WSAM¶
Implements WSAM, sharpness-aware minimization with the sharpness weighted as a regularization term.
where \(u_t\) is the base optimizer update computed from \(g_t\).
With decouple=False the weighted gradient
\(g_t + \frac{\gamma}{1 - \gamma} (\tilde{g}_t - g_t)\) is fed to the
base optimizer instead.
Reference: Yun Yue, Jiadi Jiang, Zhiling Ye, Ning Gao, Yongchao Liu, Ke Zhang, "Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term", KDD 2023. https://arxiv.org/abs/2305.15817
Note: WSAM wraps a base optimizer and needs two forward-backward passes per step: call step with a closure, or call first_step and second_step around the second backward pass. Pass model so BatchNorm running stats are frozen during the second pass.