MoT Code Walkthrough¶
KB references: MoT paper note
Overview¶
Mixture-of-Transformers (MoT) introduces modality-aware sparsity to every non-embedding block so that each modality owns its feed-forward, attention, and normalization routes while still sharing global self-attention. In practice this lets a 7B-text+image MoT hit dense-model quality with only 55.8 % of the FLOPs, extend to speech with 37.2 % of the dense compute, and run multi-branch generation faster on commodity A100s.^[text title="external_repos/MoT/README.md] (lines 6-30)"
At-a-Glance¶
| Architecture | Params / FLOPs | Context | Inputs | Key capabilities | Repo |
|---|---|---|---|---|---|
Attach modality-untied feed-forward + attention experts to an existing transformer, using binary modality_masks to route tokens yet keeping shared global attention.^[text title="external_repos/MoT/README.md][16:151:external_repos/MoT/src/simple_ModalityUntiedAttention.py] |
7B MoT (text+image) matches dense baselines at 55.8 % FLOPs; 443 M MoT (text+image+speech) hits dense speech quality at 37.2 % FLOPs.^[15:23:external_repos/MoT/README.md] |
Chameleon (autoregressive text + raster image), Transfusion (text autoregressive + image diffusion), and broader “native multimodal” projects.^[15:30:external_repos/MoT/README.md] |
Any packed token sequence as long as each token is tagged in modality_masks; examples show text/image/speech masks and detoured normalization rules.^[129:137:external_repos/MoT/README.md][87:137:external_repos/MoT/src/simple_ModalityUntiedAttention.py] |
Step-by-step tutorial covering FFN experts, attention experts, and normalization placement so you can graft MoT onto proprietary stacks.^[75:330:external_repos/MoT/README.md] |
external_repos/MoT |
Environment & Integration Notes¶
- Designed as a playbook on top of your transformer—start from any stack that exposes attention/FFN modules and thread through MoT’s modality-specific replacements.^[
text title="external_repos/MoT/README.md] (lines 40-71)" - Efficient gains hinge on accurate routing masks; the README demonstrates simple boolean lists and emphasises deterministic routing per modality.^[
text title="external_repos/MoT/README.md] (lines 129-137)" - Norm placement matters: either keep residual norms inside the expert modules (preferred) or refactor your
TransformerBlockto avoid double-normalizing.^[text title="external_repos/MoT/README.md] (lines 307-330)"
Key Components¶
Modality-Untied Feed-Forward (src/simple_ModalityUntiedFeedForward.py)¶
SimpleModalityUntiedFeedForward replicates a SiLU-gated MLP per modality, normalizes each expert’s output, then stitches results back into the original token order via merge_modalities. Swapping only this block already covers ≈67 % of non-embedding parameters, so most FLOP savings arrive after this step.^[text title="external_repos/MoT/README.md] (lines 75-119)"
class SimpleModalityUntiedFeedForward(torch.nn.Module):
def __init__(..., n_modalities: int = 2):
...
self.local_experts = torch.nn.ModuleList([
SimpleFeedForward(...)
for _ in range(self.n_modalities)
])
self.local_experts_ffn_norm = torch.nn.ModuleList(
[SimpleRMSNorm(dim, eps=1e-5) for _ in range(self.n_modalities)]
)
def forward(self, x, modality_masks):
expert_outputs = []
for i in range(self.n_modalities):
expert_input = x[modality_masks[i]]
expert_output = self.local_experts[i](expert_input)
expert_output = self.local_experts_ffn_norm[i](expert_output)
expert_outputs.append(expert_output)
return merge_modalities(expert_outputs, modality_masks)
Because experts only see their modality tokens, you can scale specialization (e.g., text-heavy vs. image-heavy hidden sizes) without perturbing other branches. SimpleFeedForward itself is the Lingua-style gated MLP that preserves tensor-parallel friendliness.^[text title="external_repos/MoT/src/simple_ModalityUntiedFeedForward.py] (lines 64-107)"
Modality-Untied Attention (src/simple_ModalityUntiedAttention.py)¶
The attention module mirrors the FFN pattern: per-modality projections and RMSNorms for Q/K/V/outputs, shared global attention via torch.nn.MultiheadAttention, and a final per-modality projection back to the model dimension.^[text title="external_repos/MoT/README.md][16:151:external_repos/MoT/src/simple_ModalityUntiedAttention.py] (lines 141-301)"
Per-modality Q/K/V projections with shared attention:
class SimpleModalityUntiedAttention(torch.nn.Module):
def __init__(...):
self.local_experts_wq = self._create_experts(dim, n_heads * head_dim)
self.local_experts_wk = self._create_experts(dim, n_heads * head_dim)
self.local_experts_wv = self._create_experts(dim, n_heads * head_dim)
self.local_experts_wo = self._create_experts(n_heads * head_dim, dim)
...
self.attention_comp = torch.nn.MultiheadAttention(
head_dim=head_dim,
n_heads=n_heads,
dropout=dropout,
)
During forward, tokens are first split by mask, projected/normed per modality, concatenated back for standard attention, and finally projected/normed per modality again.^[text title="external_repos/MoT/src/simple_ModalityUntiedAttention.py] Optional QK normalization reshapes tensors to [*, num_heads, head_dim] before applying SimpleRMSNorm, which keeps the rotary-scaled statistics stable.^[112:174:external_repos/MoT/src/simple_ModalityUntiedAttention.py] (lines 86-151)"
Utility Primitives (src/utils.py)¶
merge_modalities reconstructs the packed sequence according to mask order, so expert outputs can be arbitrarily sharded while still producing a contiguous tensor for the residual path. SimpleRMSNorm is the Lingua-derived RMSNorm variant used consistently across experts.^[text title="external_repos/MoT/src/utils.py] (lines 14-66)"
Modality merging utility:
def merge_modalities(expert_outputs, modality_masks):
merged = torch.empty_like(expert_outputs[0])
for i in range(len(expert_outputs) - 1, -1, -1):
merged[modality_masks[i]] = expert_outputs[i]
return merged
Implementation Checklist¶
- Start from your baseline
TransformerBlock, then replace FFN/attention classes with the modality-untied versions, ensuring the residual structure still performsx + module(x)as shown in the README.^[text title="external_repos/MoT/README.md] (lines 307-330)" - Provide
modality_masksfor every forward pass. The README’s three-modality example demonstrates boolean masks; in production you can precompute these from tokenizer metadata or image/video region plans.^[text title="external_repos/MoT/README.md] (lines 129-137)" - Keep norm layers inside the modality-specific modules to avoid double-scaling outputs; only keep block-level norms if your baseline requires them.^[
text title="external_repos/MoT/README.md] (lines 307-330)"
Integration Hooks¶
- Routing data from KB assets. When generating multimodal batches (e.g., fMRI tokens + gene tokens) build boolean masks once per modality and pass them into MoT blocks—only the masks need awareness of modality boundaries.
- Progressive specialization. Because experts are independent
nn.ModuleListentries, you can freeze or reinitialize select modalities while fine-tuning others (useful when only one KB modality changes). - FLOP budgeting. The FLOP savings callouts (55.8 % / 37.2 % / one-third) provide targets for profiling when adapting MoT to new neuro-omics settings.^[
text title="external_repos/MoT/README.md] (lines 15-30)"