Functional SAM¶
Implements Functional-SAM, a sharpness-aware update that perturbs only the network Jacobian and keeps the loss derivative at the unperturbed point.
For a loss \(L\) composed with a network function \(F\) (the concatenated logits), the SAM gradient at the perturbed point \(\theta+\rho\,\epsilon^*\) expands by the chain rule into a "logit path", which moves the loss derivative \(\nabla_F L\) along the perturbation, and a "functional path", which moves the Jacobian \(\nabla_\theta F\). The paper argues the logit path drives spurious sharpness minimization that fails to improve generalization on harder problems. Functional-SAM discards the logit-path contribution: it evaluates the network Jacobian at the perturbed parameters but multiplies it by the loss derivative taken at the original parameters, so only the functional path contributes to the descent direction.
where \(\theta\) are the parameters, \(\eta\) the learning rate, \(\rho\) the perturbation radius, \(\epsilon^*\) the normalized ascent direction, \(F(\theta)\) the network function (logits), \(\nabla_\theta F\) its Jacobian with respect to the parameters, and \(\nabla_F L\) the gradient of the loss with respect to the function outputs; standard SAM instead uses \(g_t = \nabla_\theta L(\theta_t + \rho\,\epsilon^*)\), which the chain rule splits into this functional path plus the discarded logit path.
Reference: Sidak Pal Singh, Hossein Mobahi, Atish Agarwala, Yann Dauphin, "Avoiding spurious sharpness minimization broadens applicability of SAM", arXiv 2025. https://arxiv.org/abs/2502.02407