AE-SAM¶
Implements AE-SAM, an adaptive policy that applies Sharpness-Aware Minimization only on the steps that need it.
SAM improves generalization by minimizing a perturbed loss, but it doubles the cost per step because each update needs two gradients: one to build the perturbation and one to update the weights. AE-SAM tracks the running mean and variance of the squared stochastic gradient norm and triggers the full SAM step only when the current squared norm is large relative to that distribution, i.e. when the iterate sits in a sharp region; otherwise it falls back to a cheap single-gradient (ERM) step. A linearly scheduled threshold controls how often SAM fires over the course of training.
where \(g_t = \nabla\mathcal{L}(\theta_t)\) is the stochastic gradient on batch \(\mathcal{B}_t\), \(\eta\) is the learning rate, \(\rho\) the SAM neighborhood radius, \(\delta \in (0,1)\) the EMA forgetting rate, \(\mu_t\) and \(\sigma_t^2\) the running mean and variance of \(\|g_t\|^2\), \(T\) the total number of iterations, and \(\lambda_1, \lambda_2\) the endpoints of the linear threshold schedule \(c_t\).
Reference: Weisen Jiang, Hansi Yang, Yu Zhang, James Kwok, "An Adaptive Policy to Employ Sharpness-Aware Minimization", ICLR 2023. https://arxiv.org/abs/2304.14647