Lookahead¶
Implements Lookahead, "k steps forward, 1 step back" around any optimizer.
Lookahead keeps two sets of weights. The fast weights \(\theta\) are advanced for \(k\) inner steps by a base optimizer, after which the slow weights \(\phi\) are pulled toward them by interpolation, and the fast weights are reset to the slow ones:
where \(A\) is the inner optimizer's update on minibatch \(d\), \(\alpha\) is the slow-weights step size, and \(k\) is the synchronization period.
Reference: Michael R. Zhang, James Lucas, Geoffrey Hinton, Jimmy Ba, "Lookahead Optimizer: k steps forward, 1 step back", NeurIPS 2019. https://arxiv.org/abs/1907.08610
Note: this is a wrapper around a base optimizer. Pass an already constructed
optimizer instance, e.g.
Lookahead(torch.optim.Adam(model.parameters(), lr=1e-3), k=5, alpha=0.5).