AdaSAM¶
Implements AdaSAM, sharpness-aware minimization combined with an adaptive learning rate and momentum.
SAM seeks parameters in flat loss regions by computing the gradient at a worst-case perturbation \(\delta_t = \rho\, g_t / \lVert g_t \rVert\) within a \(\rho\)-ball, but uses a fixed global step size. AdaSAM feeds the perturbed gradient through Adam-style first and second moments with an AMSGrad maximum on the second moment, giving each coordinate its own adaptive step. Numerical stability is provided by initializing \(\hat{v}_{-1} = \epsilon^2\) rather than by an added constant in the denominator.
where \(\theta\) are the parameters, \(\gamma\) the base learning rate, \(\rho\) the perturbation radius, \(g_t\) the gradient at \(\theta_t\), \(\tilde{g}_t\) the gradient at the perturbed point, \(m_t\)/\(v_t\) the first and second moments, \(\hat{v}_t\) the running coordinate-wise maximum of the second moment, \(\beta_1,\beta_2\) the decay rates, \(\epsilon\) the stability constant via \(\hat{v}_{-1}=\epsilon^2\), and all squaring and division are element-wise.
Reference: Hao Sun, Li Shen, Qihuang Zhong, Liang Ding, Shixiang Chen, Jingwei Sun, Jing Li, Guangzhong Sun, Dacheng Tao, "AdaSAM: Boosting Sharpness-Aware Minimization with Adaptive Learning Rate and Momentum for Training Deep Neural Networks", arXiv 2023. https://arxiv.org/abs/2303.00565