M-SAM (Modality-Aware SAM)¶
Implements M-SAM (Modality-Aware SAM), sharpness-aware minimization restricted to the dominant modality in multimodal learning.
In multimodal training, a single modality often dominates the loss, so applying SAM uniformly across all modalities can over-regularize the weaker ones. M-SAM decomposes the per-batch loss by modality contribution using Shapley values, identifies the dominant modality \(m_d\), and applies the SAM ascent perturbation only to that modality's loss \(\mathcal{L}_d\). The remaining modalities contribute through ordinary first-order gradients of \(\mathcal{L}_s\) evaluated at the unperturbed parameters.
where \(\theta\) are the parameters, \(\eta_t\) the learning rate, \(\rho\) the neighborhood radius, \(\lambda\) the weight decay, \(v_m\) the Shapley contribution of modality \(m\) to the loss \(\ell = \mathcal{L}(f(x_1^i,\dots,x_M^i;\theta), y^i)\), \(\mathcal{L}_d\) the dominant-modality loss, and \(\mathcal{L}_s\) the aggregate non-dominant loss.
Reference: Hossein Rajoli, Jie Ji, Xiaolong Ma, Fatemeh Afghah, "Modality-Aware SAM: Sharpness-Aware-Minimization Driven Gradient Modulation for Harmonized Multimodal Learning", arXiv 2025. https://arxiv.org/abs/2510.24919