pytorch-research

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

PyTorch - Advanced Research & Engineering

PyTorch - 进阶研究与工程实践

Research-grade PyTorch requires moving beyond
nn.Sequential
. You need to control how gradients flow, how weights are initialized, and how computation is distributed across multiple GPUs. This guide covers the "internals" of the framework.
科研级别的PyTorch使用需要跳出
nn.Sequential
的局限。你需要控制梯度的流动方式、权重的初始化方式,以及如何在多GPU间分配计算任务。本指南将深入讲解该框架的“内部机制”。

When to Use

适用场景

  • Implementing custom layers with non-standard mathematical derivatives.
  • Debugging vanishing or exploding gradients using Hooks.
  • Scaling models to multiple GPUs (Distributed Data Parallel).
  • Fine-tuning model performance using the PyTorch Profiler.
  • Creating complex learning rate schedules (Cyclic, OneCycle).
  • Deploying models for high-performance inference (TorchScript, FX).
  • Researching Weight Initialization and Normalization techniques.
  • 实现带有非标准数学导数的自定义层。
  • 使用钩子(Hooks)调试梯度消失或爆炸问题。
  • 将模型扩展至多GPU运行(分布式数据并行)。
  • 使用PyTorch性能分析器(Profiler)优化模型性能。
  • 创建复杂的学习率调度策略(循环式、OneCycle)。
  • 部署模型以实现高性能推理(TorchScript、FX)。
  • 研究权重初始化与归一化技术。

Reference Documentation

参考文档

Core Principles

核心原则

Beyond the Computational Graph

超越计算图

PyTorch is a "define-by-run" framework, but for research, you often need to intervene in the backward pass or inspect intermediate tensors without breaking the graph.
PyTorch是一个“定义即运行”的框架,但在科研场景中,你经常需要干预反向传播过程,或在不破坏计算图的前提下检查中间张量。

The Life of a Gradient

梯度的生命周期

Understanding that gradients are accumulated in
.grad
attributes and that
backward()
consumes the graph unless
retain_graph=True
is specified.
要理解梯度会被累积在
.grad
属性中,且除非指定
retain_graph=True
,否则
backward()
会消耗计算图。

Memory vs. Speed

内存与速度的权衡

In research, you often trade memory (activations) for speed (recomputation) using techniques like checkpointing.
在科研中,你常需要通过检查点(checkpointing)等技术,以内存(激活值存储)换取速度(重新计算)。

Quick Reference

快速参考

Standard Imports

标准导入

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

Basic Pattern - Custom Autograd Function

基础模式 - 自定义Autograd函数

python
class MySignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # Save input for backward pass
        ctx.save_for_backward(input)
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        # Straight-through estimator (STE) logic
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # Custom logic: gradients pass through as if it were an identity
        return grad_input
python
class MySignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # 保存输入用于反向传播
        ctx.save_for_backward(input)
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        # 直通估计器(STE)逻辑
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # 自定义逻辑:梯度像恒等函数一样传递
        return grad_input

Usage

使用方式

my_sign = MySignFunction.apply
undefined
my_sign = MySignFunction.apply
undefined

Critical Rules

关键规则

✅ DO

✅ 推荐做法

  • Use register_full_backward_hook - To inspect or modify gradients as they flow through a specific module.
  • Initialize Weights Explicitly - Use
    torch.nn.init
    (Xavier, Kaiming) inside a
    model.apply(fn)
    loop.
  • Use DistributedDataParallel (DDP) - Instead of DataParallel (DP). DDP is faster and handles multi-process communication correctly.
  • Profile Before Optimizing - Use
    torch.profiler
    to find which operator (e.g., a slow
    view()
    or
    cat()
    ) is actually slowing down the model.
  • Use torch.cuda.empty_cache() sparingly - It doesn't free physical memory to the OS, but it fragments the PyTorch memory manager. Only use it in long-running loops if needed.
  • 使用register_full_backward_hook - 用于在梯度流经特定模块时检查或修改梯度。
  • 显式初始化权重 - 在
    model.apply(fn)
    循环中使用
    torch.nn.init
    (Xavier、Kaiming初始化)。
  • 使用DistributedDataParallel (DDP) - 替代DataParallel (DP)。DDP速度更快,且能正确处理多进程通信。
  • 先分析再优化 - 使用
    torch.profiler
    找出真正拖慢模型的算子(例如较慢的
    view()
    cat()
    )。
  • 谨慎使用torch.cuda.empty_cache() - 它不会将物理内存释放给操作系统,反而会导致PyTorch内存管理器碎片化。仅在长时间运行的循环中必要时使用。

