Skip to content

SwiFT Code Walkthrough

KB references: Model card · fMRI feature spec · Integration strategy · Experiment config stub

Overview

SwiFT (Swin 4D fMRI Transformer) tokenizes 4D fMRI volumes with 3D convolutions, processes them with windowed 4D self-attention (spatial + temporal windows), and trains contrastive or supervised heads via PyTorch Lightning.^[text title="external_repos/swift/project/module/models/swin4d_transformer_ver7.py][1:188:external_repos/swift/project/main.py] (lines 1-400)"

At-a-Glance

Architecture Params Context Inputs Key capabilities Repo
Swin-inspired 4D transformer w/ window attention & patch merging^[text title="external_repos/swift/project/module/models/swin4d_transformer_ver7.py] Configurable (e.g., embed_dim=96, depths from config)^[402:565:external_repos/swift/project/module/models/swin4d_transformer_ver7.py] 96×96×96 voxels × 20 frames (default)^[250:300:external_repos/swift/project/module/utils/data_module.py] Preprocessed volumes from fMRIDataModule (UKB/HCP/etc.)^[13:260:external_repos/swift/project/module/utils/data_module.py] Lightning training with contrastive or supervised heads, downstream evaluation scripts^[21:187:external_repos/swift/project/main.py][32:395:external_repos/swift/project/module/pl_classifier.py] github.com/Transconnectome/SwiFT

Environment & Hardware Notes

  • Conda environment. The README tells you to run conda env create -f envs/py39.yaml followed by conda activate py39 to pull in the exact PyTorch/Lightning versions used for the released checkpoints.^[text title="external_repos/swift/README.md] (lines 45-55)"
  • Gradient checkpoint knobs. Every Swin4D stage accepts use_checkpoint and executes torch.utils.checkpoint.checkpoint(...) when set, so add use_checkpoint=True in your model config to extend contexts without exceeding GPU memory.^[text title="external_repos/swift/project/module/models/swin4d_transformer_ver7.py][507:744:external_repos/swift/project/module/models/swin4d_transformer_ver7.py] (lines 224-312)"

Key Components

Data Module (project/module/utils/data_module.py)

fMRIDataModule loads datasets (UKB, HCP, etc.), splits subjects, and returns PyTorch DataLoaders. Augmentations (affine/noise) are applied in the Lightning module.

PyTorch Lightning data module:

external_repos/swift/project/module/utils/data_module.py (lines 13-230)
class fMRIDataModule(pl.LightningDataModule):
    def get_dataset(self):
        if self.hparams.dataset_name == "S1200": return S1200
        ...
    def setup(self, stage=None):
        Dataset = self.get_dataset()
        params = {"root": self.hparams.image_path, "sequence_length": self.hparams.sequence_length, ...}
        self.train_dataset = Dataset(**params, subject_dict=train_dict, ...)
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, ...)

Patch Embedding & Window Attention (swin4d_transformer_ver7.py)

PatchEmbed downsamples volumes with strided 3D convs, WindowAttention4D computes attention inside local 4D windows, and SwinTransformerBlock4D applies shifted windows for better coverage. PatchMergingV2 reduces spatial resolution while keeping temporal size.

4D windowed attention with patch embedding:

external_repos/swift/project/module/models/swin4d_transformer_ver7.py (lines 202-399)
class PatchEmbed(nn.Module):
    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=(1, patch_size), stride=(1, patch_size))
...
class WindowAttention4D(nn.Module):
    def forward(self, x, mask):
        qkv = self.qkv(x).reshape(...)
        attn = self.softmax((q @ k.transpose(-2, -1)) * self.scale)
        x = (attn @ v)

Swin4D Backbone (swin4d_transformer_ver7.py)

BasicLayer stacks windowed blocks, handles padding, applies attention masks, and optionally downsamples. The main SwinTransformer4D builds multiple stages with positional embeddings, patch merging, and normalization.

Multi-stage Swin transformer with patch merging:

external_repos/swift/project/module/models/swin4d_transformer_ver7.py (lines 400-796)
class BasicLayer(nn.Module):
    for blk in self.blocks:
        x = blk(x, attn_mask)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
    if self.downsample is not None:
        x = self.downsample(x)
...
class SwinTransformer4D(nn.Module):
    self.patch_embed = PatchEmbed(...)
    self.layers = nn.ModuleList([...])
    def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        for layer in self.layers:
            x = self.pos_embeds[i] (x)
            x = self.layers[i] (x.contiguous())
        return x

Lightning Module (project/module/pl_classifier.py)

LitClassifier wraps the encoder, applies augmentations if requested, and attaches task-specific heads (classification/regression/contrastive). _calculate_loss routes to BCE, MSE, or contrastive losses.

Task-specific heads with loss routing:

external_repos/swift/project/module/pl_classifier.py (lines 32-205)
self.model = load_model(self.hparams.model, self.hparams)
if self.hparams.downstream_task == 'sex':
    self.output_head = load_model("clf_mlp", self.hparams)
elif self.hparams.downstream_task == 'age':
    self.output_head = load_model("reg_mlp", self.hparams)
...
def _calculate_loss(self, batch, mode):
    if self.hparams.pretraining:
        # contrastive losses (NT-Xent)
    else:
        subj, logits, target = self._compute_logits(batch)
        if classification:
            loss = F.binary_cross_entropy_with_logits(logits, target)
        else:
            loss = F.mse_loss(logits.squeeze(), target.squeeze())

Training Entry Point (project/main.py)

CLI parses dataset/model/task args, instantiates the Lightning module + data module, and launches PyTorch Lightning Trainer with callbacks (checkpointing, LR monitor).

CLI entrypoint with Lightning trainer:

external_repos/swift/project/main.py (lines 18-187)
parser = ArgumentParser(...)
parser = Classifier.add_model_specific_args(parser)
parser = Dataset.add_data_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
data_module = Dataset(**vars(args))
model = Classifier(data_module=data_module, **vars(args))
trainer = pl.Trainer.from_argparse_args(args, logger=logger, callbacks=callbacks)
trainer.fit(model, datamodule=data_module)

Integration Hooks (Brain ↔ Genetics)

  • Embedding shape. Encoder outputs [B, N_tokens, embed_dim]. Downstream heads either global-average tokens (mean(dim=[2,3,4])) or use CLS-like features (depending on head). Use _compute_logits to capture the tensor before the head for multimodal projection.^[text title="external_repos/swift/project/module/pl_classifier.py] (lines 108-205)"
  • Pooling choices. Mean pooling across spatial dimensions (features.mean(dim=[2,3,4])) produces [B, embed_dim]; temporal pooling can be added if you keep time as a separate axis prior to patch merging.
  • Projection to shared latent. Apply a lightweight projector to map [B, embed_dim] into a 512-D shared space:

import torch.nn as nn

class SwiFTProjector(nn.Module):
    def __init__(self, input_dim=768, output_dim=512, dropout=0.1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(1024, output_dim),
            nn.LayerNorm(output_dim),
        )

    def forward(self, x):
        return self.layers(x)
- Augmentation awareness. When extracting embeddings for alignment, disable augmentations (augment_during_training=False) to avoid random affine/noise perturbations that would misalign with genetic features.^[text title="external_repos/swift/project/module/pl_classifier.py] (lines 108-205)" - Window constraints. Ensure inference volumes match training window sizes (img_size, window_size)—get_window_size shrinks windows when needed, but you lose attention overlap if sizes are too small.^[text title="external_repos/swift/project/module/models/swin4d_transformer_ver7.py] (lines 110-200)"

After projection, SwiFT embeddings (global pooled or CLS) can be concatenated or contrastively aligned with Evo 2/GENERator/Caduceus projections for multimodal neurogenomics.