ml-training-recipes

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

ML Training Recipes

机器学习训练方案

Battle-tested patterns for PyTorch training across domains. Drawn from production codebases (Karpathy's autoresearch/nanochat, torchvision, HuggingFace) and modern training practice.
经过实战检验的跨领域PyTorch训练模式,源自生产代码库(Karpathy的autoresearch/nanochat、torchvision、HuggingFace)和现代训练实践。

Reference files (read when needed)

参考文件(按需阅读)

  • references/architecture.md
    — Transformer/LLM architecture code patterns, weight init
  • references/optimizers.md
    — Muon, AdamW hybrid, per-group LR, compiled optimizer steps
  • references/domain-specific.md
    — Vision, diffusion, contrastive, distributed, checkpointing, data loading
  • references/scaling-and-selection.md
    — Scaling laws, compute budget tables, decision trees, DGX Spark
  • references/biomedical.md
    — Drug discovery, protein models, medical imaging, genomics, clinical NLP
  • references/experiment-loop.md
    — Autonomous experiment loop (autoresearch keep/discard/revert)

  • references/architecture.md
    — Transformer/LLM架构代码模式、权重初始化
  • references/optimizers.md
    — Muon、AdamW混合优化器、分组学习率、编译优化器步骤
  • references/domain-specific.md
    — 计算机视觉、扩散模型、对比学习、分布式训练、 checkpointing、数据加载
  • references/scaling-and-selection.md
    — 缩放法则、算力预算表、决策树、DGX Spark
  • references/biomedical.md
    — 药物研发、蛋白质模型、医学影像、基因组学、临床NLP
  • references/experiment-loop.md
    — 自主实验循环(autoresearch保留/丢弃/回退机制)

Architecture Selection

架构选择

Pick the right model by data type and data scale:
Data Type< 10K samples10K-100K> 100K
ImagesPretrained CNN + fine-tuneFine-tune ViT or CNNViT from scratch
Text (gen)Few-shot promptingFine-tune GPT/LLaMA (LoRA)Pretrain from scratch
TabularXGBoost/LightGBMStill XGBoostNeural viable
AudioPretrained WhisperFine-tune ASTTrain from scratch
MoleculesPretrained GNNFine-tune molecular LMTrain GNN from scratch
ProteinsESM-2 embeddings + headFine-tune ESM-2Train protein LM
Medical imgPretrained CNNnnU-Net (auto-config)Swin-UNETR / MedSAM
Key principle: architecture matters less than training recipe at equal compute. A well-tuned ResNet beats a poorly-tuned ViT (ref: "ResNet Strikes Back", Wightman 2021).
For biomedical domains, see
references/biomedical.md
. For sequence model selection and compute planning, see
references/scaling-and-selection.md
.

根据数据类型数据规模选择合适的模型:
数据类型少于10K样本10K-100K样本超过100K样本
图像预训练CNN + 微调微调ViT或CNN从头训练ViT
生成式文本少样本提示微调GPT/LLaMA(LoRA)从头预训练
表格数据XGBoost/LightGBM仍使用XGBoost可尝试神经网络
音频预训练Whisper微调AST从头训练
分子预训练GNN微调分子语言模型从头训练GNN
蛋白质ESM-2嵌入 + 分类头微调ESM-2训练蛋白质语言模型
医学影像预训练CNNnnU-Net(自动配置)Swin-UNETR / MedSAM
核心原则:在算力相同的情况下,训练方案的影响大于架构选择。调优良好的ResNet性能优于调优不佳的ViT(参考:《ResNet Strikes Back》,Wightman 2021)。
生物医学领域相关内容请查看
references/biomedical.md
。 序列模型选择和算力规划请查看
references/scaling-and-selection.md

Scaling Laws

缩放法则

Chinchilla rule (Hoffmann et al., 2022)

Chinchilla法则(Hoffmann等人,2022)

Compute-optimal training: ~20 tokens per parameter.
Model SizeCompute-OptimalInference-Optimal (100×)
125M2.5B tokens12.5B tokens
1B20B tokens100B tokens
7B140B tokens700B tokens
FLOPs ≈ 6 × N × D (N=params, D=tokens). Data repetition limit: ~4 epochs before diminishing returns.

算力最优训练:每参数约对应20个token
模型规模算力最优token量推理最优token量(100×)
125M25亿125亿
1B200亿1000亿
7B1400亿7000亿
FLOPs ≈ 6 × N × D(N=参数数量,D=token数量)。数据重复上限:约4轮 epoch后收益递减。

Training Loop

训练循环

python
import gc, time, torch

torch.manual_seed(42)
torch.set_float32_matmul_precision("high")  # TF32 on Ampere+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

grad_accum_steps = total_batch_size // (batch_size * seq_len)
step = 0

while not done:
    t0 = time.time()
    for micro_step in range(grad_accum_steps):
        with autocast_ctx:
            loss = model(x, y)
        (loss / grad_accum_steps).backward()
        x, y = next(train_loader)

    update_lr(optimizer, progress)
    optimizer.step()
    model.zero_grad(set_to_none=True)  # frees memory vs zeroing

    if loss.item() > 100:  # fast-fail on divergence
        print("FAIL: loss exploded"); exit(1)

    torch.cuda.synchronize()
    if step == 0:
        gc.collect(); gc.freeze(); gc.disable()  # avoid ~500ms GC stalls
    step += 1
python
import gc, time, torch

torch.manual_seed(42)
torch.set_float32_matmul_precision("high")  # TF32 on Ampere+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

grad_accum_steps = total_batch_size // (batch_size * seq_len)
step = 0

while not done:
    t0 = time.time()
    for micro_step in range(grad_accum_steps):
        with autocast_ctx:
            loss = model(x, y)
        (loss / grad_accum_steps).backward()
        x, y = next(train_loader)

    update_lr(optimizer, progress)
    optimizer.step()
    model.zero_grad(set_to_none=True)  # frees memory vs zeroing

    if loss.item() > 100:  # fast-fail on divergence
        print("FAIL: loss exploded"); exit(1)

    torch.cuda.synchronize()
    if step == 0:
        gc.collect(); gc.freeze(); gc.disable()  # avoid ~500ms GC stalls
    step += 1

Key principles

核心原则

  • Gradient clipping:
    clip_grad_norm_(params, 1.0)
    — near-universal for Transformers. Exception: Muon optimizer normalizes updates via orthogonalization, so clipping is optional.
  • Tensor Core alignment: batch size, hidden dims should be multiples of 8 (bf16) or 64 (A100).
  • Time-based budgets make experiments comparable across hardware.
  • cudnn.benchmark = True
    for fixed-size vision inputs.

  • 梯度裁剪
    clip_grad_norm_(params, 1.0)
    — 几乎适用于所有Transformer模型。 例外:Muon优化器通过正交化归一化更新,因此裁剪可选。
  • Tensor Core对齐:批量大小、隐藏层维度应为8(bf16)或64(A100)的倍数。
  • 基于时间的预算:使不同硬件上的实验具有可比性。
  • cudnn.benchmark = True
    适用于固定尺寸的视觉输入。

Optimizer Configuration

优化器配置

Modern LLM training uses different optimizers per parameter group:
Parameter TypeOptimizerLR (base)Weight Decay
2D weight matricesMuon0.040.2
Token embeddingsAdamW0.6 × scale0.0
Unembedding (lm_head)AdamW0.004 × scale0.0
Per-layer scalarsAdamW0.005 × scale0.0
LR scaling by dimension:
lr * (d_model / 768)^(-0.5)
— keeps dynamics stable across sizes.
现代LLM训练针对不同参数组使用不同优化器:
参数类型优化器基础学习率权重衰减
2D权重矩阵Muon0.040.2
Token嵌入AdamW0.6 × 缩放系数0.0
反嵌入层(lm_head)AdamW0.004 × 缩放系数0.0
每层标量参数AdamW0.005 × 缩放系数0.0
按维度缩放学习率
lr * (d_model / 768)^(-0.5)
— 保持不同模型规模下的训练动态稳定。

Rules of thumb

经验法则

  • Embeddings need higher LR (sparse updates). Never weight-decay embeddings.
  • Weight decay scheduling: linearly decay WD to 0 over training.
  • AdamW defaults: β1=0.9, β2=0.95, eps=1e-10 (not default 1e-8 — prevents stale updates in bf16).
For Muon details (polar express orthogonalization, NorMuon), see
references/optimizers.md
.

  • 嵌入层需要更高的学习率(稀疏更新)。切勿对嵌入层应用权重衰减。
  • 权重衰减调度:训练过程中线性将WD衰减至0。
  • AdamW默认值:β1=0.9,β2=0.95,eps=1e-10(而非默认1e-8 — 避免bf16中的陈旧更新)。
Muon的详细信息(polar express正交化、NorMuon)请查看
references/optimizers.md

Learning Rate Scheduling

学习率调度

Time-based (autoresearch style)

基于时间的调度(autoresearch风格)

python
def get_lr_multiplier(progress):  # progress = elapsed_time / time_budget
    if progress < warmup_ratio:
        return progress / warmup_ratio
    elif progress < 1.0 - warmdown_ratio:
        return 1.0
    else:
        cooldown = (1.0 - progress) / warmdown_ratio
        return cooldown + (1 - cooldown) * final_lr_frac
python
def get_lr_multiplier(progress):  # progress = 已用时间 / 时间预算
    if progress < warmup_ratio:
        return progress / warmup_ratio
    elif progress < 1.0 - warmdown_ratio:
        return 1.0
    else:
        cooldown = (1.0 - progress) / warmdown_ratio
        return cooldown + (1 - cooldown) * final_lr_frac

Cosine decay

余弦衰减

python
def get_lr(step, total_steps, max_lr, min_lr, warmup_steps):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
WSD (Warmup-Stable-Decay): gaining traction — easier to resume training mid-run.
python
def get_lr(step, total_steps, max_lr, min_lr, warmup_steps):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
WSD(预热-稳定-衰减):逐渐普及——便于中途恢复训练。

Guidance

指导建议

  • Warmup: 1-5% of training. Zero warmup valid with Muon (autoresearch uses
    WARMUP_RATIO=0.0
    ).
  • Warmdown: 30-50% of training in LR decay. Matters more than warmup for final quality.
  • Final LR: 0 or ~10% of peak. Zero is simpler.

  • 预热:占训练总步数的1-5%。使用Muon时可无需预热(autoresearch使用
    WARMUP_RATIO=0.0
    )。
  • 衰减:学习率衰减阶段占训练总步数的30-50%。对最终模型质量的影响大于预热。
  • 最终学习率:0或峰值的约10%。设为0更简单。

Mixed Precision & Compilation

混合精度与编译

python
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"  # before torch import

import torch
torch.set_float32_matmul_precision("high")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
model = torch.compile(model, dynamic=False)
  • bf16 (Ampere+): same exponent as fp32, no loss scaling needed. Preferred over fp16.
  • fp16: needs GradScaler. Use only on V100 or older.
  • dynamic=False
    enables max optimization. Add
    fullgraph=True
    if no graph breaks.
  • First steps are slow (JIT) — exclude from timing.

python
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"  # before torch import

import torch
torch.set_float32_matmul_precision("high")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
model = torch.compile(model, dynamic=False)
  • bf16(Ampere及以后架构):与fp32指数范围相同,无需损失缩放。优先于fp16使用。
  • fp16:需要GradScaler。仅在V100或更旧设备上使用。
  • dynamic=False
    启用最大优化。若无图中断,可添加
    fullgraph=True
  • 前几步速度较慢(JIT编译)——不计入计时。

Memory & Performance

内存与性能

Meta device init (large models)

Meta设备初始化(大模型)

python
with torch.device("meta"):
    model = GPT(config)          # zero memory
model.to_empty(device="cuda")
model.init_weights()
python
with torch.device("meta"):
    model = GPT(config)          # 零内存占用
model.to_empty(device="cuda")
model.init_weights()

MFU (Model FLOPs Utilization)

MFU(模型FLOPs利用率)

python
achieved_flops = model_flops_per_token * batch_tokens / step_time
mfu = achieved_flops / gpu_peak_flops
python
achieved_flops = model_flops_per_token * batch_tokens / step_time
mfu = achieved_flops / gpu_peak_flops

H100 SXM: 989.5 TFLOPS | A100: 312 | RTX 4090: 165

H100 SXM: 989.5 TFLOPS | A100: 312 | RTX 4090: 165


Good targets: >30% decent, >40% good, >50% excellent (single-GPU).

优秀目标:单GPU场景下,>30%为良好,>40%为优秀,>50%为极佳。

OOM solutions (in order)

OOM解决方案(优先级从高到低)

  1. Reduce
    DEVICE_BATCH_SIZE
    , increase
    grad_accum_steps
  2. PYTORCH_ALLOC_CONF=expandable_segments:True
  3. model.zero_grad(set_to_none=True)
  4. Meta device init →
    to_empty
  5. Activation checkpointing:
    torch.utils.checkpoint.checkpoint()
  6. 8-bit optimizer (bitsandbytes): ~30% savings on optimizer states

  1. 减小
    DEVICE_BATCH_SIZE
    ,增加
    grad_accum_steps
  2. 设置
    PYTORCH_ALLOC_CONF=expandable_segments:True
  3. 使用
    model.zero_grad(set_to_none=True)
  4. Meta设备初始化 →
    to_empty
  5. 激活 checkpointing:
    torch.utils.checkpoint.checkpoint()
  6. 8位优化器(bitsandbytes):优化器状态内存节省约30%

Hyperparameter Search

超参数搜索

Priority order (tune first → last)

优先级顺序(从先到后调优)

  1. Learning rate — most impactful. Always tune first.
  2. Batch size — largest that fits. Speed knob, not quality knob.
  3. Weight decay — 0.01-0.1 for AdamW.
  4. Warmup steps — 1-5% of training.
  1. 学习率 — 影响最大。始终优先调优。
  2. 批量大小 — 取硬件能容纳的最大值。是速度调节旋钮,而非质量调节旋钮。
  3. 权重衰减 — AdamW建议取值0.01-0.1。
  4. 预热步数 — 占训练总步数的1-5%。

The 2025 default recipe

2025年默认方案

SettingValue
OptimizerAdamW (β1=0.9, β2=0.95, eps=1e-10)
Weight decay0.1
LR scheduleCosine decay or WSD
Peak LR3e-4 (scale down for larger models)
Precisionbf16
Grad clippingmax_norm=1.0
NormalizationRMSNorm (pre-norm)
ActivationSwiGLU
Position encodingRoPE
AttentionFlash Attention, optionally GQA

设置项
优化器AdamW (β1=0.9, β2=0.95, eps=1e-10)
权重衰减0.1
学习率调度余弦衰减或WSD
峰值学习率3e-4(模型规模越大,取值越小)
精度bf16
梯度裁剪max_norm=1.0
归一化RMSNorm(前置归一化)
激活函数SwiGLU
位置编码RoPE
注意力机制Flash Attention,可选GQA

Debugging Checklist

调试清单

Karpathy's recipe (still canonical)

Karpathy的经典方案

  1. Become one with the data — visualize, check distributions, verify labels
  2. Get end-to-end running first — verify on a trivial case
  3. Overfit one batch — if you can't, you have a bug
  4. Then regularize — add regularization only after overfitting works
  5. Tune hyperparameters — start with known defaults
  1. 深入理解数据 — 可视化、检查分布、验证标签
  2. 先跑通端到端流程 — 在简单场景下验证
  3. 过拟合单个批次 — 如果无法做到,说明存在bug
  4. 再添加正则化 — 只有在过拟合正常后才添加正则化
  5. 调优超参数 — 从已知默认值开始

Loss exploding / NaN

损失突增/NaN

  1. Reduce LR (3-10× smaller)
  2. Add gradient clipping:
    clip_grad_norm_(params, 1.0)
  3. Check for inf/nan in inputs
  4. Add logit soft capping:
    softcap * tanh(logits / softcap)
  5. Add QK-norm in attention
  6. Verify weight init (zero-init output projections?)
  7. Check loss reduction with gradient accumulation (
    loss / grad_accum_steps
    )
  1. 降低学习率(缩小3-10倍)
  2. 添加梯度裁剪:
    clip_grad_norm_(params, 1.0)
  3. 检查输入中是否存在inf/nan
  4. 添加logit软截断:
    softcap * tanh(logits / softcap)
  5. 在注意力机制中添加QK归一化
  6. 验证权重初始化(输出投影层是否零初始化?)
  7. 检查梯度累积时的损失缩减方式(
    loss / grad_accum_steps

Slow training / Low MFU

训练缓慢/MFU低

  1. Verify
    torch.compile
    is active
  2. Check
    torch.set_float32_matmul_precision("high")
  3. Pin memory + non_blocking transfers
  4. Profile with
    torch.profiler
  5. GC stalls?
    gc.freeze(); gc.disable()
  6. Tensor Core alignment: dims multiples of 8/64
  1. 验证
    torch.compile
    是否激活
  2. 检查是否设置
    torch.set_float32_matmul_precision("high")
  3. 使用固定内存 + 非阻塞传输
  4. 使用
    torch.profiler
    进行性能分析
  5. 存在GC停顿?使用
    gc.freeze(); gc.disable()
  6. Tensor Core对齐:维度为8/64的倍数

Loss plateau / Slow convergence

损失停滞/收敛缓慢

  1. LR too low — try 2-5× larger
  2. Warmup too long
  3. Weight decay too high
  4. Verify LR schedule is actually applied (print each step)
  5. Model too small for task
  1. 学习率过低 — 尝试增大2-5倍
  2. 预热时间过长
  3. 权重衰减过高
  4. 验证学习率调度是否实际生效(打印每一步的学习率)
  5. 模型规模相对于任务过小

Silent failures

隐性故障

  1. Data leakage between train/val
  2. Wrong preprocessing at inference — augmentation mismatch
  3. Label errors — use cleanlab to detect
  4. Shuffling bugs — correlated batches
  5. Tokenizer mismatch with pretrained model
  1. 训练/验证集数据泄露
  2. 推理阶段预处理错误 — 数据增强不匹配
  3. 标签错误 — 使用cleanlab检测
  4. 洗牌bug — 批次数据存在相关性
  5. 与预训练模型的Tokenizer不匹配

What to monitor

需要监控的指标

  • Gradient norms — spike precedes loss spike
  • Per-layer activation stats — reveals exploding/vanishing
  • Dead neurons — >50% zero ReLU = dying ReLU problem
  • Learning rate — verify schedule applied (common silent bug)

  • 梯度范数 — 梯度突增先于损失突增
  • 每层激活统计 — 揭示梯度爆炸/消失问题
  • 死亡神经元 — 超过50%的ReLU输出为零,说明存在ReLU死亡问题
  • 学习率 — 验证调度是否生效(常见隐性bug)

Experiment Management

实验管理

Track experiments in TSV for easy comparison:
commit  val_bpb  memory_gb  status   description
a1b2c3d 0.9979   44.0       keep     baseline
b2c3d4e 0.9932   44.2       keep     increase matrix LR to 0.04
c3d4e5f 1.0050   44.0       discard  switch to GeLU (worse)
Simplicity criterion: all else equal, simpler is better. Removing something and getting equal results is a great outcome. For systematic agent-driven experimentation, see
references/experiment-loop.md
.
使用TSV格式跟踪实验,便于对比:
commit  val_bpb  memory_gb  status   description
a1b2c3d 0.9979   44.0       keep     baseline
b2c3d4e 0.9932   44.2       keep     increase matrix LR to 0.04
c3d4e5f 1.0050   44.0       discard  switch to GeLU (worse)
简洁性准则:在其他条件相同的情况下,越简洁越好。移除某些组件后仍能获得相同结果是理想的改进。系统化智能体驱动的实验请查看
references/experiment-loop.md

Evaluation metrics by domain

各领域评估指标

DomainPrimary MetricNotes
LLMBPB (bits per byte)Vocab-size-independent
ClassificationAccuracy / F1Macro-F1 for imbalanced
SegmentationmIoU / DicePer-class IoU reveals weak spots
GenerationFIDNeeds >10k samples
RegressionRMSE / MAELog-transform skewed targets
领域主要指标说明
LLMBPB(bits per byte)与词汇表大小无关
分类任务准确率/F1不平衡数据集使用Macro-F1
分割任务mIoU/Dice按类别IoU可发现薄弱环节
生成任务FID需要超过10K样本
回归任务RMSE/MAE对偏斜目标进行对数变换