Atlas¶
Implements Atlas, a curvature-aware optimizer that combines a Hutchinson trace estimate, a trust-radius clamp, and a fixed three-phase descent cascade.
Atlas centralizes and median-clips the gradient, then periodically probes the Hessian trace with a single Hutchinson product \(\widehat{\mathrm{tr}}(H_t)=|v^\top \nabla^2 L(\theta_t)\,v|\) (Rademacher \(v\)) to set a trust radius \(r_t=\sqrt{\widehat{\mathrm{tr}}(H_t)}\). The candidate step is produced by one of three phase-specific rules selected by training progress \(\phi_t=t/T\) (AdaGrad-momentum, then a rectified AdamW step with a trust-ratio scaling, then SGD-Nesterov), rescaled so its norm never exceeds \(r_t\), and applied with decoupled weight decay. A Safe-Step monitor re-evaluates the loss and rolls the step back when it rises by more than the tolerance factor; optional LookAhead averaging and a CheapSAM perturbation wrap the core step.
where \(\theta_t\) are the parameters, \(\eta\) the cosine-scheduled learning rate, \(g_t^\star\) the centralized and adaptively clipped gradient, \(m_t/v_t/r_t\) the first- and second-moment buffers, \(\tau_t\) the running trust ratio, \(\widehat{\mathrm{tr}}(H_t)\) the Hutchinson Hessian-trace estimate (refreshed every \(h\) steps), \(\phi_t=t/T\) the fractional training progress selecting the active phase, \(\lambda\) the decoupled weight decay, and \(\epsilon\) a small stabilizer; \(\mathrm{AGC2}\) rescales the gradient when its norm exceeds a median-of-weights threshold.
Reference: János Horváth, "Atlas – Rethinking Optimizer Design for Stability and Speed", OPT 2025: 17th Annual Workshop on Optimization for Machine Learning (NeurIPS workshop). https://opt-ml.org/papers/2025/paper6.pdf