All-In-One Handbook
This page compresses the rest of the site into one long-form reference. It is the page to open if you want the entire story in one place: architecture, trainer code, data sharding, checkpointing, observability, performance tuning, and hard interview questions.
What This Page Is For
Section titled “What This Page Is For”Use this when you want:
- one continuous explanation instead of many short docs
- one full trainer example you can study line by line
- one place to rehearse the talk track for a senior/staff interview
- one page you can skim in the final hour before the interview
The System Model
Section titled “The System Model”Before writing code, define the system in terms of planes:
- control plane: launch, retry policy, placement, config, run metadata
- training plane: forward, backward, optimizer step, gradient synchronization
- data plane: dataset partitioning, loading, preprocessing, host-to-device movement
- artifact plane: checkpoints, manifests, final model outputs
- observability plane: metrics, logs, traces, alerts
flowchart LR A[Job Spec] --> B[Launcher / Orchestrator] B --> C[Rendezvous] C --> D[Rank 0] C --> E[Rank 1..N] F[Dataset or Feature Store] --> G[Distributed Sampler] G --> D G --> E D --> H[Checkpoint Store] E --> H D --> I[Metrics + Logs] E --> I
The First Principles You Need To Say Out Loud
Section titled “The First Principles You Need To Say Out Loud”In a good interview answer, the first few sentences should establish invariants:
- Each rank must know its identity and device ownership.
- The input stream must be partitioned deterministically across ranks.
- The optimizer state must move forward in lock-step with the model state.
- Resume semantics must be explicit instead of hand-waved.
- We need enough observability to separate input stalls, compute inefficiency, and communication bottlenecks.
That framing is often more important than the exact code.
Process Topology
Section titled “Process Topology”The default distributed baseline is:
- one process per GPU
- one DDP-wrapped model replica per process
- one rank-aware sampler per process
- one shared checkpoint contract across all processes
flowchart TD A[Node] --> B[local_rank 0 -> cuda:0] A --> C[local_rank 1 -> cuda:1] B --> D[rank 0] C --> E[rank 1] D --> F[full model replica] E --> G[full model replica] F --> H[gradient all-reduce] G --> H
Full Reference Implementation
Section titled “Full Reference Implementation”The goal of this code is not to be a turnkey production package. The goal is to show a complete, production-shaped training skeleton that you can explain under pressure.
from __future__ import annotations
import jsonimport osimport randomimport timefrom contextlib import contextmanagerfrom dataclasses import asdict, dataclassfrom pathlib import Pathfrom types import SimpleNamespace
import numpy as npimport torchimport torch.distributed as distimport torch.nn as nnfrom torch.amp import GradScaler, autocastfrom torch.nn.parallel import DistributedDataParallel as DDPfrom torch.utils.data import DataLoader, Dataset, DistributedSampler
@dataclassclass TrainConfig: backend: str = "nccl" init_method: str = "env://" world_size: int = 1 rank: int = 0 local_rank: int = 0 seed: int = 17 dataset_size: int = 20_000 input_width: int = 256 hidden_width: int = 512 num_classes: int = 8 micro_batch_size: int = 16 grad_accum_steps: int = 1 num_workers: int = 2 learning_rate: float = 3e-4 weight_decay: float = 0.01 max_epochs: int = 3 log_every: int = 20 checkpoint_every_steps: int = 200 checkpoint_dir: str = "/tmp/torch-control-plane/checkpoints" run_id: str = "demo-run" use_amp: bool = True
@property def is_distributed(self) -> bool: return self.world_size > 1
@property def is_main_rank(self) -> bool: return self.rank == 0
@property def device(self) -> torch.device: if torch.cuda.is_available(): return torch.device(f"cuda:{self.local_rank}") return torch.device("cpu")
@property def global_batch_size(self) -> int: return self.micro_batch_size * self.grad_accum_steps * self.world_size
def digest(self) -> str: payload = json.dumps(asdict(self), sort_keys=True) return str(abs(hash(payload)))
class ToyClassificationDataset(Dataset): def __init__(self, size: int, width: int, num_classes: int) -> None: g = torch.Generator().manual_seed(1234) self.features = torch.randn(size, width, generator=g) self.labels = torch.randint(0, num_classes, (size,), generator=g)
def __len__(self) -> int: return len(self.labels)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]: return { "inputs": self.features[index], "targets": self.labels[index], }
class ResumeAwareDistributedSampler(DistributedSampler): def __init__( self, dataset: Dataset, num_replicas: int, rank: int, seed: int = 0, consumed: int = 0, **kwargs, ) -> None: super().__init__( dataset, num_replicas=num_replicas, rank=rank, seed=seed, **kwargs, ) self.consumed = consumed
def state_dict(self) -> dict[str, int]: return {"epoch": self.epoch, "consumed": self.consumed}
def load_state_dict(self, state: dict[str, int]) -> None: self.epoch = state["epoch"] self.consumed = state["consumed"]
def __iter__(self): indices = list(super().__iter__()) start = min(self.consumed, len(indices)) for index in indices[start:]: yield index self.consumed += 1
class TinyNet(nn.Module): def __init__(self, width: int, hidden: int, num_classes: int) -> None: super().__init__() self.net = nn.Sequential( nn.Linear(width, hidden), nn.GELU(), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, num_classes), )
def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)
def setup_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def maybe_init_dist(cfg: TrainConfig) -> None: if not cfg.is_distributed: return
if torch.cuda.is_available(): torch.cuda.set_device(cfg.local_rank)
dist.init_process_group( backend=cfg.backend, init_method=cfg.init_method, world_size=cfg.world_size, rank=cfg.rank, )
def cleanup_dist() -> None: if dist.is_available() and dist.is_initialized(): dist.destroy_process_group()
def unwrap(model: nn.Module) -> nn.Module: return model.module if hasattr(model, "module") else model
def latest_checkpoint_path(checkpoint_dir: str) -> Path | None: base = Path(checkpoint_dir) if not base.exists(): return None candidates = sorted(base.glob("step-*.pt")) return candidates[-1] if candidates else None
def ensure_checkpoint_dir(path: str) -> None: Path(path).mkdir(parents=True, exist_ok=True)
def log_rank_event(cfg: TrainConfig, event: str, **fields) -> None: payload = { "event": event, "rank": cfg.rank, "local_rank": cfg.local_rank, "run_id": cfg.run_id, **fields, } print(json.dumps(payload, sort_keys=True))
@contextmanagerdef phase(timer_store: dict[str, list[float]], name: str): started = time.perf_counter() try: yield finally: timer_store.setdefault(name, []).append(time.perf_counter() - started)
def move_batch_to_device(batch: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]: return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
def build_dataloader(cfg: TrainConfig, dataset: Dataset, sampler: DistributedSampler) -> DataLoader: return DataLoader( dataset, batch_size=cfg.micro_batch_size, sampler=sampler, num_workers=cfg.num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=cfg.num_workers > 0, drop_last=True, )
def build_checkpoint( step: int, epoch: int, model: nn.Module, optimizer: torch.optim.Optimizer, scaler: GradScaler | None, sampler: ResumeAwareDistributedSampler, cfg: TrainConfig,) -> dict: return { "step": step, "epoch": epoch, "model": unwrap(model).state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict() if scaler else None, "sampler": sampler.state_dict(), "rng": { "python": random.getstate(), "numpy": np.random.get_state(), "torch": torch.get_rng_state(), "cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, }, "metadata": { "run_id": cfg.run_id, "world_size": cfg.world_size, "config_digest": cfg.digest(), "global_batch_size": cfg.global_batch_size, }, }
def save_checkpoint( step: int, epoch: int, model: nn.Module, optimizer: torch.optim.Optimizer, scaler: GradScaler | None, sampler: ResumeAwareDistributedSampler, cfg: TrainConfig,) -> None: if not cfg.is_main_rank: return
ensure_checkpoint_dir(cfg.checkpoint_dir) state = build_checkpoint(step, epoch, model, optimizer, scaler, sampler, cfg) path = Path(cfg.checkpoint_dir) / f"step-{step:08d}.pt" torch.save(state, path) log_rank_event(cfg, "checkpoint_saved", path=str(path), step=step, epoch=epoch)
def restore_if_present( model: nn.Module, optimizer: torch.optim.Optimizer, scaler: GradScaler | None, sampler: ResumeAwareDistributedSampler, cfg: TrainConfig,) -> SimpleNamespace: path = latest_checkpoint_path(cfg.checkpoint_dir) if path is None: return SimpleNamespace(step=0, epoch=0)
state = torch.load(path, map_location="cpu") unwrap(model).load_state_dict(state["model"]) optimizer.load_state_dict(state["optimizer"]) if scaler and state["scaler"] is not None: scaler.load_state_dict(state["scaler"]) sampler.load_state_dict(state["sampler"])
random.setstate(state["rng"]["python"]) np.random.set_state(state["rng"]["numpy"]) torch.set_rng_state(state["rng"]["torch"]) if torch.cuda.is_available() and state["rng"]["cuda"] is not None: torch.cuda.set_rng_state_all(state["rng"]["cuda"])
log_rank_event( cfg, "checkpoint_restored", path=str(path), step=state["step"], epoch=state["epoch"], ) return SimpleNamespace(step=state["step"], epoch=state["epoch"])
def train_step( model: nn.Module, optimizer: torch.optim.Optimizer, batch: dict[str, torch.Tensor], scaler: GradScaler | None, cfg: TrainConfig, timer_store: dict[str, list[float]],) -> float: with phase(timer_store, "h2d"): batch = move_batch_to_device(batch, cfg.device)
with phase(timer_store, "forward"): with autocast(device_type=cfg.device.type, enabled=scaler is not None): logits = model(batch["inputs"]) loss = nn.functional.cross_entropy(logits, batch["targets"])
with phase(timer_store, "backward"): if scaler: scaler.scale(loss).backward() else: loss.backward()
with phase(timer_store, "optimizer"): if scaler: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad(set_to_none=True)
return float(loss.detach().cpu().item())
def summarize_timers(timer_store: dict[str, list[float]]) -> dict[str, float]: return { name: float(sum(values) / max(len(values), 1)) for name, values in timer_store.items() }
def run(cfg: TrainConfig) -> None: setup_seed(cfg.seed) maybe_init_dist(cfg)
dataset = ToyClassificationDataset( size=cfg.dataset_size, width=cfg.input_width, num_classes=cfg.num_classes, ) sampler = ResumeAwareDistributedSampler( dataset, num_replicas=cfg.world_size, rank=cfg.rank, seed=cfg.seed, shuffle=True, ) loader = build_dataloader(cfg, dataset, sampler)
model = TinyNet( width=cfg.input_width, hidden=cfg.hidden_width, num_classes=cfg.num_classes, ).to(cfg.device) optimizer = torch.optim.AdamW( model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay, ) scaler = GradScaler(device=cfg.device.type, enabled=cfg.use_amp and cfg.device.type == "cuda")
if cfg.is_distributed: model = DDP( model, device_ids=[cfg.local_rank] if cfg.device.type == "cuda" else None, output_device=cfg.local_rank if cfg.device.type == "cuda" else None, gradient_as_bucket_view=True, )
state = restore_if_present(model, optimizer, scaler, sampler, cfg) step = state.step
try: for epoch in range(state.epoch, cfg.max_epochs): sampler.set_epoch(epoch) timer_store: dict[str, list[float]] = {}
for batch_index, batch in enumerate(loader): loss = train_step(model, optimizer, batch, scaler, cfg, timer_store) step += 1
if cfg.is_main_rank and step % cfg.log_every == 0: metrics = summarize_timers(timer_store) log_rank_event( cfg, "train_progress", step=step, epoch=epoch, batch_index=batch_index, loss=round(loss, 4), avg_h2d_ms=round(metrics.get("h2d", 0.0) * 1000, 2), avg_forward_ms=round(metrics.get("forward", 0.0) * 1000, 2), avg_backward_ms=round(metrics.get("backward", 0.0) * 1000, 2), avg_optimizer_ms=round(metrics.get("optimizer", 0.0) * 1000, 2), )
if step % cfg.checkpoint_every_steps == 0: save_checkpoint(step, epoch, model, optimizer, scaler, sampler, cfg)
save_checkpoint(step, epoch, model, optimizer, scaler, sampler, cfg) finally: cleanup_dist()
if __name__ == "__main__": cfg = TrainConfig( world_size=int(os.environ.get("WORLD_SIZE", "1")), rank=int(os.environ.get("RANK", "0")), local_rank=int(os.environ.get("LOCAL_RANK", "0")), ) run(cfg)How To Explain The Code
Section titled “How To Explain The Code”The code is easier to defend if you explain it in layers rather than top to bottom.
1. Configuration
Section titled “1. Configuration”TrainConfig carries:
- process identity
- batch semantics
- I/O policy
- optimizer settings
- checkpoint settings
The strong sentence here is:
“I keep the effective batch semantics on the config because topology changes should not silently change training behavior.”
2. Dataset and sampler
Section titled “2. Dataset and sampler”The dataset itself is boring. That is fine. The interesting part is the sampler:
- it is rank-aware
- it supports deterministic reshuffling via
set_epoch - it stores
consumedprogress so resume semantics are explicit
This is one of the most important parts of the entire page. Many weak distributed-training answers talk about DDP and never talk about sampler correctness.
3. Distributed initialization
Section titled “3. Distributed initialization”The maybe_init_dist() function exists for one reason: separate local single-process behavior from distributed process-group behavior.
That lets you say:
- the training loop stays mostly the same
- the launcher changes the process context
- DDP is a wrapper over a process group, not a magical all-in-one training platform
4. Model wrapping
Section titled “4. Model wrapping”The DDP wrap is intentionally minimal:
- wrap after moving to the right device
- pass one device id per process for GPU training
- keep the baseline simple and explainable
This is the right place to mention that current PyTorch DDP docs still make clear that DDP handles gradient synchronization, not input sharding.
5. AMP
Section titled “5. AMP”The code uses autocast() and GradScaler.
Important talking points:
- autocast scopes the forward pass and loss computation
GradScalerprotects optimizer steps under mixed precision- AMP improves throughput and reduces memory, but it introduces another piece of state that must be checkpointed
6. Checkpointing
Section titled “6. Checkpointing”The checkpoint contains:
- model state
- optimizer state
- scaler state
- sampler state
- RNG state
- metadata
That is the right modern answer for a strong interview. Saving weights alone is not a recovery plan.
Data Correctness
Section titled “Data Correctness”This is the section where you sound senior instead of merely API-literate.
Invariants
Section titled “Invariants”- each rank sees a deterministic shard
- reshuffle order changes by epoch when
set_epoch()is called - the checkpoint captures enough information to restart without optimizer drift
- global batch math is explicit
Common correctness failures
Section titled “Common correctness failures”flowchart TD A[Incorrect training results] --> B[Duplicate samples across ranks] A --> C[Skipped samples after resume] A --> D[Changed batch semantics after topology shift] A --> E[Stale optimizer or scaler state]
What to say when challenged
Section titled “What to say when challenged”If an interviewer asks “how do you know this distributed trainer is correct?”, answer in this order:
- define correctness
- name the state that must survive
- explain the sample-partitioning invariant
- explain the resume validation checks
Performance Model
Section titled “Performance Model”Do not discuss performance as a single number. Discuss it by phase.
flowchart LR A[Data wait] --> B[H2D copy] B --> C[Forward] C --> D[Backward] D --> E[Gradient sync] E --> F[Optimizer step] F --> G[Checkpoint / side work]
Why the timer store matters
Section titled “Why the timer store matters”The timer store in the sample code supports the most important performance conversation:
- Are we input-bound?
- Are we communication-bound?
- Are we compute-bound?
- Are we stalling on checkpointing or artifact writes?
Common bottlenecks
Section titled “Common bottlenecks”| Symptom | Likely source | First move |
|---|---|---|
| GPU idle before forward | data loader, preprocessing, host-to-device copy | inspect loader workers, caching, prefetch |
| long backward tail | collective communication | inspect DDP sync cost, topology, bucket behavior |
| periodic step spikes | checkpoint or storage jitter | correlate spikes with save cadence |
| OOM after scaling | activation footprint or optimizer state | lower micro-batch, use checkpointing, reconsider parallelism |
Checkpointing and Recovery
Section titled “Checkpointing and Recovery”A mature answer distinguishes:
- checkpoint frequency
- checkpoint format
- checkpoint atomicity
- checkpoint restore policy
The restart model
Section titled “The restart model”The page’s code assumes a conservative default:
- full job restart
- latest good checkpoint
- same topology preferred
This is a good default for interviews because it is correct and easy to explain.
When to bring up Distributed Checkpoint
Section titled “When to bring up Distributed Checkpoint”Bring up torch.distributed.checkpoint when:
- model state is large enough that rank-parallel checkpointing matters
- you want to discuss resharding across topology changes
- the interviewer wants a current PyTorch-native answer beyond plain
torch.save
Resume validation checks
Section titled “Resume validation checks”After restore, validate:
- step monotonicity
- learning rate continuity
- optimizer state presence
- scaler state presence
- sampler progress continuity
- loss continuity over the next few steps
Observability
Section titled “Observability”You do not need a specific vendor story. You do need a telemetry story.
- structured logs on all ranks
- dense human-readable progress on rank 0
- explicit events for checkpoint save and restore
Metrics
Section titled “Metrics”- loss
- learning rate
- samples/sec
- step time by phase
- checkpoint latency
- restart count
- all-reduce time if you expose it
Traces
Section titled “Traces”Use traces when you want to correlate:
- launcher events
- storage stalls
- long-tail step latency
- checkpoint publishing
flowchart LR A[Trainer ranks] --> B[Structured logs] A --> C[Metrics] A --> D[Spans] B --> E[Log backend] C --> F[Metrics backend] D --> G[Trace backend] E --> H[Operator view] F --> H G --> H
Parallelism Tradeoffs
Section titled “Parallelism Tradeoffs”Start with DDP unless memory pressure proves it is insufficient.
- easiest to explain
- easiest to debug
- clearest baseline for shared-screen interviews
- helps when model state is the memory bottleneck
- introduces more checkpoint, wrapping, and state-dict complexity
- is worth mentioning, but not usually worth live-coding first
Activation checkpointing
Section titled “Activation checkpointing”- reduces memory pressure
- increases compute
- can affect determinism and runtime characteristics
The strong answer is never “use every optimization.” The strong answer is “introduce the next complexity only when the current bottleneck is clear.”
How To Present This In A Colab Notebook
Section titled “How To Present This In A Colab Notebook”If the exercise is in Colab, do not paste an entire giant module at once. Instead, split the same logic into notebook-friendly cells:
- config
- dataset and sampler
- model and optimizer
- distributed init helper
- train step
- checkpoint helpers
- main loop
This gives you narration points between cells and keeps the interviewer inside your reasoning.
Hard Questions And Strong Answers
Section titled “Hard Questions And Strong Answers””Why DDP instead of FSDP?”
Section titled “”Why DDP instead of FSDP?””Because DDP is the smallest correct distributed baseline. If the model fits, DDP keeps failure semantics and debugging simpler. I would move to FSDP when replicated model state becomes the bottleneck.
”What is the most dangerous silent failure?”
Section titled “”What is the most dangerous silent failure?””Silent sample duplication or omission across ranks, because the system can look healthy while learning on the wrong data distribution.
”What state must survive resume?”
Section titled “”What state must survive resume?””Model, optimizer, scaler, sampler, RNG, step metadata, and enough configuration identity to detect incompatible restores.
”How would you debug a slowdown that happens every 20 minutes?”
Section titled “”How would you debug a slowdown that happens every 20 minutes?””I would correlate the slowdown with checkpoint cadence, storage writes, and rank-level step-time breakdowns before assuming a model-side issue.
”What would you deliberately not build during the interview?”
Section titled “”What would you deliberately not build during the interview?””Cloud-specific auth, scheduler-specific orchestration, vendor-specific telemetry exporters, and advanced hybrid parallelism. I would preserve interfaces for them, but I would not burn interview time implementing them.
One Rehearsable Staff-Level Summary
Section titled “One Rehearsable Staff-Level Summary”If you need one closing paragraph:
“I built the smallest correct distributed trainer that still preserves production boundaries: rank-aware initialization, deterministic data partitioning, synchronized optimizer progress, resumable state, and observable step behavior. I would start with DDP because it maximizes clarity and correctness in a live exercise, then add FSDP, distributed checkpointing, or more complex topology only when memory, I/O, or communication data proves the simpler baseline is insufficient.”
Current PyTorch References
Section titled “Current PyTorch References”These are the official docs this page is aligned with: