MADA¶
Implements MADA, a meta-adaptive optimizer that learns where to sit between known adaptive methods via hyper-gradient descent.
MADA defines a single parameterized update that subsumes Adam, AMSGrad, Adan, and Yogi as special points in a continuous coefficient space. The first moment blends plain momentum with Adan's gradient-difference term (\(\beta_3\)), and the second moment blends an Adam-style running average, a Yogi-style sign correction (\(c\)), and an AMSGrad-style running maximum (\(\rho\)). These interpolation coefficients are not fixed: they are treated as additional variables and updated during training by descending the validation/training loss with respect to them (hyper-gradient descent), so the optimizer drifts toward whichever known method works best for the task. Bias-correction terms are omitted below for clarity, as in the paper.
where \(\theta\) are the parameters, \(\eta_t\) the learning rate, \(g_t\) the gradient, \(m_t\)/\(v_t\) the first and second moments, \(\beta_1,\beta_2\) the moment decays, \(\epsilon\) the stability constant, and \(\beta_3, c, \rho\) the interpolation coefficients (learned by hyper-gradient descent): \(\beta_3\) weights Adan's gradient-difference term, \(c\) interpolates Yogi's sign correction, and \(\rho\) interpolates the AMSGrad running maximum. Setting \(\beta_3=0\), \(c=1\), \(\rho=1\) recovers Adam.
Reference: Kaan Ozkara, Can Karakus, Parameswaran Raman, Mingyi Hong, Shoham Sabach, Branislav Kveton, Volkan Cevher, "MADA: Meta-Adaptive Optimizers through hyper-gradient Descent", arXiv 2024. https://arxiv.org/abs/2401.08893