ml-training-recipes

Original🇺🇸 English
Translated

Battle-tested PyTorch training recipes for all domains — LLMs, vision, diffusion, medical imaging, protein/drug discovery, spatial omics, genomics. Covers training loops, optimizer selection (AdamW, Muon), LR scheduling, mixed precision, debugging, and systematic experimentation. Use when training or fine-tuning neural networks, debugging loss spikes or OOM, choosing architectures, or optimizing GPU throughput.

2installs

NPX Install

npx skill4agent add orchestra-research/ai-research-skills ml-training-recipes

ML Training Recipes

Battle-tested patterns for PyTorch training across domains. Drawn from production codebases (Karpathy's autoresearch/nanochat, torchvision, HuggingFace) and modern training practice.

Reference files (read when needed)

  • references/architecture.md
    — Transformer/LLM architecture code patterns, weight init
  • references/optimizers.md
    — Muon, AdamW hybrid, per-group LR, compiled optimizer steps
  • references/domain-specific.md
    — Vision, diffusion, contrastive, distributed, checkpointing, data loading
  • references/scaling-and-selection.md
    — Scaling laws, compute budget tables, decision trees, DGX Spark
  • references/biomedical.md
    — Drug discovery, protein models, medical imaging, genomics, clinical NLP
  • references/experiment-loop.md
    — Autonomous experiment loop (autoresearch keep/discard/revert)

Architecture Selection

Pick the right model by data type and data scale:
Data Type< 10K samples10K-100K> 100K
ImagesPretrained CNN + fine-tuneFine-tune ViT or CNNViT from scratch
Text (gen)Few-shot promptingFine-tune GPT/LLaMA (LoRA)Pretrain from scratch
TabularXGBoost/LightGBMStill XGBoostNeural viable
AudioPretrained WhisperFine-tune ASTTrain from scratch
MoleculesPretrained GNNFine-tune molecular LMTrain GNN from scratch
ProteinsESM-2 embeddings + headFine-tune ESM-2Train protein LM
Medical imgPretrained CNNnnU-Net (auto-config)Swin-UNETR / MedSAM
Key principle: architecture matters less than training recipe at equal compute. A well-tuned ResNet beats a poorly-tuned ViT (ref: "ResNet Strikes Back", Wightman 2021).
For biomedical domains, see
references/biomedical.md
. For sequence model selection and compute planning, see
references/scaling-and-selection.md
.

Scaling Laws

Chinchilla rule (Hoffmann et al., 2022)

Compute-optimal training: ~20 tokens per parameter.
Model SizeCompute-OptimalInference-Optimal (100×)
125M2.5B tokens12.5B tokens
1B20B tokens100B tokens
7B140B tokens700B tokens
FLOPs ≈ 6 × N × D (N=params, D=tokens). Data repetition limit: ~4 epochs before diminishing returns.

Training Loop

python
import gc, time, torch

torch.manual_seed(42)
torch.set_float32_matmul_precision("high")  # TF32 on Ampere+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

grad_accum_steps = total_batch_size // (batch_size * seq_len)
step = 0

while not done:
    t0 = time.time()
    for micro_step in range(grad_accum_steps):
        with autocast_ctx:
            loss = model(x, y)
        (loss / grad_accum_steps).backward()
        x, y = next(train_loader)

    update_lr(optimizer, progress)
    optimizer.step()
    model.zero_grad(set_to_none=True)  # frees memory vs zeroing

    if loss.item() > 100:  # fast-fail on divergence
        print("FAIL: loss exploded"); exit(1)

    torch.cuda.synchronize()
    if step == 0:
        gc.collect(); gc.freeze(); gc.disable()  # avoid ~500ms GC stalls
    step += 1

Key principles

  • Gradient clipping:
    clip_grad_norm_(params, 1.0)
    — near-universal for Transformers. Exception: Muon optimizer normalizes updates via orthogonalization, so clipping is optional.
  • Tensor Core alignment: batch size, hidden dims should be multiples of 8 (bf16) or 64 (A100).
  • Time-based budgets make experiments comparable across hardware.
  • cudnn.benchmark = True
    for fixed-size vision inputs.

Optimizer Configuration

Modern LLM training uses different optimizers per parameter group:
Parameter TypeOptimizerLR (base)Weight Decay
2D weight matricesMuon0.040.2
Token embeddingsAdamW0.6 × scale0.0
Unembedding (lm_head)AdamW0.004 × scale0.0
Per-layer scalarsAdamW0.005 × scale0.0
LR scaling by dimension:
lr * (d_model / 768)^(-0.5)
— keeps dynamics stable across sizes.

Rules of thumb

  • Embeddings need higher LR (sparse updates). Never weight-decay embeddings.
  • Weight decay scheduling: linearly decay WD to 0 over training.
  • AdamW defaults: β1=0.9, β2=0.95, eps=1e-10 (not default 1e-8 — prevents stale updates in bf16).
For Muon details (polar express orthogonalization, NorMuon), see
references/optimizers.md
.

Learning Rate Scheduling

Time-based (autoresearch style)

python
def get_lr_multiplier(progress):  # progress = elapsed_time / time_budget
    if progress < warmup_ratio:
        return progress / warmup_ratio
    elif progress < 1.0 - warmdown_ratio:
        return 1.0
    else:
        cooldown = (1.0 - progress) / warmdown_ratio
        return cooldown + (1 - cooldown) * final_lr_frac

Cosine decay

python
def get_lr(step, total_steps, max_lr, min_lr, warmup_steps):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
WSD (Warmup-Stable-Decay): gaining traction — easier to resume training mid-run.

Guidance

  • Warmup: 1-5% of training. Zero warmup valid with Muon (autoresearch uses
    WARMUP_RATIO=0.0
    ).
  • Warmdown: 30-50% of training in LR decay. Matters more than warmup for final quality.
  • Final LR: 0 or ~10% of peak. Zero is simpler.

Mixed Precision & Compilation

python
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"  # before torch import

import torch
torch.set_float32_matmul_precision("high")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
model = torch.compile(model, dynamic=False)
  • bf16 (Ampere+): same exponent as fp32, no loss scaling needed. Preferred over fp16.
  • fp16: needs GradScaler. Use only on V100 or older.
  • dynamic=False
    enables max optimization. Add
    fullgraph=True
    if no graph breaks.
  • First steps are slow (JIT) — exclude from timing.

Memory & Performance

Meta device init (large models)

python
with torch.device("meta"):
    model = GPT(config)          # zero memory
model.to_empty(device="cuda")
model.init_weights()

MFU (Model FLOPs Utilization)

python
achieved_flops = model_flops_per_token * batch_tokens / step_time
mfu = achieved_flops / gpu_peak_flops
# H100 SXM: 989.5 TFLOPS | A100: 312 | RTX 4090: 165
Good targets: >30% decent, >40% good, >50% excellent (single-GPU).

OOM solutions (in order)

  1. Reduce
    DEVICE_BATCH_SIZE
    , increase
    grad_accum_steps
  2. PYTORCH_ALLOC_CONF=expandable_segments:True
  3. model.zero_grad(set_to_none=True)
  4. Meta device init →
    to_empty
  5. Activation checkpointing:
    torch.utils.checkpoint.checkpoint()
  6. 8-bit optimizer (bitsandbytes): ~30% savings on optimizer states

Hyperparameter Search

Priority order (tune first → last)

  1. Learning rate — most impactful. Always tune first.
  2. Batch size — largest that fits. Speed knob, not quality knob.
  3. Weight decay — 0.01-0.1 for AdamW.
  4. Warmup steps — 1-5% of training.

The 2025 default recipe

SettingValue
OptimizerAdamW (β1=0.9, β2=0.95, eps=1e-10)
Weight decay0.1
LR scheduleCosine decay or WSD
Peak LR3e-4 (scale down for larger models)
Precisionbf16
Grad clippingmax_norm=1.0
NormalizationRMSNorm (pre-norm)
ActivationSwiGLU
Position encodingRoPE
AttentionFlash Attention, optionally GQA

Debugging Checklist

Karpathy's recipe (still canonical)

  1. Become one with the data — visualize, check distributions, verify labels
  2. Get end-to-end running first — verify on a trivial case
  3. Overfit one batch — if you can't, you have a bug
  4. Then regularize — add regularization only after overfitting works
  5. Tune hyperparameters — start with known defaults

Loss exploding / NaN

  1. Reduce LR (3-10× smaller)
  2. Add gradient clipping:
    clip_grad_norm_(params, 1.0)
  3. Check for inf/nan in inputs
  4. Add logit soft capping:
    softcap * tanh(logits / softcap)
  5. Add QK-norm in attention
  6. Verify weight init (zero-init output projections?)
  7. Check loss reduction with gradient accumulation (
    loss / grad_accum_steps
    )

Slow training / Low MFU

  1. Verify
    torch.compile
    is active
  2. Check
    torch.set_float32_matmul_precision("high")
  3. Pin memory + non_blocking transfers
  4. Profile with
    torch.profiler
  5. GC stalls?
    gc.freeze(); gc.disable()
  6. Tensor Core alignment: dims multiples of 8/64

Loss plateau / Slow convergence

  1. LR too low — try 2-5× larger
  2. Warmup too long
  3. Weight decay too high
  4. Verify LR schedule is actually applied (print each step)
  5. Model too small for task

Silent failures

  1. Data leakage between train/val
  2. Wrong preprocessing at inference — augmentation mismatch
  3. Label errors — use cleanlab to detect
  4. Shuffling bugs — correlated batches
  5. Tokenizer mismatch with pretrained model

What to monitor

  • Gradient norms — spike precedes loss spike
  • Per-layer activation stats — reveals exploding/vanishing
  • Dead neurons — >50% zero ReLU = dying ReLU problem
  • Learning rate — verify schedule applied (common silent bug)

Experiment Management

Track experiments in TSV for easy comparison:
commit  val_bpb  memory_gb  status   description
a1b2c3d 0.9979   44.0       keep     baseline
b2c3d4e 0.9932   44.2       keep     increase matrix LR to 0.04
c3d4e5f 1.0050   44.0       discard  switch to GeLU (worse)
Simplicity criterion: all else equal, simpler is better. Removing something and getting equal results is a great outcome. For systematic agent-driven experimentation, see
references/experiment-loop.md
.

Evaluation metrics by domain

DomainPrimary MetricNotes
LLMBPB (bits per byte)Vocab-size-independent
ClassificationAccuracy / F1Macro-F1 for imbalanced
SegmentationmIoU / DicePer-class IoU reveals weak spots
GenerationFIDNeeds >10k samples
RegressionRMSE / MAELog-transform skewed targets