SlimAdam¶
Implements SlimAdam, a memory-efficient Adam that compresses the second moment by averaging it along selected tensor dimensions.
Adam stores a full per-coordinate second moment \(v_t\), doubling the optimizer state. SlimAdam keeps the first moment and the Adam update unchanged, but replaces the second-moment accumulator with the mean of squared gradients taken over a chosen set of dimensions \(K\) of the parameter tensor. The single compressed value is then broadcast back across those dimensions in the per-coordinate division, so \(v_t\) shrinks by the size of the compressed axes. The dimensions \(K\) are picked by a signal-to-noise criterion measured during a short warmup: an axis is compressed when its second-moment entries are well described by their mean (high SNR), giving up to ~98% second-moment memory savings with little quality loss.
where \(\theta\) are the parameters, \(\eta\) the learning rate, \(g_t\) the gradient, \(m_t,v_t\) the first and second moments, \(\beta_1,\beta_2\) their decay rates, \(\epsilon\) a stability constant, and \(\mathbb{E}_K[\cdot]\) the mean over the compressed dimensions \(K\) (fan-in, fan-out, or both), with the resulting value broadcast back across \(K\) in the update.
Reference: Dayal Singh Kalra, John Kirchenbauer, Maissam Barkeshli, Tom Goldstein, "When Can You Get Away with Low Memory Adam?", arXiv 2025. https://arxiv.org/abs/2503.01843