SAM¶
Implements SAM, sharpness-aware minimization wrapping a base optimizer.
where \(\hat{\epsilon}_t\) solves the inner maximization
\(\max_{\lVert \epsilon \rVert_2 \leq \rho} L(\theta_t + \epsilon)\)
to first order, and the gradient at the perturbed point is fed to the
wrapped base optimizer. With adaptive=True the perturbation becomes
the scale-invariant
\(\hat{\epsilon}_t = \rho \, \theta_t^2 g_t / \lVert \theta_t g_t \rVert_2\)
of ASAM (Kwon et al., ICML 2021).
Reference: Pierre Foret, Ariel Kleiner, Hossein Mobahi, Behnam Neyshabur, "Sharpness-Aware Minimization for Efficiently Improving Generalization", ICLR 2021. https://arxiv.org/abs/2010.01412
Note: Each step needs two forward-backward passes: either call first_step, recompute the loss and gradients, then call second_step, or pass step a closure that zeroes gradients, computes the loss, and calls backward().