Training Systems Architecture
This page gives you the mental model to explain a distributed training system before you write a line of PyTorch.
The Simplest Architecture That Still Looks Senior
Section titled “The Simplest Architecture That Still Looks Senior”flowchart LR A[Job Spec] --> B[Orchestrator] B --> C[Rendezvous + World Setup] B --> D[Trainer Rank 0] B --> E[Trainer Rank 1..N] F[Dataset / Feature Store] --> G[Shard + Sampler Layer] G --> D G --> E D --> H[Checkpoint Store] E --> H D --> I[Metrics / Logs / Traces] E --> I
The interviewer does not need your exact internal platform. They need evidence that you can separate concerns:
- orchestration decides when work starts and where it runs
- rendezvous decides which ranks belong to the job
- trainer runtime performs forward, backward, and optimizer steps
- data plane ensures each rank gets the correct sample stream
- artifact plane stores checkpoints and model outputs
- observability plane tells operators whether the run is healthy
Rank Topology
Section titled “Rank Topology”In PyTorch distributed training, these terms should be automatic:
| Term | Meaning | Why interviewers care |
|---|---|---|
world_size | Total process count in the job | Determines collective behavior and global batch semantics |
rank | Unique process index across the job | Used for role assignment and output suppression |
local_rank | Device index local to a node | Needed for correct device pinning |
| process group | Communicator over a set of ranks | Becomes central when discussing hybrid parallelism |
The Core Control Loop
Section titled “The Core Control Loop”def train_job(cfg: TrainConfig) -> None: dist.init_process_group( backend=cfg.backend, init_method=cfg.init_method, world_size=cfg.world_size, rank=cfg.rank, ) torch.cuda.set_device(cfg.local_rank)
model = build_model(cfg).to(cfg.device) optimizer = build_optimizer(model, cfg) sampler = build_sampler(cfg.dataset, cfg.rank, cfg.world_size, cfg.seed) loader = build_loader(cfg.dataset, sampler, cfg) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[cfg.local_rank], output_device=cfg.local_rank, gradient_as_bucket_view=True, )
state = maybe_restore_checkpoint(model, optimizer, sampler, cfg) for epoch in range(state.epoch, cfg.max_epochs): sampler.set_epoch(epoch) for batch in loader: loss = train_step(model, optimizer, batch, cfg) emit_metrics(loss=loss.item(), rank=cfg.rank) save_checkpoint(model, optimizer, sampler, epoch, cfg)This is intentionally boring. Boring is good in interviews if you can explain why:
- DDP is the most legible baseline
- sampler state is part of correctness, not just convenience
- rank-aware metrics prevent log duplication
- checkpoint save frequency is an RPO tradeoff, not a cosmetic choice
Current PyTorch Notes
Section titled “Current PyTorch Notes”The current official DDP docs make two points worth quoting into your mental model:
- DDP synchronizes gradients across model replicas, but it does not shard inputs for you. PyTorch explicitly puts input partitioning on the user side, usually through
DistributedSampler. gradient_as_bucket_view=Truecan reduce peak memory and avoid gradient-to-bucket copies, but the gradients become views and some in-place assumptions break.
That matters in interviews because it separates:
- model replication semantics
- communication semantics
- data partitioning semantics
Architectural Invariants
Section titled “Architectural Invariants”1. Global batch semantics must be explicit
Section titled “1. Global batch semantics must be explicit”If each rank processes micro_batch_size=8 and world_size=8, your effective global batch is at least 64, and larger if you add gradient accumulation. State it clearly.
2. Storage is part of the training algorithm
Section titled “2. Storage is part of the training algorithm”Checkpoint throughput, atomicity, and restore latency affect whether your pipeline is operationally viable. Treat storage as a first-class design dependency.
3. Logging is a distributed systems problem
Section titled “3. Logging is a distributed systems problem”If all ranks write identical logs, you have noise. If only rank 0 writes, you may miss local failures. The practical answer is:
- structured logs on every rank
- dense console output only on rank 0
- counters and histograms aggregated centrally
What Belongs In The Notebook Version
Section titled “What Belongs In The Notebook Version”flowchart TD A[Notebook Cell] --> B[Config dataclass] A --> C[Dataset stub] A --> D[Trainer wrapper] D --> E[Mock metrics sink] D --> F[Checkpoint adapter] D --> G[Launch abstraction]
A strong notebook does not fully implement Kubernetes launch, object-store auth, or Prometheus exporters. It does preserve the boundaries where those concerns would attach.
Staff-Level Tradeoffs To Mention
Section titled “Staff-Level Tradeoffs To Mention”| Decision | Why you choose it first | What cost it creates later |
|---|---|---|
| DDP over FSDP | Lowest cognitive load, reliable default for live coding | Less memory efficiency at larger model sizes |
| full checkpoints | Simplest restore semantics | Large I/O footprint and slower save times |
| map-style dataset | Easy deterministic sharding | Harder to model streaming datasets |
| rank-0 orchestration decisions | Clear control path | Can become a scaling bottleneck for advanced coordination |