GENERator Code Walkthrough¶
KB references: Model card · Genomics feature spec · Integration strategy · Experiment config stub
Overview¶
GENERator wraps GPT-style causal decoders (1.2 B and 3 B parameters for both eukaryote and prokaryote checkpoints) with a strict 6-mer tokenizer and long-context optimizations—FlashAttention, Liger kernels, sliding-window decoding—so you can score or generate up to one million base pairs per prompt.^[text title="external_repos/generator/README.md] (lines 5-125)"
At-a-Glance¶
| Architecture | Params | Context | Tokenization / Inputs | Key capabilities | Repo |
|---|---|---|---|---|---|
HuggingFace AutoModelForCausalLM decoder w/ optional ChunkEnsemble Llama heads^[text title="external_repos/generator/src/tasks/downstream/fine_tuning.py][508:688:external_repos/generator/src/tasks/downstream/sequence_understanding.py] |
1.2 B & 3 B checkpoints for euk/prok.^[52:118:external_repos/generator/README.md] |
1 Mbp prompts via sliding windows + FlashAttention^[84:99:external_repos/generator/README.md][612:667:external_repos/generator/src/tasks/downstream/sequence_understanding.py] |
6-mer tokenizer; sequences must be multiples of 6, enforced in preprocessing^[118:125:external_repos/generator/src/tasks/downstream/variant_effect_prediction.py][115:235:external_repos/generator/src/tasks/downstream/fine_tuning.py] |
Variant effect scoring, sequence recovery, classification/regression fine-tuning^[141:406:external_repos/generator/src/tasks/downstream/variant_effect_prediction.py][400:687:external_repos/generator/src/tasks/downstream/sequence_understanding.py] |
github.com/GenerTeam/GENERator |
Environment & Hardware Notes¶
- Long-context dependencies. For million-base contexts the README recommends installing the custom kernels explicitly:
pip install liger-kernel
pip install flash-attn --no-build-isolation^[text title="external_repos/generator/README.md] (lines 84-89)" - Gradient checkpointing flag. When operating on >10 kbp sequences, the authors enable
model.gradient_checkpointing_enable()to trade compute for memory.^[text title="external_repos/generator/README.md] (lines 420-424)"
Key Components¶
Tokenizer & Preprocessing (variant_effect_prediction.py, fine_tuning.py)¶
The downstream scripts consistently load the HF tokenizer with trust_remote_code=True, force pad tokens to EOS if missing, and either truncate or pad every sequence to the nearest 6-mer boundary (pad_to_multiple_of_six flag) so the 6-mer BPE never emits <oov> tokens.
Tokenizer initialization:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
...
inputs = tokenizer(batch_sequences, return_tensors="pt", padding=True)
6-mer boundary truncation:
truncate_length = len(sequence) % 6
if truncate_length > 0:
sequence = sequence[truncate_length:]
Padding to 6-mer multiples:
if pad_to_multiple_of_six:
remainder = len(seq) % 6
if remainder != 0:
pad_len = 6 - remainder
seq = seq + "A" * pad_len
tokenized = tokenizer(
sequences,
truncation=True,
max_length=max_length,
add_special_tokens=True,
padding=False,
)
Positional & Long-Context Handling (sequence_understanding.py)¶
sequence_understanding.py either scales RoPE via YaRN or injects sliding-window attention patches so you can extend Llama-based classifiers to >1 M tokens while staying numerically stable.
elif length_extension_mode == "sliding_window":
config.sliding_window = int(original_model_max_length_for_scaling)
...
def _sliding_llama_forward(...):
kwargs["sliding_window"] = self.config.sliding_window
return _orig_forward(...)
LlamaAttention.forward = _sliding_llama_forward
attn_implementation = "flash_attention_2"
Backbone Instantiation (fine_tuning.py, sequence_understanding.py)¶
Fine-tuning uses AutoModelForCausalLM with optional pad ID fixes, while sequence-understanding swaps in AutoModelForSequenceClassification or the ChunkEnsemble wrapper to keep a rolling window over million-token sequences.
Causal LM model loading:
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
)
if model.config.pad_token_id is None and hasattr(model.config, "eos_token_id"):
model.config.pad_token_id = model.config.eos_token_id
class ChunkEnsembleLlamaForSequenceClassification(LlamaPreTrainedModel):
def forward(...):
input_ids_chunks = input_ids.unfold(dimension=1, size=self.chunk_size, step=self.stride)
...
chunk_eos_embedding = hidden_states[
torch.arange(batch_size, device=hidden_states.device),
sequence_lengths,
]
stacked_embeddings = torch.stack(all_chunk_eos_embeddings, dim=1)
final_representation = padded_embeddings.view(batch_size, -1)
logits = self.classifier(final_representation)
Objective & Training Loop (fine_tuning.py)¶
The script wraps everything in transformers.Trainer with DataCollatorForLanguageModeling (mlm=False) so causal LM losses line up with the HF training stack and distributed options (DeepSpeed, FSDP) set via CLI.
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
trainer.train()
Inference Helpers (variant_effect_prediction.py)¶
Variant effect prediction shards ClinVar sequences across GPUs, caches logits, and computes per-base probabilities by summing over all tokens starting with ref/alt characters. This utility powers the headline ClinVar AUROC numbers.
def compute_logits_parallel(...):
num_gpus = torch.cuda.device_count()
shards.append({'shard_id': i, 'sequences_data': sequences_data[start_idx:end_idx], ...})
with ctx.Pool(processes=num_gpus) as pool:
results = list(pool.imap(compute_logits_shard, args_list))
def parallel_compute_probabilities(...):
vocab = tokenizer.get_vocab()
char_indices = get_char_indices(vocab)
results = list(pool.imap(compute_prob, args_list, chunksize=chunksize))
p_ref, p_alt = zip(*results)
Embedding Extraction (sequence_understanding.py)¶
ChunkEnsemble accumulates the EOS vector from each sliding chunk, pads/truncates them to a fixed count, and flattens into a [B, max_chunks * hidden] representation before the classifier head—exactly what you can reuse for downstream alignment.
stacked_embeddings = torch.stack(all_chunk_eos_embeddings, dim=1)
num_padding_chunks = self.max_chunks - stacked_embeddings.shape[1]
...
final_representation = padded_embeddings.view(batch_size, -1)
logits = self.classifier(final_representation)
Sequence Constraints (variant_effect_prediction.py, fine_tuning.py)¶
Both inference and training enforce the 6-mer constraint by trimming or padding raw strings and, for dataset preprocessing, only accepting columns named sequence, seq, dna_sequence, etc., so you cannot silently feed invalid tokens.
if "sequence" in examples:
sequences = examples["sequence"]
elif "seq" in examples:
sequences = examples["seq"]
...
else:
raise ValueError("No sequence column found in dataset.")
Integration Hooks (Genetics ↔ Brain)¶
- Embedding shapes. GENERator decoders yield
[B, L_tokens, hidden]tensors; ChunkEnsemble condenses them into[B, max_chunks, hidden]before flattening to[B, max_chunks * hidden]. You can stop just before the final classifier to grab the stacked embeddings for pooling.^[text title="external_repos/generator/src/tasks/downstream/sequence_understanding.py] (lines 446-505)" - Pooling strategies. Use mean pooling along the chunk dimension for overall sequence summaries, max pooling for motif emphasis, or take the final chunk (equivalent to autoregressive “last token”). Because chunk embeddings correspond to non-overlapping windows, pooling behaves like low-resolution downsampling.
- Projection to shared latent. After pooling to
[B, H](H≈1536 for the 1.2 B model), apply a projector to map into the same 512-D space used by your brain encoder:
import torch.nn as nn
class GeneratorProjector(nn.Module):
def __init__(self, input_dim=1536, output_dim=512, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(1024, output_dim),
nn.LayerNorm(output_dim),
)
def forward(self, x):
return self.net(x)
- Normalization. LayerNorm (as in the projector above) keeps token-averaged embeddings comparable to fMRI CLS tokens (BrainLM/SwiFT), especially before cosine-similarity objectives.
- Sequence hygiene. Reuse the
pad_to_multiple_of_sixlogic orensure_6mer_compatiblehelper whenever you extract embeddings outside the packaged scripts; otherwise, HF will inject<oov>tokens that shift chunk boundaries and misalign pooling.^[text title="external_repos/generator/src/tasks/downstream/variant_effect_prediction.py] (lines 118-125)" - Memory tips. For million-token prompts, lean on ChunkEnsemble (
length_extension_mode="chunk_ensemble") or sliding-window RoPE to avoid editing HF internals; both paths keep per-chunk lengths manageable and let FlashAttention v2 handle the heavy lifting.^[text title="external_repos/generator/src/tasks/downstream/sequence_understanding.py] (lines 566-666)"
Following these steps yields [B, 512] genetic embeddings that can be concatenated with or contrastively aligned against brain-model outputs such as BrainLM CLS vectors or BrainMT/SwiFT pooled features.