Hybrid SignSGD-SGD switching¶
Implements Hybrid SignSGD-SGD switching, a SWATS-style schedule that begins with momentum SignSGD and transitions to plain SGD using a projection-calibrated learning rate.
SignSGD-M compresses each momentum coordinate to its sign, which saves communication and memory but discards gradient magnitude and leaves a generalization gap. This method runs SignSGD-M for the early phase while continuously estimating the SGD step size that best matches the sign step via a projection of the gradient onto the sign direction. The estimate is tracked with an exponential moving average, and once a switch step is reached the optimizer hands over to SGD using that calibrated rate, recovering magnitude information for the later phase.
where \(\theta\) are the parameters, \(g_t\) the stochastic gradient, \(m_t\) the momentum with decay \(\beta_1\), \(\gamma\) the fixed sign step size, \(\lambda_t\) the per-step projection of the gradient onto the sign direction, \(\bar{\lambda}_t\) its EMA with decay \(\beta_2\), \(\epsilon\) a stability constant, and \(T_{\mathrm{switch}}\) the epoch at which the optimizer transitions from SignSGD-M to SGD.
Reference: Haoran Chen, Wentao Wang, "Enhancing SignSGD: Small-Batch Convergence Analysis and a Hybrid Switching Strategy", arXiv 2026. https://arxiv.org/abs/2604.25550