GSAM¶
Implements GSAM, surrogate gap guided sharpness-aware minimization wrapping a base optimizer.
where \(g_t\) is the gradient of the loss \(f\), \(g_t^p\) is the gradient of the perturbed loss, and the perturbation radius follows the schedule \(\rho_t = \rho_{min} + (\rho_{max} - \rho_{min}) (\eta_t - \eta_{min}) / (\eta_{max} - \eta_{min})\).
Reference: Juntang Zhuang et al., "Surrogate Gap Minimization Improves Sharpness-Aware Training", ICLR 2022. https://arxiv.org/abs/2203.08065
Note: Each step needs two forward-backward passes: call set_closure and then step, or pass step a closure that zeroes gradients, computes the loss, and calls backward(). Pass rho_scheduler (for example ProportionScheduler) and call update_rho_t once per step to anneal \(\rho_t\) with the learning rate as in the paper; without one, \(\rho_t\) stays at the constant rho. The state dict covers only the GSAM perturbation state; save base_optimizer.state_dict() separately.