pytorch-fsdp2

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Skill: Use PyTorch FSDP2 (
fully_shard
) correctly in a training script

技能:在训练脚本中正确使用PyTorch FSDP2(
fully_shard

This skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
FSDP2 in PyTorch is exposed primarily via
torch.distributed.fsdp.fully_shard
and the
FSDPModule
methods it adds in-place to modules. See:
references/pytorch_fully_shard_api.md
,
references/pytorch_fsdp2_tutorial.md
.

本技能指导编码Agent如何在训练循环中添加PyTorch FSDP2,包含正确的初始化、分片、混合精度/卸载配置以及 checkpoint 功能。
PyTorch中的FSDP2主要通过
torch.distributed.fsdp.fully_shard
及其为模块原地添加的
FSDPModule
方法对外暴露。参考:
references/pytorch_fully_shard_api.md
references/pytorch_fsdp2_tutorial.md

When to use this skill

何时使用本技能

Use FSDP2 when:
  • Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
  • You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
  • You may later compose DP with Tensor Parallel using DeviceMesh.
Avoid (or be careful) if:
  • You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
  • You’re forced onto older PyTorch versions without the FSDP2 stack.
在以下场景使用FSDP2:
  • 你的模型无法容纳在单块GPU中(参数+梯度+优化器状态)。
  • 你需要一种基于DTensor逐参数分片的即时模式分片方案(相比FSDP1,状态字典更易检查、更简洁)。
  • 你之后可能需要结合数据并行(DP)与张量并行(Tensor Parallel),使用DeviceMesh
在以下场景避免使用(或谨慎使用):
  • 你需要跨PyTorch版本严格向后兼容的checkpoint(DCP对此发出警告)。
  • 你被迫使用不支持FSDP2栈的旧版PyTorch。

Alternatives (when FSDP2 is not the best fit)

替代方案(当FSDP2不是最佳选择时)

  • DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
  • FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.
Reference:
references/pytorch_ddp_notes.md
,
references/pytorch_fsdp1_api.md
.

  • DistributedDataParallel (DDP):当你需要经典分布式数据并行训练时,使用标准数据并行封装器。
  • FullyShardedDataParallel (FSDP1):当你需要在数据并行工作节点间进行参数分片时,使用原始FSDP封装器。
参考:
references/pytorch_ddp_notes.md
references/pytorch_fsdp1_api.md

Contract the agent must follow

Agent必须遵循的约定

  1. Launch with
    torchrun
    and set the CUDA device per process (usually via
    LOCAL_RANK
    ).
  2. Apply
    fully_shard()
    bottom-up
    , i.e., shard submodules (e.g., Transformer blocks) before the root module.
  3. Call
    model(input)
    , not
    model.forward(input)
    , so the FSDP2 hooks run (unless you explicitly
    unshard()
    or register the forward method).
  4. Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-
    fully_shard
    ).
  5. Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve
    torch.save(model.state_dict())
    unless you deliberately gather to full tensors.
(Each of these rules is directly described in the official API docs/tutorial; see references.)

  1. 使用
    torchrun
    启动
    ,并为每个进程设置CUDA设备(通常通过
    LOCAL_RANK
    )。
  2. 自底向上应用
    fully_shard()
    ,即先分片子模块(如TransformerBlock),再分片根模块。
  3. 调用
    model(input)
    ,而非
    model.forward(input)
    ,这样FSDP2的钩子才能运行(除非你显式调用
    unshard()
    或注册forward方法)。
  4. 在分片后创建优化器,确保优化器基于DTensor参数(执行
    fully_shard
    之后的参数)构建。
  5. 使用分布式Checkpoint(DCP)或分布式状态字典助手进行Checkpoint,不要直接使用
    torch.save(model.state_dict())
    ,除非你故意将张量聚合为完整张量。
(这些规则均直接来自官方API文档/教程;请参考相关引用。)

Step-by-step procedure

分步操作流程

0) Version & environment sanity

0) 版本与环境检查

  • Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
  • Use
    torchrun --nproc_per_node <gpus_per_node> ...
    and ensure
    RANK
    ,
    WORLD_SIZE
    ,
    LOCAL_RANK
    are visible.
