GaLoreAdamW¶
Implements GaLoreAdamW, AdamW with gradient low-rank projection.
where \(r\) is the projection rank, \(T\) the subspace change
frequency (update_proj_gap), \(\alpha\) the scale factor, and
\(\lambda\) the decoupled weight decay, applied after the gradient
step as upstream does. Bias correction is folded into the step size
\(\eta_t\), the formulation the official implementation inherits
from the transformers AdamW. The Adam statistics
\(m_t, v_t\) live in the rank-\(r\) subspace, which is what saves
the optimizer memory. The paper states the update for a matrix with
\(m \le n\) and a left projector; this implementation picks the
projector side from the gradient shape so the smaller factor is kept.
Note: Projection is enabled per parameter group: groups carrying rank, update_proj_gap, scale, and proj_type keys are projected (2D parameters only), all other groups get plain AdamW. The upstream tensor projector for dim > 2 parameters needs tensorly and is not vendored.
Reference: Jiawei Zhao, Zhenyu Zhang, Beidi Chen, Zhangyang Wang, Anima Anandkumar, Yuandong Tian, "GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection", ICML 2024. https://arxiv.org/abs/2403.03507