BrainHarmony Code Walkthrough¶
KB references: Model card · fMRI feature spec · sMRI feature spec · Integration strategy · Experiment config stub
Overview¶
BrainHarmony (a.k.a. BrainHarmonix) is a three-stage pipeline that first extracts modality-specific embeddings from fMRI ROI time-series and structural T1 volumes, then performs JEPA-style token-space pretraining, and finally fine-tunes classification heads on downstream cohorts (e.g., ABIDE). Stage 0 runs fused encoders with anatomical+functional positional priors, Stage 1 trains a latent-token predictor with smooth L1 loss between student and EMA targets, and Stage 2 attaches lightweight heads for clinical prediction.^[text title="external_repos/brainharmony/README.md][32:138:external_repos/brainharmony/configs/harmonizer/stage0_embed/conf_embed_pretrain.py] (lines 1-94)"
At-a-Glance¶
| Architecture | Params | Context | Inputs | Key capabilities | Repo |
|---|---|---|---|---|---|
Dual FlexVisionTransformer encoders (fMRI + 3D T1) feeding a JEPA predictor with latent tokens, FlashAttention-2 blocks, and mask-conditioned regressors.^[text title="external_repos/brainharmony/libs/flex_transformer.py][22:260:external_repos/brainharmony/modules/harmonizer/stage1_pretrain/models.py][23:112:external_repos/brainharmony/libs/ssl_models/jepa_flex.py] |
Default “base” uses 768-d embeddings, 12 encoder blocks, 6-layer predictor, mask ratio 0.75, and 128 latent tokens (configurable via scripts).^[48:138:external_repos/brainharmony/configs/harmonizer/stage0_embed/conf_embed_pretrain.py][38:49:external_repos/brainharmony/scripts/harmonizer/stage1_pretrain/run_pretrain.sh] |
400 cortical ROIs × 490 TRs are chunked into 18 patches (48-step windows) plus optional 50 subcortical channels; structural MRI cubes are normalized to 160 × 192 × 160 voxels.^[317:465:external_repos/brainharmony/datasets/datasets.py][499:561:external_repos/brainharmony/datasets/datasets.py] |
Stage 0 ingests (fMRI, T1) pairs via UKB_FusionDataset, Stage 1/2 read .npz embeddings with attention masks/labels using GenerateEmbedDataset(_downstream).^[566:581:external_repos/brainharmony/datasets/datasets.py][803:857:external_repos/brainharmony/datasets/datasets.py] |
Provided scripts wrap embedding extraction (Accelerate), JEPA pretraining, and downstream finetuning for reproducibility.^[1:23:external_repos/brainharmony/scripts/harmonizer/stage0_embed/run_embed_pretrain.sh][1:49:external_repos/brainharmony/scripts/harmonizer/stage1_pretrain/run_pretrain.sh][1:59:external_repos/brainharmony/scripts/harmonizer/stage2_finetune/run_finetune.sh] |
Repo checkout: external_repos/brainharmony |
Environment & Hardware Notes¶
- Conda + pip workflow. Create
brainharmonix(Python 3.10), install CUDA 12.4 wheels for PyTorch 2.6, thenpip install -r requirements.txtandpip install -e ..^[text title="external_repos/brainharmony/README.md] (lines 40-56)" - Checkpoint placement. Download pretrained encoders (harmonix-f/s) plus harmonizer checkpoints and drop them under
checkpoints/{harmonix-f,harmonix-s,harmonizer}before running Stage 0/1/2.^[text title="external_repos/brainharmony/README.md] (lines 58-71)" - FlashAttention 2 expectation.
FlexVisionTransformerselects FlashAttention 2 when installed (seeis_flash_attn_2_available) so ensure compatible GPU builds or fall back to “eager” attention.^[text title="external_repos/brainharmony/libs/attn_utils/fa2_utils.py][138:214:external_repos/brainharmony/libs/flex_transformer.py] (lines 8-52)"
Key Components¶
Stage 0: Embedding Extraction (modules/harmonizer/stage0_embed)¶
Accelerate launches (run_embed_pretrain.sh) call embedding_pretrain.py, which loads configurable datasets (default UKB_FusionDataset) and wraps pretrained fMRI/T1 encoders specified in conf_embed_pretrain.py. The fmri encoder receives gradient+geometric-harmonic positional embeddings, while the MAE-style T1 encoder reuses volumetric patches. Each batch returns fmri tokens, T1 tokens, attention masks, and subject IDs; Stage 0 runs both encoders, concatenates their representations, and persists .npz files along with the attention mask for later stages.^[text title="external_repos/brainharmony/scripts/harmonizer/stage0_embed/run_embed_pretrain.sh][48:138:external_repos/brainharmony/configs/harmonizer/stage0_embed/conf_embed_pretrain.py][81:185:external_repos/brainharmony/modules/harmonizer/stage0_embed/embedding_pretrain.py] (lines 1-23)"
class UKB_FusionDataset(UKBDataset, UKB_T1_Dataset):
def __init__(self, **kwargs):
UKBDataset.__init__(self, **kwargs)
UKB_T1_Dataset.__init__(self, **kwargs)
def __getitem__(self, index):
ts, _ = self.load_fmri(index)
attn_mask = self.signal_attn_mask()
t1 = self.load_t1(index)
return ts, t1, self.patch_size, attn_mask, self.ids[index]
Stage 1: Token-Space JEPA Pretraining (modules/harmonizer/stage1_pretrain)¶
The generated .npz files are streamed by GenerateEmbedDataset, which yields concatenated embeddings and their attention masks; distributed loaders feed OneTokRegViT, a latent-token ViT that appends learnable latent vectors and mask tokens before passing through decoder blocks. train_one_epoch applies Smooth L1 loss between the predictor output and EMA targets from the frozen teacher encoder, mirroring the JEPA objective.^[text title="external_repos/brainharmony/datasets/datasets.py][22:260:external_repos/brainharmony/modules/harmonizer/stage1_pretrain/models.py][13:84:external_repos/brainharmony/modules/harmonizer/stage1_pretrain/engine_pretrain.py][38:49:external_repos/brainharmony/scripts/harmonizer/stage1_pretrain/run_pretrain.sh] (lines 803-825)"
class GenerateEmbedDataset(Dataset):
def __init__(self, root_dir, portion=1.0, seed=42):
pattern = os.path.join(root_dir, "*.npz")
all_files = sorted(glob.glob(pattern))
self.files = all_files
if len(all_files) == 0:
raise RuntimeError(f"No .npz files found in {root_dir}")
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
filepath = self.files[idx]
arr = np.load(filepath)
tensor = torch.from_numpy(arr["data"])
if tensor.dtype != torch.float32:
tensor = tensor.float()
return tensor.squeeze(), arr["attn_mask"].squeeze()
Encoder with latent tokens:
def forward_encoder(self, x, attn_mask):
target = x
if self.add_pre_mapping:
x = self.pre_map(x)
x = x + self.pos_embed[:, 1:, :]
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
latent_tokens = self.latent_tokens.expand(x.shape[0], -1, -1)
latent_tokens = latent_tokens + self.enc_latent_token_positional_embedding
x = torch.cat([x, latent_tokens], dim=1)
pad = torch.ones(attn_mask.shape[0], 1200 + latent_tokens.shape[1], dtype=attn_mask.dtype, device=attn_mask.device)
pad_0 = torch.ones(attn_mask.shape[0], 1, dtype=attn_mask.dtype, device=attn_mask.device)
attn_mask = torch.cat([pad_0, attn_mask, pad], dim=1)
for blk in self.blocks:
x = blk(x, attention_mask=attn_mask)
x = self.norm(x)
latent_tokens = torch.cat([x[:, :1, :], x[:, -self.num_latent_tokens :]], dim=1)
return latent_tokens, target
Stage 2: Downstream Harmonizer Heads (modules/harmonizer/stage2_finetune)¶
For tasks like ABIDE diagnosis, GenerateEmbedDataset_downstream reads the saved embeddings plus labels, and VisionTransformer attaches either a CLS-token head or global pooler atop the latent-token expanded sequence. Training mixes standard augmentation knobs (mixup/cutmix) with layer-wise LR decay, and evaluation logs accuracy + F1.^[text title="external_repos/brainharmony/datasets/datasets.py][1:350:external_repos/brainharmony/modules/harmonizer/stage2_finetune/main_finetune.py][16:166:external_repos/brainharmony/modules/harmonizer/stage2_finetune/engine_finetune.py][1:59:external_repos/brainharmony/scripts/harmonizer/stage2_finetune/run_finetune.sh] (lines 828-857)"
def forward_features(self, x, attn_mask):
if self.add_pre_mapping:
x = self.pre_map(x)
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
latent_tokens = self.latent_tokens.expand(x.shape[0], -1, -1)
latent_tokens = latent_tokens + self.enc_latent_token_positional_embedding
x = torch.cat([x, latent_tokens], dim=1)
pad = torch.ones(attn_mask.shape[0], 1200 + latent_tokens.shape[1], dtype=attn_mask.dtype, device=attn_mask.device)
pad_0 = torch.ones(attn_mask.shape[0], 1, dtype=attn_mask.dtype, device=attn_mask.device)
attn_mask = torch.cat([pad_0, attn_mask, pad], dim=1)
for blk in self.blocks:
x = blk(x, attention_mask=attn_mask)
if self.global_pool:
x = torch.cat([x[:, :1, :], x[:, -self.num_latent_tokens :]], dim=1)
x = x[:, 1:, :].mean(dim=1)
outcome = self.fc_norm(x)
else:
x = self.norm(x)
x = torch.cat([x[:, :1, :], x[:, -self.num_latent_tokens :]], dim=1)
outcome = x[:, 0]
return outcome
FlexVisionTransformer & Masked Predictor (libs/flex_transformer.py)¶
FlexVisionTransformer supports flexible patch sizes via FlexiPatchEmbed, optional gradient checkpointing, and either FlashAttention 2 or eager attention blocks. Predictor heads (VisionTransformerPredictor) project encoder outputs into predictor space, tile positional embeddings for masked regions, append learnable mask tokens, and regress back to encoder dimensionality; they reuse apply_masks to select context/target indices.^[text title="external_repos/brainharmony/libs/flex_transformer.py][322:463:external_repos/brainharmony/libs/flex_transformer.py] (lines 482-610)"
def forward(self, x, masks_x, masks, attention_masks=None, return_attention=False):
assert (masks is not None) and (masks_x is not None)
if not isinstance(masks_x, list):
masks_x = [masks_x]
if not isinstance(masks, list):
masks = [masks]
B = len(x) // len(masks_x)
x = self.predictor_embed(x)
predictor_pos_embed = self.predictor_pos_embed()[1]
if self.cls_token is not None:
x_pos_embed = predictor_pos_embed.repeat(B, 1, 1)
x_pos_embed = apply_masks(x_pos_embed, masks_x, cls_token=True)
x += x_pos_embed
_, N_ctxt, D = x.shape
pos_embs = predictor_pos_embed.repeat(B, 1, 1)
pos_embs = apply_masks(pos_embs[:, 1:, :], masks)
pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
pred_tokens += pos_embs
else:
x_pos_embed = predictor_pos_embed.repeat(B, 1, 1)
x += apply_masks(x_pos_embed, masks_x)
_, N_ctxt, D = x.shape
pos_embs = predictor_pos_embed.repeat(B, 1, 1)
pos_embs = apply_masks(pos_embs, masks)
pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
pred_tokens += pos_embs
x = x.repeat(len(masks), 1, 1)
x = torch.cat([x, pred_tokens], dim=1)
for blk in self.predictor_blocks:
x = blk(x, attention_masks)
x = self.predictor_norm(x)
x = x[:, N_ctxt:]
x = self.predictor_proj(x)
return x
Positional Embeddings & Mask Utilities¶
BrainHarmony blends anatomical gradients with geometric harmonics to produce shared positional priors across encoder/predictor stacks, and the same module can emit decoder embeddings for Stage 1. Mask helpers expose gather-style APIs for re-indexing context/target tokens.^[text title="external_repos/brainharmony/libs/position_embedding.py][11:31:external_repos/brainharmony/libs/masks/utils.py] (lines 137-209)"
geo_harm_pos_embed = self.geo_harm_proj(self.geo_harm)
gradient_pos_embed = self.grad_proj(self.gradient)
pos_embed = (gradient_pos_embed + geo_harm_pos_embed) * 0.5
emb_w = pos_embed.squeeze().repeat_interleave(self.repeat_time, dim=0)
emb_w = (emb_w - emb_w.min()) / (emb_w.max() - emb_w.min()) * 2 - 1
emb_encoder = torch.cat([self.emb_h_encoder, emb_w], dim=1).unsqueeze(0)
if self.cls_token:
pos_embed_encoder = torch.concat(
[torch.zeros([1, 1, emb_encoder.shape[2]], requires_grad=False).to(self.device), emb_encoder],
dim=1,
)
else:
pos_embed_encoder = emb_encoder
if self.use_pos_embed_decoder:
emb_w_decoder = self.decoder_pos_embed_proj(emb_w.detach())
emb_decoder = torch.cat([self.emb_h_decoder, emb_w_decoder], dim=1).unsqueeze(0)
if self.cls_token:
pos_embed_decoder = torch.concat(
[torch.zeros([1, 1, emb_decoder.shape[2]], requires_grad=False).to(self.device), emb_decoder],
dim=1,
)
else:
pos_embed_decoder = emb_decoder
return pos_embed_encoder, pos_embed_decoder
return pos_embed_encoder, None
def apply_masks(x, masks, cls_token=False):
all_x = []
if cls_token:
cls_t = x[:, :1, :]
for m in masks:
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x += [torch.cat((cls_t, torch.gather(x[:, 1:, :], dim=1, index=mask_keep)), dim=1)]
else:
for m in masks:
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x += [torch.gather(x, dim=1, index=mask_keep)]
return torch.cat(all_x, dim=0)
Integration Hooks (Brain ↔ Genetics)¶
- Token shapes.
FlexVisionTransformer.forwardoutputs[B, N_tokens, embed_dim](CLS + patches [+ latent tokens]); Stage 2 heads either take the CLS vector or mean-pool latent tokens, so downstream genetics encoders should expect 768-d (base) or 1024-d (large) vectors per sample.^[text title="external_repos/brainharmony/libs/flex_transformer.py][117:170:external_repos/brainharmony/modules/harmonizer/stage2_finetune/models.py] (lines 563-610)" - Attention masks. Both
GenerateEmbedDatasetvariants surface per-sample masks; reusing them when aligning with long genomic sequences preserves which ROI/time windows were padded vs. observed.^[text title="external_repos/brainharmony/datasets/datasets.py] (lines 803-857)" - Stage bridging. Stage 0 writes
.npzfiles withdataandattn_maskarrays; you can append additional modalities (e.g., gene-expression embeddings) into the samedatavector before Stage 1 as long as the downstream models’ positional encoders are updated accordingly.^[text title="external_repos/brainharmony/modules/harmonizer/stage0_embed/embedding_pretrain.py] (lines 138-184)" - Projecting to shared latent spaces. A lightweight projector keeps BrainHarmony tokens compatible with genetics embeddings:
import torch.nn as nn
class BrainHarmonyProjector(nn.Module):
def __init__(self, input_dim=768, output_dim=512, dropout=0.1):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(1024, output_dim),
nn.LayerNorm(output_dim),
)
def forward(self, x):
return self.layers(x)
Map mean-pooled latent tokens through this projector, concatenate with genetics embeddings, and fine-tune a fusion head while reusing BrainHarmony’s attention masks for masking-aware losses.