TRAC¶
Implements TRAC, a parameter-free scale tuner for any base optimizer.
TRAC keeps the reference point \(\theta_{ref}\) (the parameters before optimization began) and, after each base-optimizer update, rescales the cumulative displacement by a learned scale \(S_t\). To recover the base optimizer's raw step direction, the displacement is first un-scaled by the previous-step scale \(S_{t-1}\), giving the un-scaled displacement \(\Delta_t = (\theta_t - \theta_{ref}) / (S_{t-1} + \epsilon)\). The scale is the sum of \(n\) one-dimensional discounted tuners, one per discount factor \(\beta_i\). With base update producing \(\theta_t\), gradient \(g_t\), and inner product \(h_t\):
where \(\mathrm{erfi}\) is the imaginary error function and
\(s_{init}\) is the initial scale s_prev.
Reference: Aneesh Muppidi, Zhiyu Zhang, Heng Yang, "Fast TRAC: A Parameter-Free Optimizer for Lifelong Reinforcement Learning", NeurIPS 2024. https://arxiv.org/abs/2405.16642
Note: this is a wrapper around a base optimizer. Pass an already constructed
optimizer instance, e.g.
TRAC(torch.optim.AdamW(model.parameters())).