SGDScheduleFree¶
Implements Schedule-Free SGD, which replaces momentum with interpolation and averaging.
where \(z_t\) is the base SGD iterate, gradients are evaluated at the
interpolated point \(y_t\), the parameters used for evaluation are the
average \(x_t\), \(\lambda\) is weight_decay, and
\(c_{t+1} = \gamma_t^2 / \sum_{i=1}^{t} \gamma_i^2\). No learning
rate schedule is needed; linear warmup is available through
warmup_steps.
Reference: Aaron Defazio, Xingyu Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky, "The Road Less Scheduled", NeurIPS 2024. https://arxiv.org/abs/2405.15682
Note: Call optimizer.train() before training and optimizer.eval() before evaluation or checkpointing, alongside the matching model.train() / model.eval() calls. Gradients are computed at \(y_t\) while losses should be measured at \(x_t\), so the parameter buffer must be switched between the two points.