ESAM¶
Implements ESAM, an efficient sharpness-aware minimizer that approximates SAM at a fraction of its cost.
SAM seeks parameters in flat regions of the loss landscape by perturbing the weights toward the worst-case direction \(\hat{\epsilon}=\rho\,g_t/\lVert g_t\rVert\) and then taking a step on the gradient evaluated at \(\theta+\hat{\epsilon}\), which doubles the per-step cost. ESAM cuts this overhead with two strategies. Stochastic Weight Perturbation (SWP) perturbs only a random subset of parameters: each coordinate is kept with probability \(\beta\) via a Bernoulli mask \(m\) and rescaled by \(1/\beta\) so the perturbation stays unbiased. Sharpness-sensitive Data Selection (SDS) computes the final gradient on only the subset \(\mathcal{B}^{+}\) of the batch whose loss increases most under the perturbation, since those samples dominate the sharpness measure.
where \(\rho\) is the neighborhood radius, \(g_t=\nabla_\theta L_{\mathcal{B}}(\theta)\) is the batch gradient, \(\beta\in(0,1]\) is the SWP keep probability, \(m\) is the per-coordinate Bernoulli mask, \(\odot\) is elementwise product, \(\gamma\in(0,1]\) is the SDS selection ratio, \(\alpha\) is the threshold induced by \(\gamma\), \(L_{\mathcal{B}^{+}}\) is the loss over the selected subset, and \(\eta\) is the learning rate.
Reference: Jiawei Du, Hanshu Yan, Jiashi Feng, Joey Tianyi Zhou, Liangli Zhen, Rick Siow Mong Goh, Vincent Y. F. Tan, "Efficient Sharpness-aware Minimization for Improved Training of Neural Networks", ICLR 2022. https://arxiv.org/abs/2110.03141