SLTrain¶
Implements SLTrain, a sparse-plus-low-rank weight reparameterization for parameter- and memory-efficient pretraining.
SLTrain replaces each dense weight matrix with the sum of a low-rank factorization \(BA\) and a sparse matrix \(S\). The sparse support is drawn once by uniform random sampling and held fixed throughout training, so only the non-zero values are learned alongside the low-rank factors. The three components \(B\), \(A\), and the sparse values are optimized jointly with Adam; storing only the factors and the sparse entries cuts both parameter count and optimizer-state memory relative to full-rank training.
where \(W\) is the effective weight, \(B,A\) are the rank-\(r\) low-rank factors scaled by \(\alpha/r\) (\(\alpha\) a balancing hyperparameter), \(S\) is the sparse matrix with index set \(\mathcal{I}\) and learnable values \(\mathcal{V}\), \(\delta\) is the sparsity density, and the trainable parameters \(\{B, A, \mathcal{V}\}\) are updated by Adam.
Reference: Andi Han, Jiaxiang Li, Wei Huang, Mingyi Hong, Akiko Takeda, Pratik Jawanpuria, Bamdev Mishra, "SLTrain: a sparse plus low-rank approach for parameter and memory efficient pretraining", NeurIPS 2024. https://arxiv.org/abs/2406.02214