MSAM¶
Implements MSAM (Momentum-SAM), sharpness-aware minimization that perturbs parameters along the accumulated momentum direction instead of the current gradient.
SAM ascends along the normalized gradient before computing the descent step, doubling the per-step cost with an extra forward-backward pass. MSAM replaces this ascent direction with the SGD momentum vector, which is already available, so the perturbation is essentially free. Parameters are perturbed by \(-\rho\) times the unit momentum vector, the gradient is evaluated at that perturbed point, and a standard SGD-with-momentum step follows.
where \(\theta_t\) are the (unperturbed) parameters, \(\tilde{\theta}_t\) the perturbed parameters at which the gradient is taken, \(v_t\) the momentum buffer, \(\rho\) the perturbation radius, \(\mu\) the momentum coefficient, \(\eta\) the learning rate, and \(g_t\) the minibatch gradient evaluated at \(\tilde{\theta}_t\).
Reference: Marlon Becker, Frederick Altrock, Benjamin Risse, "Momentum-SAM: Sharpness Aware Minimization without Computational Overhead", arXiv 2024. https://arxiv.org/abs/2401.12033