Skip to content

BrainMT Code Walkthrough

KB references: Model card · fMRI feature spec · Integration strategy · Experiment config stub

Overview

BrainMT pairs bidirectional Mamba mixers (temporal-first scanning) with MHSA transformer blocks to model long-range fMRI dynamics, delivering state-of-the-art regression/classification on UKB and HCP phenotypes.^[text title="external_repos/brainmt/README.md][294:462:external_repos/brainmt/src/brainmt/models/brain_mt.py] The architecture is now described in an official conference paper (SpringerLink, Lecture Notes in Computer Science, pp. 150–160; first online 19 September 2025), so reference the proceedings PDF in docs/generated/kb_curated/papers-pdf/brainmt_2025.pdf when citing. (lines 3-170)"

At-a-Glance

Architecture Params Context Inputs Key capabilities Repo
3D Conv patch embed → bidirectional Mamba blocks → Transformer attention blocks^[text title="external_repos/brainmt/src/brainmt/models/brain_mt.py] Configurable (default hidden 512, depth [12,8])^[293:375:external_repos/brainmt/src/brainmt/models/brain_mt.py] 91×109×91 voxels × 200 frames (default)^[294:339:external_repos/brainmt/src/brainmt/models/brain_mt.py] Preprocessed .pt tensors from data/datasets.py^[15:80:external_repos/brainmt/src/brainmt/data/datasets.py] DDP training with regression/classification heads, inference utilities^[1:330:external_repos/brainmt/src/brainmt/train.py][1:390:external_repos/brainmt/src/brainmt/inference.py] github.com/arunkumar-kannan/brainmt-fmri

Environment & Hardware Notes

  • Exact environment commands. The README targets Python 3.9.18 + PyTorch 2.6/CUDA 12.4, created via python -m venv brainmt_env, source brainmt_env/bin/activate, and pip install -r requirements.txt.^[text title="external_repos/brainmt/README.md] (lines 44-60)"
  • Gradient checkpoint flag. Every Mamba block accepts use_checkpoint and conditionally calls checkpoint.checkpoint(...), so you can instantiate BrainMT(..., use_checkpoint=True) to reduce memory usage on long temporal contexts.^[text title="external_repos/brainmt/src/brainmt/models/brain_mt.py][293:334:external_repos/brainmt/src/brainmt/models/brain_mt.py] (lines 95-125)"

Key Components

Dataset Loader (src/brainmt/data/datasets.py)

The dataset stores fMRI volumes as fp16 tensors (func_data_MNI_fp16.pt), slices contiguous time segments, permutes them into [frames, channel, depth, height, width], and returns (tensor, target) pairs.

fMRI dataset with temporal slicing:

external_repos/brainmt/src/brainmt/data/datasets.py (lines 15-80)
class fMRIDataset(Dataset):
    def __getitem__(self, idx):
        data = torch.load(img_file)
        start_index = torch.randint(0, total_frames - num_frames + 1, (1,)).item()
        data_sliced = data[:, :, :, start_index:end_index]
        data_global = data_sliced.unsqueeze(0).permute(4, 0, 2, 1, 3)
        target = self.target_dict[subject_dir]
        return data_global, torch.tensor(target, dtype=torch.float16)

Patch Embed & Conv Blocks (src/brainmt/models/brain_mt.py)

PatchEmbed downsamples the 4D tensor with strided 3D convolutions before two ConvBlocks + Downsample layers reduce spatial resolution while keeping temporal length.

3D convolution patch embedding:

external_repos/brainmt/src/brainmt/models/brain_mt.py (lines 202-263)
class PatchEmbed(nn.Module):
    self.conv_down = nn.Sequential(
        nn.Conv3d(in_chans, in_dim, 3, 2, 1, bias=False),
        nn.ReLU(),
        nn.Conv3d(in_dim, dim, 3, 2, 1, bias=False),
        nn.ReLU()
    )

Hybrid Mamba + Transformer Backbone (src/brainmt/models/brain_mt.py)

Temporal-first processing reshapes tokens, feeds them through create_block (bidirectional Mamba) and then through transformer attention + MLP to capture residual spatial dependencies.

Bidirectional Mamba blocks followed by transformer attention:

external_repos/brainmt/src/brainmt/models/brain_mt.py (lines 331-462)
self.layers = nn.ModuleList([
    create_block(embed_dim, ssm_cfg=ssm_cfg, ..., drop_path=inter_dpr[i], ...)
    for i in range(depth[0])
])
self.blocks = nn.ModuleList([
    Attention(embed_dim, num_heads=num_heads, ...)
    for i in range(depth[1])
])
...
def forward_features(self, x, ...):
    x = self.patch_embed(x)
    x = self.conv_block0(x); x = self.downsample0(x)
    x = self.conv_block1(x); x = self.downsample1(x)
    x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
    x = x + self.temporal_pos_embedding
    for layer in self.layers:
        hidden_states, residual = layer(hidden_states, residual, ...)
    for block in self.blocks:
        hidden_states = hidden_states + drop_path_attn(block(self.norm(hidden_states)))

Forward & Head (src/brainmt/models/brain_mt.py)

CLS token is prepended before Mamba blocks; forward returns final MLP head output for regression/classification.

CLS token prepending and final head:

external_repos/brainmt/src/brainmt/models/brain_mt.py (lines 400-461)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
...
return hidden_states[:, 0, :]

Training Loop (src/brainmt/train.py)

Hydra config builds datasets, wraps the model in DDP, selects loss (MSE or BCEWithLogits), constructs layer-wise LR decay groups, and trains with GradScaler + cosine warm restarts.

DDP training with mixed precision:

external_repos/brainmt/src/brainmt/train.py (lines 132-234)
model = BrainMT(**model_config).to(device)
model = nn.parallel.DistributedDataParallel(model, device_ids=[device_id], ...)
if cfg.task.loss_fn == "mse":
    criteria = nn.MSELoss()
...
train_loss, train_outputs, train_targets = train_one_epoch(model, criteria, train_loader, optimizer, scaler, device, epoch, cfg)
val_loss, val_outputs, val_targets = evaluate(model, criteria, val_loader, device)

Inference (src/brainmt/inference.py)

The inference script mirrors dataset splits, loads checkpoints, and computes detailed metrics (accuracy/AUROC for classification, MSE/MAE/R²/Pearson for regression), plus optional plots.

Checkpoint loading and metric computation:

external_repos/brainmt/src/brainmt/inference.py (lines 26-210)
model = BrainMT(**model_config).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
if cfg.task.name == 'classification':
    metrics = calculate_classification_metrics(test_outputs, test_targets)
else:
    metrics = calculate_regression_metrics(test_outputs, test_targets)

Integration Hooks (Brain ↔ Genetics)

  • Embedding shape. forward_features returns [B, hidden] CLS vectors (hidden default 512). To access intermediate token embeddings, tap hidden_states[:, 1:, :] before the final average/MLP.^[text title="external_repos/brainmt/src/brainmt/models/brain_mt.py] (lines 400-462)"
  • Pooling choices. CLS token encodes temporal-first, globally attentive context. For voxel-conditioned embeddings, reshape post-Mamba tensor back to [B, voxels, hidden] prior to the transformer block and average along the voxel axis.
  • Projection to shared latent. Map [B, 512] BrainMT vectors into a 512-D multimodal space with a lightweight projector:

import torch.nn as nn

class BrainMTProjector(nn.Module):
    def __init__(self, input_dim=512, output_dim=512, dropout=0.1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(1024, output_dim),
            nn.LayerNorm(output_dim),
        )

    def forward(self, x):
        return self.layers(x)
- Normalization. Because BrainMT ends with LayerNorm (self.norm_f), additional LayerNorm in the projector keeps the scale comparable to genetic embeddings. - Temporal handling. The temporal-first scan (rearrange(..., b n) t m -> b (n t) m) is crucial for long-range modeling—preserve this ordering if you export intermediate features for contrastive alignment with DNA sequences.^[text title="external_repos/brainmt/src/brainmt/models/brain_mt.py] (lines 421-444)"

Projected BrainMT features can then be concatenated or contrastively aligned with Evo 2/GENERator/Caduceus embeddings to study genetics↔fMRI correspondences.