MuonBP¶
Implements MuonBP, a distributed Muon variant that interleaves cheap blockwise orthogonalization with periodic full orthogonalization.
Muon orthogonalizes the momentum matrix every step, which forces an all-gather of the sharded gradient on every iteration. MuonBP (Block-Periodic) instead orthogonalizes each device's local shard \(M_t^{(m)}\) independently most of the time, and gathers the full momentum to orthogonalize globally only once every \(P\) steps. Block and full steps use separate learning rates \(\eta_{\mathrm{block}}\) and \(\eta_{\mathrm{full}}\), with RMS-norm matching scaling each update by the square root of the relevant matrix dimensions. The period \(P\) interpolates between fully blockwise updates (\(P\to\infty\)) and standard Muon (\(P=1\)).
where \(\theta\) are the matrix parameters (sharded across devices \(m\)), \(G_t^{(m)}\) is the local gradient shard, \(M_t^{(m)}\) the local momentum buffer, \(\mu \in [0,1)\) the momentum coefficient, \(P\) the orthogonalization period, and \(\mathrm{Orth}(\cdot)\) the orthogonalization (computed via Newton-Schulz iterations) equal to \(UV^\top\) from the SVD; the dagger denotes the Moore-Penrose pseudoinverse.
Reference: Ahmed Khaled, Kaan Ozkara, Tao Yu, Mingyi Hong, Youngsuk Park, "MuonBP: Faster Muon via Block-Periodic Orthogonalization", arXiv 2025. https://arxiv.org/abs/2510.16981