Skip to content

Checkpointing and Recovery

Checkpointing is not “save model weights every N steps.” It is the contract that defines your recovery point objective and the correctness of resumed training.

StateWhy it matters
model parametersobvious, but insufficient alone
optimizer stateneeded for momentum, adaptive moments, and learning-rate continuity
scheduler stateprevents LR jumps after resume
AMP scaler staterequired if mixed precision is in use
sampler progressprevents accidental replay / skip patterns
random seeds / RNG statenecessary for reproducibility-sensitive flows
metadatastep, epoch, git SHA, config digest, dataset snapshot identifier
sequenceDiagram
  participant Trainer as Trainer ranks
  participant Store as Checkpoint store
  participant Scheduler as Orchestrator
  Trainer->>Trainer: quiesce step boundary
  Trainer->>Store: write shard(s) + manifest
  Store-->>Trainer: commit confirmation
  Note over Trainer,Store: atomic publish of checkpoint id
  Scheduler->>Trainer: restart after failure
  Trainer->>Store: resolve latest good manifest
  Store-->>Trainer: model, optimizer, sampler, metadata
  Trainer->>Trainer: rebuild process groups and continue
Step-boundary saves with manifest-based publication are easier to reason about than ad hoc partial writes.
def build_checkpoint(step, model, optimizer, scheduler, scaler, sampler, cfg):
return {
"step": step,
"epoch": sampler.epoch,
"model": model.module.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if scheduler else None,
"scaler": scaler.state_dict() if scaler else None,
"sampler": sampler.state_dict(),
"rng": {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state_all(),
},
"metadata": {
"run_id": cfg.run_id,
"world_size": cfg.world_size,
"config_digest": cfg.digest(),
},
}
FormatGood forPain points
full checkpointsimple restore, easy debuggingexpensive I/O and large rank-0 pressure
sharded checkpointlarge-scale models, FSDP, faster distributed save patternsharder portability, more restore logic, format coordination

The current PyTorch Distributed Checkpoint documentation adds two especially useful ideas:

  • DCP saves and loads from multiple ranks in parallel.
  • DCP supports load-time resharding, which means a checkpoint saved under one cluster topology can be loaded under another.
  • DCP writes multiple files per checkpoint, typically at least one per rank, so it should not be modeled as a single monolithic blob.

That makes for a strong modern answer:

“If I needed topology-flexible restores, I would look at torch.distributed.checkpoint rather than treating distributed restart as a plain torch.save problem.”

flowchart TD
  A[Training interruption] --> B{Failure scope}
  B --> C[Rank-local process crash]
  B --> D[Whole node loss]
  B --> E[Object store outage]
  B --> F[Poisoned communicator / hang]
  C --> G[Usually full job restart in simple DDP]
  D --> G
  E --> H[Retry save, hold progress, alert]
  F --> I[Tear down group, restart cleanly]
The simplest correct answer in many environments is full job restart from latest good checkpoint.

“I prefer a full-job restart model unless the platform already provides proven elastic semantics, because partial member replacement in synchronous training can create subtle state divergence.”

  • step count monotonicity
  • optimizer-state presence
  • sampler position consistency
  • learning rate continuity
  • loss curve continuity after one or two steps
  • metrics tags showing a resume event