Amos¶
Implements Amos, an Adam-style optimizer with adaptive weight decay towards a model-oriented scale.
Amos replaces the tuned weight decay of Adam with a decay schedule driven by a per-variable model-oriented scale \(\xi\), an estimate of the magnitude each weight should settle at. The second moment is a scalar mean of the squared gradient, so the running buffers are size one per parameter tensor.
where \(\overline{g_t^2}\) is the mean of the squared gradient over the
parameter tensor, \(\xi\) is the model-oriented scale returned by
get_scale, \(b_t\) is the accumulated decay buffer, \(c\) and
\(d\) are the decay coefficients c_coef and d_coef, and
\(\lambda\) is the additional L2 term extra_l2. An optional moving
average of the update with rate momentum is applied before the step.
Reference: Ran Tian, Ankur P. Parikh, "Amos: An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale", 2022. https://arxiv.org/abs/2210.11693