Reference:
references/pytorch_fsdp2_tutorial.md
(launch commands and setup),
references/pytorch_fully_shard_api.md
(user contract).

  • 优先使用近期稳定版PyTorch,确保其文档中FSDP2和DCP是最新的。
  • 使用
    torchrun --nproc_per_node <gpus_per_node> ...
    启动,并确保
    RANK
    WORLD_SIZE
    LOCAL_RANK
    环境变量可见。
参考:
references/pytorch_fsdp2_tutorial.md
(启动命令与设置)、
references/pytorch_fully_shard_api.md
(用户约定)。

1) Initialize distributed and set device

1) 初始化分布式环境并设置设备

Minimal, correct pattern:
  • dist.init_process_group(backend="nccl")
  • torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
  • Optionally create a
    DeviceMesh
    to describe the data-parallel group(s)
Reference:
references/pytorch_device_mesh_tutorial.md
(why DeviceMesh exists & how it manages process groups).

最简正确模式:
  • dist.init_process_group(backend="nccl")
  • torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
  • (可选)创建
    DeviceMesh
    以描述数据并行组。
参考:
references/pytorch_device_mesh_tutorial.md
(DeviceMesh的存在意义及进程组管理方式)。

2) Build model on meta device (recommended for very large models)

2) 在meta设备上构建模型(超大型模型推荐)

For big models, initialize on
meta
, apply sharding, then materialize weights on GPU:
  • with torch.device("meta"): model = ...
  • apply
    fully_shard(...)
    on submodules, then
    fully_shard(model)
  • model.to_empty(device="cuda")
  • model.reset_parameters()
    (or your init routine)
Reference:
references/pytorch_fsdp2_tutorial.md
(migration guide shows this flow explicitly).

对于大型模型,先在
meta
设备初始化,应用分片,再将权重实例化到GPU:
  • with torch.device("meta"): model = ...
  • 在子模块上应用
    fully_shard(...)
    ,然后对
    model
    调用
    fully_shard(model)
  • model.to_empty(device="cuda")
  • model.reset_parameters()
    (或你的自定义初始化流程)
参考:
references/pytorch_fsdp2_tutorial.md
(迁移指南明确展示了此流程)。

3) Apply
fully_shard()
bottom-up (wrapping policy = “apply where needed”)

3) 自底向上应用
fully_shard()
(封装策略=“按需应用”)

Do not only call
fully_shard
on the topmost module.
Recommended sharding pattern for transformer-like models:
  • iterate modules,
    if isinstance(m, TransformerBlock): fully_shard(m, ...)
  • then
    fully_shard(model, ...)
Why:
  • fully_shard
    forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.
Reference:
references/pytorch_fully_shard_api.md
(bottom-up requirement and why).

不要仅在最顶层模块调用
fully_shard
针对类Transformer模型的推荐分片模式:
  • 遍历模块,
    if isinstance(m, TransformerBlock): fully_shard(m, ...)
  • 然后调用
    fully_shard(model, ...)
原因:
  • fully_shard
    会为集合通信效率形成“参数组”,并排除已被之前调用分组的参数。自底向上的方式能实现更好的重叠计算,降低峰值内存占用。
参考:
references/pytorch_fully_shard_api.md
(自底向上的要求及原因)。

4) Configure
reshard_after_forward
for memory/perf trade-offs

4) 配置
reshard_after_forward
以平衡内存/性能

Default behavior:
  • None
    means
    True
    for non-root modules and
    False
    for root modules (good default).
Heuristics:
  • If you’re memory-bound: keep defaults or force
    True
    on many blocks.
  • If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often
    False
    ).
  • Advanced: use an
    int
    to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.
Reference:
references/pytorch_fully_shard_api.md
(full semantics).

默认行为:
  • None
    表示非根模块使用
    True
    ,根模块使用
    False
    (良好的默认值)。
启发式规则:
  • 如果受内存限制:保持默认值,或在多个块上强制设置
    True
  • 如果受吞吐量限制且内存充足:考虑让参数保持未分片状态更久(根模块通常设为
    False
    )。
  • 高级用法:使用
    int
    类型值,在forward后将分片恢复到更小的mesh(如节点内),前提是该值是mesh大小的有效约数。