❌ DON'T

❌ 避免做法

  • Don't use inplace=True in custom layers - This often breaks Autograd's ability to compute gradients correctly.
  • Don't use item() inside the loop - Calling
    .item()
    on a GPU tensor forces a CPU-GPU sync, which kills performance.
  • Don't forget to set shuffle=False for DistributedSampler - Let the sampler handle the shuffling logic in a multi-GPU environment.
  • Avoid Global Variables - PyTorch models should be self-contained for easy serialization and deployment.
  • 不要在自定义层中使用inplace=True - 这通常会破坏Autograd计算梯度的能力。
  • 不要在循环内使用item() - 对GPU张量调用
    .item()
    会强制进行CPU-GPU同步,严重影响性能。
  • 不要忘记为DistributedSampler设置shuffle=False - 让采样器在多GPU环境中处理洗牌逻辑。
  • 避免使用全局变量 - PyTorch模型应具备自包含性,以便于序列化和部署。

Advanced Custom Layers

进阶自定义层

Hooks for Debugging and Feature Extraction

用于调试与特征提取的钩子

python
def print_grad_norm(module, grad_input, grad_output):
    print(f"Module: {module.__class__.__name__}, Grad Norm: {grad_output[0].norm().item()}")
python
def print_grad_norm(module, grad_input, grad_output):
    print(f"Module: {module.__class__.__name__}, Grad Norm: {grad_output[0].norm().item()}")

Attach to a specific layer

附加到特定层

model.fc1.register_full_backward_hook(print_grad_norm)
model.fc1.register_full_backward_hook(print_grad_norm)

Extract activations (Forward Hook)

提取激活值(前向钩子)

activations = {} def get_activation(name): def hook(model, input, output): activations[name] = output.detach() return hook
model.conv1.register_forward_hook(get_activation('conv1'))
undefined
activations = {} def get_activation(name): def hook(model, input, output): activations[name] = output.detach() return hook
model.conv1.register_forward_hook(get_activation('conv1'))
undefined

Advanced Training Patterns

进阶训练模式

Distributed Data Parallel (DDP) Skeleton

分布式数据并行(DDP)框架代码

python
import torch.multiprocessing as mp

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train(rank, world_size):
    setup(rank, world_size)
    
    model = MyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    
    # Use DistributedSampler to ensure each GPU sees different data
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, sampler=sampler, batch_size=32)
    
    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.001)
    # ... training loop ...
    
    dist.destroy_process_group()
python
import torch.multiprocessing as mp

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train(rank, world_size):
    setup(rank, world_size)
    
    model = MyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    
    # 使用DistributedSampler确保每个GPU看到不同的数据
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, sampler=sampler, batch_size=32)
    
    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.001)
    # ... 训练循环 ...
    
    dist.destroy_process_group()

mp.spawn(train, args=(world_size,), nprocs=world_size)

mp.spawn(train, args=(world_size,), nprocs=world_size)

undefined
undefined

Performance & Profiling

性能与分析

Using the Profiler

使用性能分析器

python
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 
             record_shapes=True) as prof:
    with record_function("model_inference"):
        model(inputs)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
python
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 
             record_shapes=True) as prof:
    with record_function("model_inference"):
        model(inputs)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Gradient Checkpointing (Memory Saving)

梯度检查点(节省内存)

