BADM¶
Implements BADM (Batch ADMM), a data-driven ADMM optimizer that splits each mini-batch into sub-batches and updates primal, global, and dual variables per block.
BADM recasts training as a consensus problem: the loss is partitioned over \(B\) batches, each further split into \(S\) sub-batches, and a local parameter \(w_{bs}\) is forced to agree with a global parameter \(w\) through the constraint \(w = w_{bs}\). Each epoch sweeps the batches; within a batch the global parameter is aggregated from the previous sub-batch solutions and their scaled multipliers, then every sub-batch performs an inexact (single-gradient) local solve and a dual ascent step. The sub-batches inside a batch are independent, so the \(S\) local updates run in parallel.
Carrying state forward across batches via \(w_{0s}^{\ell+1} = w_{Bs}^{\ell},\ \pi_{0s}^{\ell+1} = \pi_{Bs}^{\ell}\), for each batch \(b = 1,\dots,B\):
where \(w_b\) is the global parameter, \(w_{bs}\) the local parameter for sub-batch \(\mathcal{N}_{bs}\), \(\pi_{bs}\) its Lagrange multiplier, \(\nabla F_{bs}\) the sub-batch gradient, \(\alpha_s\) the sub-batch sampling weight, \(\sigma\) the augmented-Lagrangian penalty, \(\rho\) the proximal coefficient of the inexact local solve, and \(\ell\) the epoch index. The returned parameter is \(w_B^{\ell+1}\).
Reference: Ouya Wang, Shenglong Zhou, Geoffrey Ye Li, "BADM: Batch ADMM for Deep Learning", arXiv 2024. https://arxiv.org/abs/2407.01640