GWT¶
Implements GWT (Gradient Wavelet Transform), a memory-efficient optimizer that applies Adam to the wavelet-compressed gradient.
GWT decomposes each gradient \(g_t\) with a discrete Haar wavelet transform, splitting it into low-frequency approximation coefficients \(A_t\) and high-frequency detail coefficients \(D_t\). Adam's optimizer states \(m_t\) and \(v_t\) are maintained only over the smaller approximation part, which roughly halves the moment-buffer memory per level of decomposition. The normalized approximation is reconstructed together with the (state-free) detail coefficients via the inverse transform, scaled by a constant \(\alpha\), and used to update the weights.
where \(H\) is the forward Haar wavelet matrix, \(\tilde{H}\) its inverse (\(H\tilde{H}=I\)), \(A_t\)/\(D_t\) the approximation/detail coefficients, \(m_t\)/\(v_t\) the Adam moments over \(A_t\), \(\beta_1,\beta_2\) the decay rates, \(\eta\) the learning rate, \(\alpha\) a scale factor (default \(0.25\)), and \(\epsilon\) a stability constant.
Reference: Ziqing Wen, Ping Luo, Jiahuan Wang, Xiaoge Deng, Jinping Zou, Kun Yuan, Tao Sun, Dongsheng Li, "Breaking Memory Limits: Gradient Wavelet Transform Enhances LLMs Training", arXiv 2025. https://arxiv.org/abs/2501.07237