参考:
references/pytorch_fully_shard_api.md
(完整语义)。

5) Mixed precision & offload (optional but common)

5) 混合精度与卸载(可选但常用)

FSDP2 uses:
  • mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)
  • offload_policy=CPUOffloadPolicy()
    if you want CPU offload
Rules of thumb:
  • Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
  • Keep
    reduce_dtype
    aligned with your gradient reduction expectations.
  • If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.
Reference:
references/pytorch_fully_shard_api.md
(MixedPrecisionPolicy / OffloadPolicy classes).

FSDP2使用:
  • mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)
  • 如果需要CPU卸载,使用
    offload_policy=CPUOffloadPolicy()
经验法则:
  • 在H100/A100级GPU上,优先使用BF16参数/梯度归约(如果你的模型数值稳定)。
  • 保持
    reduce_dtype
    与你的梯度归约预期一致。
  • 如果使用CPU卸载,要考虑PCIe/NVLink带宽和运行时开销。
参考:
references/pytorch_fully_shard_api.md
(MixedPrecisionPolicy / OffloadPolicy类)。

6) Optimizer, gradient clipping, accumulation

6) 优化器、梯度裁剪、梯度累积

  • Create the optimizer after sharding so it holds DTensor params.
  • If you need gradient accumulation / no_sync:
    • use the FSDP2 mechanism (
      set_requires_gradient_sync
      ) instead of FSDP1’s
      no_sync()
      .
Gradient clipping:
  • Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.
Reference:
references/pytorch_fsdp2_tutorial.md
.

  • 在分片后创建优化器,使其持有DTensor参数。
  • 如果需要梯度累积/禁用同步:
    • 使用FSDP2机制(
      set_requires_gradient_sync
      ),而非FSDP1的
      no_sync()
梯度裁剪:
  • 使用FSDP2教程中展示的方法(“Gradient Clipping and Optimizer with DTensor”),因为参数/梯度是DTensor类型。
参考:
references/pytorch_fsdp2_tutorial.md

7) Checkpointing: prefer DCP or distributed state dict helpers

7) Checkpoint:优先使用DCP或分布式状态字典助手

Two recommended approaches:
A) Distributed Checkpoint (DCP) — best default
  • DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
  • DCP produces multiple files (often at least one per rank) and operates “in place”.
B) Distributed state dict helpers
  • get_model_state_dict
    /
    set_model_state_dict
    with
    StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)
  • For optimizer:
    get_optimizer_state_dict
    /
    set_optimizer_state_dict
Avoid:
  • Saving DTensor state dicts with plain
    torch.save
    unless you intentionally convert with
    DTensor.full_tensor()
    and manage memory carefully.
References:
  • references/pytorch_dcp_overview.md
    (DCP behavior and caveats)
  • references/pytorch_dcp_recipe.md
    and
    references/pytorch_dcp_async_recipe.md
    (end-to-end usage)
  • references/pytorch_fsdp2_tutorial.md
    (DTensor vs DCP state-dict flows)
  • references/pytorch_examples_fsdp2.md
    (working checkpoint scripts)

两种推荐方案:
A) 分布式Checkpoint(DCP)——最佳默认选择
  • DCP支持多节点并行保存/加载,并在加载时支持重新分片。
  • DCP生成多个文件(通常每个节点至少一个),并“原地”操作。
B) 分布式状态字典助手
  • get_model_state_dict
    /
    set_model_state_dict
    结合
    StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)
  • 对于优化器:使用
    get_optimizer_state_dict
    /
    set_optimizer_state_dict
避免:
  • 使用普通
    torch.save
    保存DTensor状态字典,除非你故意用
    DTensor.full_tensor()
    转换并仔细管理内存。
参考:
  • references/pytorch_dcp_overview.md
    (DCP的行为与注意事项)
  • references/pytorch_dcp_recipe.md
    references/pytorch_dcp_async_recipe.md
    (端到端用法)
  • references/pytorch_fsdp2_tutorial.md
    (DTensor与DCP状态字典流程对比)
  • references/pytorch_examples_fsdp2.md
    (可用的checkpoint脚本)

