Coupled Adam¶
Implements Coupled Adam, an Adam variant that couples the second moment across output embedding vectors.
Standard Adam treats every embedding vector independently, which the authors show drives anisotropy and gives rare tokens disproportionately large updates. Coupled Adam targets only the output embedding matrix: it computes per-vector first and second moments as usual, then replaces each vector's bias-corrected second moment with the average second moment taken over the whole vocabulary. Sharing \(\hat{v}\) removes the per-token scaling difference in the denominator, yielding better-conditioned embeddings while leaving the rest of the network on ordinary Adam.
where \(\theta_i\) is the embedding vector for token \(i\), \(V\) is the vocabulary size, \(g_{i,t}\) its gradient, \(m_{i,t}/v_{i,t}\) the first and second moments with decays \(\beta_1,\beta_2\), \(\hat{m}_{i,t}/\hat{v}_{i,t}\) their bias-corrected forms, \(\bar{v}_t\) the vocabulary-averaged second moment shared by all tokens, \(\eta\) the learning rate, and \(\epsilon\) a stability constant.
Reference: Felix Stollenwerk, Tobias Stollenwerk, "Better Embeddings with Coupled Adam", arXiv 2025. https://arxiv.org/abs/2502.08441