WarpAdam¶
Implements WarpAdam, Adam with a learnable distortion matrix applied to the gradient.
WarpAdam imports the warped-gradient-descent idea from meta-learning into Adam. Instead of feeding the raw gradient into the moment estimates, it first linearly transforms it by a square matrix \(P\) learned across tasks. The matrix preconditions the gradient so the optimizer can adapt to the characteristics of a given dataset, while the rest of the Adam machinery (exponential moving averages, bias correction, adaptive step) is unchanged.
where \(g_t = \nabla_\theta f(\theta_t)\) is the gradient, \(P\) is an \(n \times n\) distortion matrix learned by meta-learning, \(m_t, v_t\) are the first and second moment estimates with decay rates \(\beta_1, \beta_2\), \(\hat{m}_t, \hat{v}_t\) are their bias-corrected forms, \(\eta\) is the learning rate, and \(\epsilon\) guards the denominator.
Reference: Chengxi Pan, Junshang Chen, Jingrui Ye, "WarpAdam: A new Adam optimizer based on Meta-Learning approach", arXiv 2024. https://arxiv.org/abs/2409.04244