Workflow checklists (copy-paste friendly)

工作流检查清单(可复制粘贴)

Workflow A: Retrofit FSDP2 into an existing training script

工作流A:将FSDP2改造到现有训练脚本中

  • Launch with
    torchrun
    and initialize the process group.
  • Set the CUDA device from
    LOCAL_RANK
    ; create a
    DeviceMesh
    if you need multi-dim parallelism.
  • Build the model (use
    meta
    if needed), apply
    fully_shard
    bottom-up, then
    fully_shard(model)
    .
  • Create the optimizer after sharding so it captures DTensor parameters.
  • Use
    model(inputs)
    so hooks run; use
    set_requires_gradient_sync
    for accumulation.
  • Add DCP save/load via
    torch.distributed.checkpoint
    helpers.
Reference:
references/pytorch_fsdp2_tutorial.md
,
references/pytorch_fully_shard_api.md
,
references/pytorch_device_mesh_tutorial.md
,
references/pytorch_dcp_recipe.md
.
  • 使用
    torchrun
    启动并初始化进程组。
  • LOCAL_RANK
    设置CUDA设备;如果需要多维并行,创建
    DeviceMesh
  • 构建模型(必要时使用
    meta
    设备),自底向上应用
    fully_shard
    ,然后对
    model
    调用
    fully_shard(model)
  • 在分片后创建优化器,使其捕获DTensor参数。
  • 使用
    model(inputs)
    以确保钩子运行;使用
    set_requires_gradient_sync
    进行梯度累积。
  • 通过
    torch.distributed.checkpoint
    助手添加DCP保存/加载功能。
参考:
references/pytorch_fsdp2_tutorial.md
references/pytorch_fully_shard_api.md
references/pytorch_device_mesh_tutorial.md
references/pytorch_dcp_recipe.md

Workflow B: Add DCP save/load (minimal pattern)

工作流B:添加DCP保存/加载(最简模式)

  • Wrap state in
    Stateful
    or assemble state via
    get_state_dict
    .
  • Call
    dcp.save(...)
    from all ranks to a shared path.
  • Call
    dcp.load(...)
    and restore with
    set_state_dict
    .
  • Validate any resharding assumptions when loading into a different mesh.
Reference:
references/pytorch_dcp_recipe.md
.
  • 将状态包装在
    Stateful
    中,或通过
    get_state_dict
    组装状态。
  • 从所有节点调用
    dcp.save(...)
    到共享路径。
  • 调用
    dcp.load(...)
    并通过
    set_state_dict
    恢复状态。
  • 当加载到不同的mesh时,验证所有重新分片的假设。
参考:
references/pytorch_dcp_recipe.md

Debug checklist (what the agent should check first)

调试检查清单(Agent应首先检查的内容)

  1. All ranks on distinct GPUs?
    If not, verify
    torch.cuda.set_device(LOCAL_RANK)
    and your
    torchrun
    flags.
  2. Did you accidentally call
    forward()
    directly?

    Use
    model(input)
    or explicitly
    unshard()
    / register forward.
  3. Is
    fully_shard()
    applied bottom-up?

    If only root is sharded, expect worse memory/perf and possible confusion.
  4. Optimizer created at the right time?
    Must be built on DTensor parameters after sharding.
  5. Checkpointing path consistent?
    • If using DCP, don’t mix with ad-hoc
      torch.save
      unless you understand conversions.
    • Be mindful of PyTorch-version compatibility warnings for DCP.

  1. 所有节点是否使用不同的GPU? 如果不是,验证
    torch.cuda.set_device(LOCAL_RANK)
    和你的
    torchrun
    参数。
  2. 是否意外直接调用了
    forward()
    使用
    model(input)
    ,或显式调用
    unshard()
    /注册forward方法。
  3. 是否自底向上应用了
    fully_shard()
    如果仅分片了根模块,预期内存/性能会变差,且可能出现问题。
  4. 优化器是否在正确的时机创建? 必须在所有
    fully_shard
    调用之后,基于DTensor参数构建。
  5. Checkpoint路径是否一致?
    • 如果使用DCP,不要与临时
      torch.save
      混用,除非你理解转换逻辑。
    • 注意DCP的PyTorch版本兼容性警告。

