GrokAdamW¶
Implements GrokAdamW, AdamW with Grokfast-style amplification of slow-varying gradients.
where \(\mu_t\) is the slow-gradient EMA, \(s_t\) the grokking
signal averaged over grokking_signal_fns, \(\kappa\) the signal
decay rate, \(\lambda\) the amplification factor lamb, \(w\)
the decoupled weight decay, and \(l\) the index of the parameter
within its group, so \(\gamma\) decays the momentum of later layers.
Gradients are norm-clipped per parameter before the update when
gradient_clipping is positive.
Note: When no grokking_signal_fns are given, the signal is computed from train_loss and eval_loss entries set on the parameter group and is zero while those are absent. Optimizer state is kept in CPU memory and moved to the parameter device for each step.
GrokAdamW was written by Eric Hartford and has no dedicated paper; its slow-gradient amplification follows Grokfast.
Reference: Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee, "Grokfast: Accelerated Grokking by Amplifying Slow Gradients", arXiv 2024. https://arxiv.org/abs/2405.20233