AdaRankGrad¶
Implements AdaRankGrad, low-rank Adam with an adaptively shrinking gradient subspace for memory-efficient LLM training.
The gradient \(g_t\) is projected onto a low-rank subspace whose rank \(r_t\) is chosen at runtime from the gradient's information content: \(r_t\) is the largest rank whose rank-\(r\) approximation captures a fraction \(\eta_{th}\) of the gradient energy, which provably decreases over training so the optimizer footprint shrinks. A projection \(Q_t\) (with \(r_t\) rows) maps the gradient into that subspace, where standard Adam moments are maintained, before mapping the resulting step back to the full parameter space. When the subspace changes between steps, the moments are carried over by re-expressing them in the new basis through \(R_t = Q_t Q_{t-1}^\top\).
where \(\theta\) are the parameters (weight matrix), \(\gamma\) the learning rate, \(g_t\) the gradient, \(Q_t\) the rank-\(r_t\) projection onto the adaptively chosen subspace, \(P_t(r)\) the best rank-\(r\) projector of \(g_t\), \(\eta_{th}\) the information-retention threshold, \(\hat g_t\) the projected gradient, \(m_t,v_t\) the (subspace) first and second moments, \(\beta_1,\beta_2\) their decay rates, and \(\epsilon\) a stability constant.
Reference: Yehonathan Refael, Jonathan Svirsky, Boris Shustin, Wasim Huleihel, Ofir Lindenbaum, "AdaRankGrad: Adaptive Gradient-Rank and Moments for Memory-Efficient LLMs Training and Fine-Tuning", arXiv 2024. https://arxiv.org/abs/2410.17881