Common issues and fixes

常见问题与修复方案

  • Forward hooks not running → Call
    model(inputs)
    (or
    unshard()
    explicitly) instead of
    model.forward(...)
    .
  • Optimizer sees non-DTensor params → Create optimizer after all
    fully_shard
    calls.
  • Only root module sharded → Apply
    fully_shard
    bottom-up on submodules before the root.
  • Memory spikes after forward → Set
    reshard_after_forward=True
    for more modules.
  • Gradient accumulation desync → Use
    set_requires_gradient_sync
    instead of FSDP1’s
    no_sync()
    .
Reference:
references/pytorch_fully_shard_api.md
,
references/pytorch_fsdp2_tutorial.md
.

  • Forward钩子未运行 → 调用
    model(inputs)
    (或显式
    unshard()
    ),而非
    model.forward(...)
  • 优化器看到非DTensor参数 → 在所有
    fully_shard
    调用之后创建优化器。
  • 仅根模块被分片 → 在分片根模块之前,自底向上对子模块应用
    fully_shard
  • Forward后内存突增 → 为更多模块设置
    reshard_after_forward=True
  • 梯度累积不同步 → 使用
    set_requires_gradient_sync
    ,而非FSDP1的
    no_sync()
参考:
references/pytorch_fully_shard_api.md
references/pytorch_fsdp2_tutorial.md

Minimal reference implementation outline (agent-friendly)

最简参考实现大纲(Agent友好)

The coding agent should implement a script with these labeled blocks:
  • init_distributed()
    : init process group, set device
  • build_model_meta()
    : model on meta, apply
    fully_shard
    , materialize weights
  • build_optimizer()
    : optimizer created after sharding
  • train_step()
    : forward/backward/step with
    model(inputs)
    and DTensor-aware patterns
  • checkpoint_save/load()
    : DCP or distributed state dict helpers
Concrete examples live in
references/pytorch_examples_fsdp2.md
and the official tutorial reference.

编码Agent应实现包含以下标记块的脚本:
  • init_distributed()
    :初始化进程组,设置设备
  • build_model_meta()
    :在meta设备构建模型,应用
    fully_shard
    ,实例化权重
  • build_optimizer()
    :在分片后创建优化器
  • train_step()
    :使用
    model(inputs)
    和DTensor感知模式执行前向/反向/更新步骤
  • checkpoint_save/load()
    :使用DCP或分布式状态字典助手
具体示例见
references/pytorch_examples_fsdp2.md
和官方教程参考。

References

参考资料

  • references/pytorch_fsdp2_tutorial.md
  • references/pytorch_fully_shard_api.md
  • references/pytorch_ddp_notes.md
  • references/pytorch_fsdp1_api.md
  • references/pytorch_device_mesh_tutorial.md
  • references/pytorch_tp_tutorial.md
  • references/pytorch_dcp_overview.md
  • references/pytorch_dcp_recipe.md
  • references/pytorch_dcp_async_recipe.md
  • references/pytorch_examples_fsdp2.md
  • references/torchtitan_fsdp_notes.md
    (optional, production notes)
  • references/ray_train_fsdp2_example.md
    (optional, integration example)
  • references/pytorch_fsdp2_tutorial.md
  • references/pytorch_fully_shard_api.md
  • references/pytorch_ddp_notes.md
  • references/pytorch_fsdp1_api.md
  • references/pytorch_device_mesh_tutorial.md
  • references/pytorch_tp_tutorial.md
  • references/pytorch_dcp_overview.md
  • references/pytorch_dcp_recipe.md
  • references/pytorch_dcp_async_recipe.md
  • references/pytorch_examples_fsdp2.md
  • references/torchtitan_fsdp_notes.md
    (可选,生产环境笔记)
  • references/ray_train_fsdp2_example.md
    (可选,集成示例)