Petabyte-Scale Image Training With DDP
This page is the image counterpart to the molecular walkthrough: same DDP baseline, very different data plane. At petabyte scale the model is often not the first problem. File layout, cache strategy, decode placement, and transform economics decide whether the GPUs train or wait.
The scenario below is also hypothetical. Use it as a production-shaped example, not as a claim about one specific public dataset.
Reference Scenario
Section titled “Reference Scenario”Assume a web-scale vision pretraining corpus with:
- 9.5 billion images
- 1.4 PB of compressed JPEG and WebP objects
- captions or class labels stored as sidecar metadata
- a ViT-sized model that still fits per GPU, so plain DDP remains the simplest correct trainer
Example sample schema
Section titled “Example sample schema”| Field | Type | Why it exists |
|---|---|---|
sample_id | string | Stable dedupe and replay key |
image_uri | string | Object-store path to compressed bytes |
caption | string | Optional contrastive text supervision |
label | int or null | Classification target if present |
width | int | Precomputed shape metadata for bucketing and QA |
height | int | Precomputed shape metadata for bucketing and QA |
mime_type | string | JPEG, PNG, or WebP handling |
quality_flags | bitfield | Corrupt decode, NSFW filter, duplicate detection, etc. |
shard_id | string | Training scheduling and cache accounting |
The Central Design Decision
Section titled “The Central Design Decision”Do not train from billions of individual files if you can avoid it.
The real system wants:
- immutable medium-sized shards
- sequential reads
- predictable cache behavior
- enough metadata in the manifest to route work without opening every image
Config Management Has To Encode The Data Contract
Section titled “Config Management Has To Encode The Data Contract”At petabyte scale, config management is not just about learning rate and batch size. It is how you make a run explainable later.
| Config area | Example fields | Why it matters |
|---|---|---|
| topology | world_size, micro_batch_size, grad_accum_steps | changes optimizer cadence and effective global batch |
| dataset definition | manifest_uri, manifest_version, filter_policy | binds the run to one corpus view |
| featurization | decode_backend, crop_size, augmentation_policy | controls throughput and train/eval comparability |
| cache and I/O | cache_dir, prefetch_depth, max_open_shards | determines whether storage can keep up with the trainer |
| evaluation policy | eval_interval_steps, val_manifest_version, top_k_list | ties decision-making to a stable validation contract |
Example config shape
Section titled “Example config shape”from dataclasses import dataclass
@dataclass(frozen=True)class VisionTrainConfig: run_name: str seed: int world_size: int micro_batch_size: int grad_accum_steps: int manifest_uri: str manifest_version: str decode_backend: str crop_size: int augmentation_policy: str cache_dir: str eval_interval_steps: int val_manifest_version: strThe sentence I would say directly is:
“If I cannot reconstruct the exact corpus view, transform policy, and topology from config plus checkpoint metadata, the run is not really reproducible.”
flowchart LR A[Config payload] --> B[Manifest version] A --> C[Transform policy] A --> D[Topology and batch math] B --> E[Run metadata] C --> E D --> E E --> F[Checkpoint lineage] F --> G[Resume or postmortem]
End-To-End Data Path
Section titled “End-To-End Data Path”flowchart LR A[Raw image objects] --> B[Offline filtering + dedupe] B --> C[Shard packing + manifest build] C --> D[Rank-aware shard assignment] D --> E[Local NVMe cache] E --> F[CPU or GPU decode] F --> G[Crop / resize / normalize] G --> H[Pinned host memory] H --> I[DDP trainer rank 0..N] I --> J[Checkpoint + metrics]
Storage Layout
Section titled “Storage Layout”Recommended physical format
Section titled “Recommended physical format”- shard compressed images into tar, parquet, or another sequential container
- target roughly 512 MB to 2 GB per shard, depending on cache and network behavior
- keep sidecar metadata with the image bytes or in a co-located index file
- avoid millions of tiny remote object requests in steady-state training
- include row counts and checksums so incomplete shards can be rejected early
Why shard instead of file-per-image
Section titled “Why shard instead of file-per-image”| Problem with file-per-image | What sharding fixes |
|---|---|
| object-store metadata overhead | one request fetches many samples |
| poor throughput under high fanout | sequential or batched reads improve bandwidth utilization |
| expensive local cache indexing | cache at shard granularity |
| resume is hard to reason about | shards provide a natural replay unit |
What To Precompute Offline
Section titled “What To Precompute Offline”Large image jobs benefit from more offline prep than many teams expect.
| Step | Offline or online | Why |
|---|---|---|
| dedupe and content filtering | offline | too expensive and too important to repeat |
| corrupt-file detection | offline | decode failures in hot path cause rank skew |
| width / height extraction | offline | useful for sampling and crop diagnostics |
| shard packing | offline | critical for object-store efficiency |
| heavy resizing to many variants | sometimes offline | worth it if the same sizes are reused repeatedly |
| final stochastic crop / flip | online | cheap, training-specific, and intentionally random |
| normalization | online | trivial tensor work |
The clean mental split is:
move expensive global hygiene offline; keep cheap training-time randomness online
Featurization and Decode Placement
Section titled “Featurization and Decode Placement”For images, “featurization” usually means decode plus transforms.
CPU-first path
Section titled “CPU-first path”- local worker reads compressed bytes from shard cache
torchvision.io.decode_imageor PIL decodes into tensor form- worker applies resize and crop
pin_memory=Truelets host-to-device copies overlap better
GPU-heavier path
Section titled “GPU-heavier path”- keep bytes compressed until late
- decode or resize with GPU-aware tooling when CPU decode becomes the bottleneck
- move simple color or normalization work onto device
Use the CPU-first path until you have evidence it is the bottleneck. It is easier to debug and usually good enough for moderate cluster sizes.
Rank-Aware Shard Streaming
Section titled “Rank-Aware Shard Streaming”At this scale, an IterableDataset is usually a better fit than pretending the corpus is a local random-access array.
from __future__ import annotations
import ioimport randomimport tarfilefrom dataclasses import dataclass
import torchfrom PIL import Imagefrom torch.utils.data import IterableDataset
@dataclassclass StreamState: epoch: int = 0 shard_offset: int = 0 sample_offset: int = 0
class ImageShardDataset(IterableDataset): def __init__(self, manifest, rank: int, world_size: int, seed: int, state: StreamState | None = None): self.manifest = manifest self.rank = rank self.world_size = world_size self.seed = seed self.state = state or StreamState()
def state_dict(self) -> dict[str, int]: return { "epoch": self.state.epoch, "shard_offset": self.state.shard_offset, "sample_offset": self.state.sample_offset, }
def set_epoch(self, epoch: int) -> None: self.state = StreamState(epoch=epoch)
def _shards_for_rank(self): rng = random.Random(self.seed + self.state.epoch) shards = list(self.manifest) rng.shuffle(shards) return shards[self.rank :: self.world_size]
def __iter__(self): shards = self._shards_for_rank() for shard_idx, shard in enumerate(shards[self.state.shard_offset :], start=self.state.shard_offset): with open_local_cached_shard(shard["uri"]) as fp: with tarfile.open(fileobj=fp, mode="r|*") as tf: sample_idx = 0 for member in tf: if not member.isfile() or not member.name.endswith(".jpg"): continue if shard_idx == self.state.shard_offset and sample_idx < self.state.sample_offset: sample_idx += 1 continue image = Image.open(io.BytesIO(tf.extractfile(member).read())).convert("RGB") self.state.shard_offset = shard_idx self.state.sample_offset = sample_idx sample_idx += 1 yield {"image": image, "label": shard["label_lookup"].get(member.name, -1)} self.state.sample_offset = 0That code is simplified, but the important bit is the contract:
- deterministic shard ownership by rank
- explicit replay state
- local cache boundary hidden behind
open_local_cached_shard()
Caching Strategy
Section titled “Caching Strategy”If every worker reads every shard straight from remote storage, your cluster becomes an object-store benchmark.
Better pattern
Section titled “Better pattern”- cache shards on node-local NVMe
- download ahead in a background thread or helper process
- keep cache accounting per node, not per worker, so workers share data
- evict by shard, not by file
- record cache hit rate as a first-class metric
For a training cluster, local cache policy is part of the data plane, not an implementation detail.
Reproducibility Means Stable Contracts, Not Magic
Section titled “Reproducibility Means Stable Contracts, Not Magic”Image pipelines include stochastic crops, shuffle order, cache timing, and occasionally non-deterministic decode behavior. So be exact about the contract.
What I would pin
Section titled “What I would pin”- code revision or image digest
- training and validation manifest versions
- augmentation policy version
- crop size and normalization constants
- global seed and per-epoch shuffle seed rule
- checkpoint lineage and resume step
What I would not over-promise
Section titled “What I would not over-promise”- bitwise-identical results across different GPU counts
- identical sample order if a streaming source is intentionally at-least-once
- perfectly identical timing once cache warmup or remote storage conditions change
The senior framing is:
“I optimize for explainable and statistically repeatable runs, and I reserve exact replay for controlled debugging paths.”
flowchart TD
A[Resume request] --> B{Same val and train manifests?}
B -->|No| C[New run or offline comparison only]
B -->|Yes| D{Same topology and transform policy?}
D -->|Yes| E[Closer replay path]
D -->|No| F[Statistical continuation]
E --> G[Resume checkpoint]
F --> G
Efficient Batch Construction
Section titled “Efficient Batch Construction”Vision workloads often waste performance in one of two ways:
- tiny batches because decode is slow
- expensive dynamic shapes because crops or aspect ratios are unmanaged
Practical batching rules
Section titled “Practical batching rules”- normalize final tensor shapes before the model boundary
- use
drop_last=Trueso DDP sees even work in steady state - if aspect ratio matters, use a bounded set of crop buckets instead of arbitrary image shapes
- prefetch enough batches that short storage hiccups do not starve the device
import torchimport torchvision.transforms.v2 as T
train_tfms = T.Compose( [ T.ToImage(), T.RandomResizedCrop(size=(224, 224), antialias=True), T.RandomHorizontalFlip(p=0.5), T.ToDtype(torch.float32, scale=True), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ])
def collate_images(batch): images = torch.stack([train_tfms(sample["image"]) for sample in batch]) labels = torch.tensor([sample["label"] for sample in batch], dtype=torch.long) return {"images": images, "labels": labels}If your transforms are much heavier than this, measure whether they belong offline or on a different execution device.
DataLoader Settings That Usually Pay Off
Section titled “DataLoader Settings That Usually Pay Off”| Setting | Why it helps |
|---|---|
persistent_workers=True | avoid repeated worker startup cost |
prefetch_factor | smooth short stalls between storage, decode, and train |
pin_memory=True | better host-to-device transfer overlap |
drop_last=True | keep per-rank step counts aligned |
num_workers tuning | balance decode throughput against CPU oversubscription |
These do not fix bad storage layout, but they matter once the basics are correct.
DDP Training Loop
Section titled “DDP Training Loop”from contextlib import nullcontext
import torchfrom torch.amp import GradScaler, autocast
def train_epoch(model, loader, optimizer, scaler: GradScaler, cfg): model.train() optimizer.zero_grad(set_to_none=True)
for step, batch in enumerate(loader): images = batch["images"].to( cfg.device, non_blocking=True, memory_format=torch.channels_last, ) labels = batch["labels"].to(cfg.device, non_blocking=True)
sync_ctx = model.no_sync() if (step + 1) % cfg.grad_accum_steps != 0 else nullcontext() with sync_ctx: with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=cfg.use_amp): logits = model(images) loss = torch.nn.functional.cross_entropy(logits, labels) loss = loss / cfg.grad_accum_steps scaler.scale(loss).backward()
if (step + 1) % cfg.grad_accum_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True)What to say out loud
Section titled “What to say out loud”- DDP does synchronization; it does not solve image ingest.
channels_lastoften helps convolution-heavy models.- AMP is a throughput lever, but only after data starvation is under control.
- Gradient accumulation changes optimizer cadence and should be explained, not hidden.
Evaluation Design For Vision Training
Section titled “Evaluation Design For Vision Training”At petabyte scale, evaluation needs to be smaller, cleaner, and much more stable than the training corpus.
Validation design
Section titled “Validation design”- freeze a versioned validation manifest that is never mixed with train shards
- keep validation transforms deterministic, typically resize plus center crop
- maintain at least one cheap smoke-eval set and one more representative holdout
- if the job is multimodal, version the text side of the validation data too
Metrics by objective
Section titled “Metrics by objective”| Objective | Better metric set | Why one scalar is not enough |
|---|---|---|
| image classification | top-1, top-5, per-class recall | long-tail classes disappear in a global average |
| multimodal contrastive | retrieval recall@K, median rank | loss alone does not expose retrieval usefulness |
| ranking or recommendation | NDCG, recall@K, slice metrics by traffic segment | deployment quality depends on who the model fails for |
| generative or representation | downstream probe metrics and drift slices | proxy loss may improve while representation quality drops |
Evaluation cadence
Section titled “Evaluation cadence”- run frequent small validation passes to catch regressions early
- run larger validation passes less often to reduce cluster interruption
- keep eval throughput separate from train throughput so the cost is visible
flowchart LR A[Versioned train shards] --> B[DDP train loop] C[Versioned validation shards] --> D[Deterministic val transforms] B --> E[Checkpoint] E --> F[Smoke eval] E --> G[Full holdout eval] D --> F D --> G F --> H[Fast guardrail decision] G --> I[Promotion or rollback decision]
Interview Language For Metrics
Section titled “Interview Language For Metrics”When asked how you monitor the system, split the answer into layers:
- data-plane health: read bandwidth, cache hit rate, decode time, corrupt-sample rate
- trainer health: images/sec, batch wait, all-reduce fraction, OOM and restart counts
- model quality: top-1 or recall@K, per-slice performance, calibration or drift checks
That is stronger than saying “we watch loss and GPU utilization.”
flowchart TD A[Monitoring stack] --> B[Data-plane health] A --> C[Trainer health] A --> D[Model quality] B --> E[read BW, cache hit, decode time] C --> F[images/sec, wait, all-reduce, restarts] D --> G[top-1, recall@K, slice drift]
torch.compile Guidance
Section titled “torch.compile Guidance”Image models are often better compile candidates than molecular sequence models because shapes can be made more regular.
Good conditions for torch.compile:
- fixed final crop size
- stable forward graph
- limited control flow
- enough steady-state steps to amortize compile overhead
Bad conditions:
- lots of dynamic aspect-ratio branches in-model
- constant re-specialization from unconstrained shapes
- debugging a new training job where compile obscures basic failures
Checkpointing and Resume
Section titled “Checkpointing and Resume”Petabyte-scale datasets make epoch-only checkpoints a weak story because an epoch may be extremely long.
Checkpoint at step intervals and store:
- model, optimizer, and scaler state
- logical epoch
- dataset stream state or sampler state
- manifest version
- any cache or prefetch cursor state that affects replay semantics
For the data plane, at-least-once with bounded duplication is often a more honest target than pretending you have perfectly exact sample replay on a giant streaming corpus.
Metrics That Distinguish Real Bottlenecks
Section titled “Metrics That Distinguish Real Bottlenecks”| Metric | Why it matters |
|---|---|
| images/sec per rank | top-line throughput |
| object-store read bandwidth | tells you whether storage is the actual limiter |
| cache hit rate | proves local staging is doing useful work |
| decode time / batch | separates CPU pressure from model compute |
| batch wait time | direct signal for GPU starvation |
| step time skew across ranks | exposes corrupt shards, slow nodes, or uneven work |
| dropped or corrupt sample rate | catches upstream data quality regressions |
| all-reduce fraction of step | tells you when communication tuning is worth your time |
| validation freshness | tells you whether current quality signals still reflect the run |
Common Failure Modes
Section titled “Common Failure Modes”
flowchart TD
A[GPU utilization falls] --> B{Primary symptom}
B --> C[High batch wait]
B --> D[High decode time]
B --> E[Rank skew]
B --> F[OOM after increasing batch]
C --> G[Storage or cache pipeline issue]
D --> H[Too few workers or expensive transforms]
E --> I[Corrupt shards or uneven input]
F --> J[Activation or optimizer-state pressure]
When To Escalate Beyond DDP
Section titled “When To Escalate Beyond DDP”If the model stops fitting, or optimizer state dominates device memory, then DDP may no longer be sufficient. Until that point, resist the urge to complicate the trainer.
The first escalation is often not FSDP. It is:
- better shard layout
- better cache hit rate
- faster decode path
- better crop bucketing
- cleaner observability
Only after the model-state problem is real should you add model-state complexity.
Full Reference Implementation
Section titled “Full Reference Implementation”This is the full article in code form: config, shard streaming, local-cache boundary, transforms, DDP setup, evaluation, and checkpointing in one script.
from __future__ import annotations
import ioimport jsonimport osimport randomimport tarfileimport timefrom contextlib import nullcontextfrom dataclasses import asdict, dataclassfrom pathlib import Path
import numpy as npimport torchimport torch.distributed as distimport torch.nn as nnimport torch.nn.functional as Fimport torchvision.models as tvmimport torchvision.transforms.v2 as Tfrom PIL import Imagefrom torch.amp import GradScaler, autocastfrom torch.nn.parallel import DistributedDataParallel as DDPfrom torch.utils.data import DataLoader, IterableDataset
@dataclassclass TrainConfig: backend: str = "nccl" seed: int = 17 world_size: int = int(os.environ.get("WORLD_SIZE", "1")) rank: int = int(os.environ.get("RANK", "0")) local_rank: int = int(os.environ.get("LOCAL_RANK", "0")) micro_batch_size: int = 128 grad_accum_steps: int = 2 max_epochs: int = 2 max_steps: int = 10_000 image_size: int = 224 num_classes: int = 1_000 learning_rate: float = 2e-4 weight_decay: float = 0.05 eval_interval_steps: int = 500 checkpoint_interval_steps: int = 500 train_manifest_path: str = "manifests/train_images.json" val_manifest_path: str = "manifests/val_images.json" cache_dir: str = "/tmp/tcp-image-cache" checkpoint_dir: str = "artifacts/checkpoints/images" use_amp: bool = True
@property def is_distributed(self) -> bool: return self.world_size > 1
@property def is_main_rank(self) -> bool: return self.rank == 0
@property def device(self) -> torch.device: if torch.cuda.is_available(): return torch.device(f"cuda:{self.local_rank}") return torch.device("cpu")
@property def global_batch_size(self) -> int: return self.micro_batch_size * self.grad_accum_steps * self.world_size
@dataclassclass StreamState: epoch: int = 0 shard_offset: int = 0 sample_offset: int = 0
def seed_everything(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def setup_distributed(cfg: TrainConfig) -> None: if not cfg.is_distributed: return if torch.cuda.is_available(): torch.cuda.set_device(cfg.local_rank) dist.init_process_group(backend=cfg.backend, rank=cfg.rank, world_size=cfg.world_size)
def cleanup_distributed() -> None: if dist.is_available() and dist.is_initialized(): dist.destroy_process_group()
def load_manifest(path: str) -> list[dict]: return json.loads(Path(path).read_text())
def ensure_local_shard(uri: str, cache_dir: str) -> Path: cache_root = Path(cache_dir) cache_root.mkdir(parents=True, exist_ok=True) destination = cache_root / Path(uri).name if destination.exists(): return destination source = Path(uri.replace("file://", "")) destination.write_bytes(source.read_bytes()) return destination
class ImageShardDataset(IterableDataset): def __init__( self, manifest: list[dict], rank: int, world_size: int, seed: int, cache_dir: str, state: StreamState | None = None, ) -> None: self.manifest = manifest self.rank = rank self.world_size = world_size self.seed = seed self.cache_dir = cache_dir self.state = state or StreamState()
def state_dict(self) -> dict[str, int]: return asdict(self.state)
def load_state_dict(self, state: dict[str, int]) -> None: self.state = StreamState(**state)
def set_epoch(self, epoch: int) -> None: self.state = StreamState(epoch=epoch)
def _rank_shards(self) -> list[dict]: shards = list(self.manifest) random.Random(self.seed + self.state.epoch).shuffle(shards) return shards[self.rank :: self.world_size]
def __iter__(self): shards = self._rank_shards() for shard_idx, shard in enumerate( shards[self.state.shard_offset :], start=self.state.shard_offset ): local_path = ensure_local_shard(shard["uri"], self.cache_dir) with tarfile.open(local_path, mode="r:*") as archive: sample_idx = 0 for member in archive: if not member.isfile() or not member.name.lower().endswith((".jpg", ".jpeg", ".png", ".webp")): continue if shard_idx == self.state.shard_offset and sample_idx < self.state.sample_offset: sample_idx += 1 continue handle = archive.extractfile(member) if handle is None: continue image = Image.open(io.BytesIO(handle.read())).convert("RGB") self.state.shard_offset = shard_idx self.state.sample_offset = sample_idx sample_idx += 1 yield { "image": image, "label": int(shard["label_lookup"].get(member.name, -1)), } self.state.sample_offset = 0
def build_transforms(image_size: int, training: bool) -> T.Compose: if training: return T.Compose( [ T.ToImage(), T.RandomResizedCrop(size=(image_size, image_size), antialias=True), T.RandomHorizontalFlip(p=0.5), T.ToDtype(torch.float32, scale=True), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] ) return T.Compose( [ T.ToImage(), T.Resize(size=image_size + 32, antialias=True), T.CenterCrop(size=(image_size, image_size)), T.ToDtype(torch.float32, scale=True), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] )
def make_collate(transform: T.Compose): def collate(batch: list[dict]) -> dict[str, torch.Tensor]: images = torch.stack([transform(sample["image"]) for sample in batch]) labels = torch.tensor([sample["label"] for sample in batch], dtype=torch.long) return {"images": images, "labels": labels}
return collate
def make_loader(dataset: IterableDataset, cfg: TrainConfig, training: bool) -> DataLoader: return DataLoader( dataset, batch_size=cfg.micro_batch_size, num_workers=8, prefetch_factor=4, persistent_workers=True, pin_memory=torch.cuda.is_available(), collate_fn=make_collate(build_transforms(cfg.image_size, training=training)), drop_last=training, )
def unwrap(model: nn.Module) -> nn.Module: return model.module if hasattr(model, "module") else model
def save_checkpoint( model: nn.Module, optimizer: torch.optim.Optimizer, scaler: GradScaler, dataset: ImageShardDataset, cfg: TrainConfig, epoch: int, global_step: int,) -> None: if not cfg.is_main_rank: return checkpoint_dir = Path(cfg.checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) payload = { "model": unwrap(model).state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "dataset_state": dataset.state_dict(), "epoch": epoch, "global_step": global_step, "config": asdict(cfg), } torch.save(payload, checkpoint_dir / f"step-{global_step:08d}.pt")
def restore_if_available( model: nn.Module, optimizer: torch.optim.Optimizer, scaler: GradScaler, dataset: ImageShardDataset, cfg: TrainConfig,) -> tuple[int, int]: checkpoint_dir = Path(cfg.checkpoint_dir) if not checkpoint_dir.exists(): return 0, 0 checkpoints = sorted(checkpoint_dir.glob("step-*.pt")) if not checkpoints: return 0, 0 payload = torch.load(checkpoints[-1], map_location="cpu") unwrap(model).load_state_dict(payload["model"]) optimizer.load_state_dict(payload["optimizer"]) scaler.load_state_dict(payload["scaler"]) dataset.load_state_dict(payload["dataset_state"]) return int(payload["epoch"]), int(payload["global_step"])
@torch.no_grad()def evaluate(model: nn.Module, loader: DataLoader, cfg: TrainConfig) -> dict[str, float]: model.eval() loss_sum = torch.zeros(1, device=cfg.device) count_sum = torch.zeros(1, device=cfg.device) correct_sum = torch.zeros(1, device=cfg.device)
for batch in loader: images = batch["images"].to(cfg.device, non_blocking=True, memory_format=torch.channels_last) labels = batch["labels"].to(cfg.device, non_blocking=True) logits = model(images) loss = F.cross_entropy(logits, labels) preds = logits.argmax(dim=-1) loss_sum += loss count_sum += labels.numel() correct_sum += (preds == labels).sum()
if cfg.is_distributed: dist.all_reduce(loss_sum) dist.all_reduce(count_sum) dist.all_reduce(correct_sum)
return { "val_loss": (loss_sum / max(len(loader), 1)).item(), "top1": (correct_sum / count_sum.clamp_min(1)).item(), }
def train(cfg: TrainConfig) -> None: seed_everything(cfg.seed) setup_distributed(cfg)
train_dataset = ImageShardDataset( manifest=load_manifest(cfg.train_manifest_path), rank=cfg.rank, world_size=cfg.world_size, seed=cfg.seed, cache_dir=cfg.cache_dir, ) val_dataset = ImageShardDataset( manifest=load_manifest(cfg.val_manifest_path), rank=cfg.rank, world_size=cfg.world_size, seed=cfg.seed + 1000, cache_dir=cfg.cache_dir, ) train_loader = make_loader(train_dataset, cfg, training=True) val_loader = make_loader(val_dataset, cfg, training=False)
model = tvm.resnet50(num_classes=cfg.num_classes).to( cfg.device, memory_format=torch.channels_last ) model = torch.compile(model) if hasattr(torch, "compile") else model if cfg.is_distributed: model = DDP( model, device_ids=[cfg.local_rank] if torch.cuda.is_available() else None, output_device=cfg.local_rank if torch.cuda.is_available() else None, gradient_as_bucket_view=True, static_graph=True, broadcast_buffers=False, )
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) scaler = GradScaler("cuda", enabled=cfg.use_amp and torch.cuda.is_available()) start_epoch, global_step = restore_if_available(model, optimizer, scaler, train_dataset, cfg)
for epoch in range(start_epoch, cfg.max_epochs): train_dataset.set_epoch(epoch) model.train() optimizer.zero_grad(set_to_none=True) for batch_idx, batch in enumerate(train_loader): step_start = time.perf_counter() images = batch["images"].to( cfg.device, non_blocking=True, memory_format=torch.channels_last, ) labels = batch["labels"].to(cfg.device, non_blocking=True)
sync_ctx = ( model.no_sync() if cfg.is_distributed and (batch_idx + 1) % cfg.grad_accum_steps != 0 else nullcontext() ) with sync_ctx: with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=cfg.use_amp and torch.cuda.is_available()): logits = model(images) loss = F.cross_entropy(logits, labels) / cfg.grad_accum_steps scaler.scale(loss).backward()
if (batch_idx + 1) % cfg.grad_accum_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) global_step += 1
if cfg.is_main_rank and global_step % 50 == 0: step_time = time.perf_counter() - step_start print( json.dumps( { "step": global_step, "loss": round(loss.item() * cfg.grad_accum_steps, 5), "images_per_sec": round( cfg.micro_batch_size * cfg.world_size / max(step_time, 1e-6), 2, ), "cache_dir": cfg.cache_dir, } ) )
if global_step % cfg.eval_interval_steps == 0: metrics = evaluate(model, val_loader, cfg) if cfg.is_main_rank: print(json.dumps({"step": global_step, **metrics})) model.train()
if global_step % cfg.checkpoint_interval_steps == 0: save_checkpoint(model, optimizer, scaler, train_dataset, cfg, epoch, global_step)
if global_step >= cfg.max_steps: save_checkpoint(model, optimizer, scaler, train_dataset, cfg, epoch, global_step) cleanup_distributed() return
save_checkpoint(model, optimizer, scaler, train_dataset, cfg, cfg.max_epochs - 1, global_step) cleanup_distributed()
if __name__ == "__main__": train(TrainConfig())Interview-Ready Summary
Section titled “Interview-Ready Summary”- Pack images into immutable sequential shards and train from manifests, not individual files.
- Cache shards on local NVMe and treat cache hit rate as a core production metric.
- Decode and crop efficiently enough that DDP workers stay fed before worrying about exotic model parallelism.
- Use fixed final image shapes, AMP, and
channels_lastto get more from PyTorch once the input path is healthy. - Save stream position alongside weights or restarts will be operationally expensive and logically ambiguous.