SUMO¶
Implements SUMO, subspace-aware moment orthogonalization for memory-efficient LLM training.
SUMO trains in a low-rank subspace: every \(K\) steps it computes an orthonormal basis \(Q_t\) for the top-\(r\) subspace of the gradient (via truncated randomized SVD), projects the gradient and carries the momentum across subspace changes, and accumulates a first-order moment inside that subspace. The projected moment is orthogonalized exactly through its SVD, \(m_t = U \Sigma V^\top \Rightarrow O_t = U V^\top\), replacing Muon's Newton–Schulz approximation and avoiding its accumulated error. The orthogonalized update is mapped back to the full space by \(Q_t\) before the weight step, with optional norm-growth clipping and decoupled weight decay.
where \(\theta\) are the weights, \(g_t\) the full gradient, \(Q_t\) the orthonormal subspace basis of rank \(r\), \(\hat{g}_t\) the projected gradient, \(m_t\) the in-subspace first moment, \(\mu\) the momentum coefficient, \(O_t\) the SVD-orthogonalized direction, \(\eta\) the learning rate, \(\alpha\) a scale factor, and \(\lambda\) the weight-decay coefficient.
Reference: Yehonathan Refael, Guy Smorodinsky, Tom Tirer, Ofir Lindenbaum, "SUMO: Subspace-Aware Moment-Orthogonalization for Accelerating Memory-Efficient LLM Training", arXiv 2025. https://arxiv.org/abs/2505.24749