Sequence and Structure Modeling
Protein language models, equivariant graph networks, and structure prediction models introduce distributed training challenges that do not appear in standard CV or NLP pipelines. Sequence length variation is extreme. Pair representations are quadratic in memory. Geometric equivariance adds architectural constraints that interact badly with standard parallelism strategies.
The Biotech Model Landscape
Section titled “The Biotech Model Landscape”flowchart LR A[Protein sequence] --> B[PLM encoder: ESM-style] C[Molecular graph] --> D[GNN encoder: MPNN / SchNet] E[3D structure] --> F[Equivariant encoder: SE3 / EGNN] B --> G[Task head] D --> G F --> G G --> H[Property or structure output]
Protein Language Models
Section titled “Protein Language Models”ESM-style protein language models are transformer encoders trained on amino acid sequences with masked token prediction. The distributed training challenges differ from standard NLP in two ways:
- Sequence length distribution is bimodal and heavy-tailed. UniRef50 spans lengths from 10 to 35,000 residues. A single long protein in a batch can dominate memory and cause a rank-level OOM that does not affect other ranks.
- The vocabulary is small (20–33 tokens) but the activation footprint is large. Memory pressure comes from activations and attention matrices, not embedding tables.
from dataclasses import dataclass, field
from training_systems_architecture import TrainConfig
@dataclassclass ProteinLMConfig(TrainConfig): max_length: int = 1024 vocab_size: int = 33 d_model: int = 1280 num_layers: int = 33 num_heads: int = 20 length_bucket_boundaries: list[int] = field( default_factory=lambda: [128, 256, 512, 1024] )
@property def pair_representation_memory_gb(self) -> float: return (self.max_length**2 * self.d_model * 4) / (1024**3)At max_length=1024 and d_model=1280, a pair representation alone consumes 6.7 GB per sequence in FP32. That number drives the decisions about sequence parallelism and activation checkpointing.
Sequence Parallelism
Section titled “Sequence Parallelism”Standard data parallelism (DDP) replicates the full model across ranks. For transformers on long sequences, the bottleneck is per-sequence activation footprint, not parameter count.
Sequence parallelism partitions the sequence dimension across ranks so each rank processes a contiguous subsequence. The attention layer requires an all-gather across the sequence axis before computing attention scores:
flowchart TD A[Sequence of length L] --> B[Split: L divided by N tokens per rank] B --> C[Rank 0: tokens 0 to L/N] B --> D[Rank 1: tokens L/N to 2L/N] B --> E[Rank N-1: remaining tokens] C --> F[Local Q/K/V projection] D --> F E --> F F --> G[All-gather keys and values] G --> H[Full attention on each rank] H --> I[Scatter output back to sequence shards]
| Condition | Action |
|---|---|
| Sequences fit in device memory at target batch size | DDP first; simpler failure model |
| Sequences exceed device memory even at micro_batch=1 | Sequence parallelism is required |
| Mixed short and long sequences | Dynamic batching; sequence parallelism for long-tail only |
Pair Representation Memory
Section titled “Pair Representation Memory”AlphaFold-style architectures maintain a pair representation of shape (L, L, d) where L is sequence length and d is the channel dimension. Memory grows quadratically in L:
def pair_memory_bytes(seq_len: int, d_pair: int = 128, dtype_bytes: int = 2) -> int: return seq_len * seq_len * d_pair * dtype_bytes
for length in [256, 512, 1024, 2048]: gb = pair_memory_bytes(length) / (1024**3) print(f"L={length:5d}: {gb:.3f} GB per sequence (bfloat16, d_pair=128)")| Sequence length | Pair memory (bfloat16, d_pair=128) |
|---|---|
| 256 | 0.016 GB |
| 512 | 0.064 GB |
| 1024 | 0.25 GB |
| 2048 | 1.0 GB |
Gradient accumulation does not help pair-representation memory because the tensor lives in the forward graph, not the optimizer state. The two effective mitigations are activation checkpointing over Evoformer-style pair update blocks and a sequence length curriculum that caps length during early training.
SE(3)-Equivariant Networks
Section titled “SE(3)-Equivariant Networks”Equivariant networks for 3D molecular structures produce outputs that transform predictably under rotation and translation. This constraint affects the distributed training setup in non-obvious ways:
flowchart LR A[3D atom coordinates] --> B[Canonicalize coordinate frame] B --> C[Equivariant message passing] C --> D[Scalar invariant features] C --> E[Vector equivariant features] D --> F[Task readout] E --> F B -.->|Must precede sharding| G[Data partition]
Three training constraints that matter operationally:
- Frame canonicalization before sharding. Re-centering or aligning a molecular structure must happen in preprocessing, not inside a DataLoader worker. A rank that receives a shard of a protein cannot independently re-center because it lacks the full structural context.
- No rotation augmentation. An equivariant model is rotation-invariant by construction. Augmenting with random rotations wastes compute and can degrade training when equivariance is approximate rather than exact.
- Layer normalization over batch normalization. Batch normalization aggregates across the batch dimension. For variable-size molecular graphs, this is inconsistent across batches. Layer normalization operates per-node and is safe.
Multi-Task Training Across Assay Types
Section titled “Multi-Task Training Across Assay Types”Biotech models are frequently trained simultaneously on dozens to hundreds of biological assays. This creates correctness challenges that resemble class imbalance but are harder to detect because the label matrix is sparse:
import torchimport torch.nn as nn
def multitask_loss( predictions: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor,) -> torch.Tensor: per_task_loss = nn.functional.binary_cross_entropy_with_logits( predictions, targets, reduction="none" ) masked = per_task_loss * mask.float() weighted = masked * weights.unsqueeze(0) return weighted.sum() / mask.float().sum().clamp(min=1.0)The mask is critical. In a multi-task bioassay dataset, most compounds are not tested in most assays. Treating missing labels as negatives is a common silent error that systematically biases every head toward predicting inactivity and inflates apparent performance.
Distributed Training Config Adjustments
Section titled “Distributed Training Config Adjustments”The standard TrainConfig from the baseline trainer needs adjustment for long-sequence biological models:
| Parameter | Generic default | Protein LM adjustment | Why |
|---|---|---|---|
micro_batch_size | 16–64 | 1–4 | Activation footprint per sequence |
grad_accum_steps | 1–4 | 8–32 | Recover effective batch size |
use_amp | True (float16) | True (bfloat16 preferred) | Larger dynamic range; avoids overflow on long sequences |
max_length | N/A | 512–2048 with curriculum | Memory budget sets the ceiling |
num_workers | 2–4 | 1–2 | Variable-length batching reduces prefetch benefit |
Activation Checkpointing for Long Sequences
Section titled “Activation Checkpointing for Long Sequences”For protein LMs and structure prediction models, activation checkpointing is not optional. It is the primary tool for fitting large sequences in device memory:
import torch.nn as nnfrom torch.utils.checkpoint import checkpoint
class CheckpointedTransformerBlock(nn.Module): def __init__(self, block: nn.Module) -> None: super().__init__() self.block = block
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return checkpoint(self.block, x, mask, use_reentrant=False)use_reentrant=False aligns with the current PyTorch recommendation. The reentrant variant has known issues with some forward graph patterns common in attention implementations and does not support torch.compile.
Staff-Level Tradeoffs
Section titled “Staff-Level Tradeoffs”| Decision | Why you choose it | What it costs |
|---|---|---|
| DDP over FSDP for PLMs | Simpler baseline | Breaks down once model exceeds per-device memory |
| Sequence parallelism | Required for L > 2048 on 80 GB devices | Communication overhead per attention layer |
| Activation checkpointing on pair blocks | Keeps pair-representation memory bounded | ~33% more compute per forward pass |
| Sequence length curriculum | Stable early training; prevents OOM from long-tail examples | Complicates epoch definition and sampler resume |
| bfloat16 over float16 | Larger dynamic range; fewer overflow events | Slightly lower throughput on older hardware |
| Masking missing assay labels | Prevents silent negative-label contamination | Reduces effective batch density per task |