If you have a very deep model and limited memory, trade computation for space.
python
from torch.utils.checkpoint import checkpoint

class DeepModel(nn.Module):
    def forward(self, x):
        # Instead of storing all activations, recompute them during backward
        x = checkpoint(self.heavy_layer_1, x)
        x = checkpoint(self.heavy_layer_2, x)
        return x
如果你有一个非常深的模型且内存有限,可以用计算量换取空间。
python
from torch.utils.checkpoint import checkpoint

class DeepModel(nn.Module):
    def forward(self, x):
        # 不存储所有激活值,而是在反向传播时重新计算
        x = checkpoint(self.heavy_layer_1, x)
        x = checkpoint(self.heavy_layer_2, x)
        return x

Practical Workflows

实用工作流

1. Custom Weight Initialization

1. 自定义权重初始化

python
def init_weights(m):
    if isinstance(m, nn.Linear):
        # Kaiming initialization for ReLU networks
        torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)

model.apply(init_weights)
python
def init_weights(m):
    if isinstance(m, nn.Linear):
        # 针对ReLU网络的Kaiming初始化
        torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)

model.apply(init_weights)

2. Gradient Clipping (Stability)

2. 梯度裁剪(提升稳定性)

python
undefined
python
undefined

Inside training loop

训练循环内

loss.backward()
loss.backward()

Clip to prevent exploding gradients (standard in RNNs/Transformers)

裁剪梯度防止爆炸(在RNN/Transformer中常用)

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
undefined
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
undefined

3. Dynamic Learning Rate (OneCycleLR)

3. 动态学习率(OneCycleLR)

python
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, 
                                                steps_per_epoch=len(train_loader), 
                                                epochs=10)

for epoch in range(10):
    for batch in train_loader:
        train_batch()
        scheduler.step() # Step every batch for OneCycle
python
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, 
                                                steps_per_epoch=len(train_loader), 
                                                epochs=10)

for epoch in range(10):
    for batch in train_loader:
        train_batch()
        scheduler.step() # OneCycle策略需每个批次执行一次step

Common Pitfalls and Solutions

常见陷阱与解决方案

In-place Modification Error

原地修改错误

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.
python
undefined
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.
python
undefined

❌ Problem: x += 1 (breaks backward pass)

❌ 问题:x += 1(破坏反向传播)

✅ Solution: y = x + 1 (creates a new tensor)

✅ 解决方案:y = x + 1(创建新张量)

undefined
undefined

CUDA Out of Memory (OOM) Strategies

CUDA内存不足(OOM)解决策略

  • Batch Size: Reduce it.
  • Gradient Accumulation: Compute loss for small batches, but only
    step()
    every N steps.
  • Empty Cache: Use
    torch.cuda.empty_cache()
    between independent evaluations.
  • Mixed Precision: Use
    torch.cuda.amp
    (saves 50% memory).
  • 批量大小:减小批量大小。
  • 梯度累积:小批量计算损失,但每N个批次才执行一次
    step()
  • 清空缓存:在独立评估之间使用
    torch.cuda.empty_cache()
  • 混合精度:使用
    torch.cuda.amp
    (节省50%内存)。

Silent Failure: zero_grad() position

静默失败:zero_grad()的位置错误

If you call
zero_grad()
after
backward()
but before
step()
, your model will never learn.
python
undefined
如果在
backward()
之后、
step()
之前调用
zero_grad()
,模型将无法学习。
python
undefined

✅ Correct order:

✅ 正确顺序:

optimizer.zero_grad() loss.backward() optimizer.step()

Research PyTorch is about mastery over the mathematical engine. By leveraging custom gradients, hooks, and distributed infrastructure, you can move from training standard models to inventing the next generation of scientific AI.
optimizer.zero_grad() loss.backward() optimizer.step()

科研级PyTorch的核心是掌握其数学引擎。通过利用自定义梯度、钩子和分布式基础设施,你可以从训练标准模型进阶到发明下一代科学AI模型。