Checkpointing and Recovery
Checkpointing is not “save model weights every N steps.” It is the contract that defines your recovery point objective and the correctness of resumed training.
Minimum Recoverable State
Section titled “Minimum Recoverable State”| State | Why it matters |
|---|---|
| model parameters | obvious, but insufficient alone |
| optimizer state | needed for momentum, adaptive moments, and learning-rate continuity |
| scheduler state | prevents LR jumps after resume |
| AMP scaler state | required if mixed precision is in use |
| sampler progress | prevents accidental replay / skip patterns |
| random seeds / RNG state | necessary for reproducibility-sensitive flows |
| metadata | step, epoch, git SHA, config digest, dataset snapshot identifier |
Save and Restore Flow
Section titled “Save and Restore Flow”sequenceDiagram participant Trainer as Trainer ranks participant Store as Checkpoint store participant Scheduler as Orchestrator Trainer->>Trainer: quiesce step boundary Trainer->>Store: write shard(s) + manifest Store-->>Trainer: commit confirmation Note over Trainer,Store: atomic publish of checkpoint id Scheduler->>Trainer: restart after failure Trainer->>Store: resolve latest good manifest Store-->>Trainer: model, optimizer, sampler, metadata Trainer->>Trainer: rebuild process groups and continue
A Practical Checkpoint Contract
Section titled “A Practical Checkpoint Contract”def build_checkpoint(step, model, optimizer, scheduler, scaler, sampler, cfg): return { "step": step, "epoch": sampler.epoch, "model": model.module.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if scheduler else None, "scaler": scaler.state_dict() if scaler else None, "sampler": sampler.state_dict(), "rng": { "python": random.getstate(), "numpy": np.random.get_state(), "torch": torch.get_rng_state(), "cuda": torch.cuda.get_rng_state_all(), }, "metadata": { "run_id": cfg.run_id, "world_size": cfg.world_size, "config_digest": cfg.digest(), }, }Full vs Sharded Checkpoints
Section titled “Full vs Sharded Checkpoints”| Format | Good for | Pain points |
|---|---|---|
| full checkpoint | simple restore, easy debugging | expensive I/O and large rank-0 pressure |
| sharded checkpoint | large-scale models, FSDP, faster distributed save patterns | harder portability, more restore logic, format coordination |
Current PyTorch Notes
Section titled “Current PyTorch Notes”The current PyTorch Distributed Checkpoint documentation adds two especially useful ideas:
- DCP saves and loads from multiple ranks in parallel.
- DCP supports load-time resharding, which means a checkpoint saved under one cluster topology can be loaded under another.
- DCP writes multiple files per checkpoint, typically at least one per rank, so it should not be modeled as a single monolithic blob.
That makes for a strong modern answer:
“If I needed topology-flexible restores, I would look at
torch.distributed.checkpointrather than treating distributed restart as a plaintorch.saveproblem.”
Failure Taxonomy
Section titled “Failure Taxonomy”
flowchart TD
A[Training interruption] --> B{Failure scope}
B --> C[Rank-local process crash]
B --> D[Whole node loss]
B --> E[Object store outage]
B --> F[Poisoned communicator / hang]
C --> G[Usually full job restart in simple DDP]
D --> G
E --> H[Retry save, hold progress, alert]
F --> I[Tear down group, restart cleanly]
The Sentence That Sounds Experienced
Section titled “The Sentence That Sounds Experienced”“I prefer a full-job restart model unless the platform already provides proven elastic semantics, because partial member replacement in synchronous training can create subtle state divergence.”
What To Validate After Resume
Section titled “What To Validate After Resume”- step count monotonicity
- optimizer-state presence
- sampler position consistency
- learning rate continuity
- loss curve continuity after one or two steps
- metrics tags showing a resume event