pytorch-fsdp2
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseSkill: Use PyTorch FSDP2 (fully_shard
) correctly in a training script
fully_shard技能:在训练脚本中正确使用PyTorch FSDP2(fully_shard
)
fully_shardThis skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
FSDP2 in PyTorch is exposed primarily viaand thetorch.distributed.fsdp.fully_shardmethods it adds in-place to modules. See:FSDPModule,references/pytorch_fully_shard_api.md.references/pytorch_fsdp2_tutorial.md
本技能指导编码Agent如何在训练循环中添加PyTorch FSDP2,包含正确的初始化、分片、混合精度/卸载配置以及 checkpoint 功能。
PyTorch中的FSDP2主要通过及其为模块原地添加的torch.distributed.fsdp.fully_shard方法对外暴露。参考:FSDPModule、references/pytorch_fully_shard_api.md。references/pytorch_fsdp2_tutorial.md
When to use this skill
何时使用本技能
Use FSDP2 when:
- Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
- You may later compose DP with Tensor Parallel using DeviceMesh.
Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.
在以下场景使用FSDP2:
- 你的模型无法容纳在单块GPU中(参数+梯度+优化器状态)。
- 你需要一种基于DTensor逐参数分片的即时模式分片方案(相比FSDP1,状态字典更易检查、更简洁)。
- 你之后可能需要结合数据并行(DP)与张量并行(Tensor Parallel),使用DeviceMesh。
在以下场景避免使用(或谨慎使用):
- 你需要跨PyTorch版本严格向后兼容的checkpoint(DCP对此发出警告)。
- 你被迫使用不支持FSDP2栈的旧版PyTorch。
Alternatives (when FSDP2 is not the best fit)
替代方案(当FSDP2不是最佳选择时)
- DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.
Reference: , .
references/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.md- DistributedDataParallel (DDP):当你需要经典分布式数据并行训练时,使用标准数据并行封装器。
- FullyShardedDataParallel (FSDP1):当你需要在数据并行工作节点间进行参数分片时,使用原始FSDP封装器。
参考:、。
references/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.mdContract the agent must follow
Agent必须遵循的约定
- Launch with and set the CUDA device per process (usually via
torchrun).LOCAL_RANK - Apply bottom-up, i.e., shard submodules (e.g., Transformer blocks) before the root module.
fully_shard() - Call , not
model(input), so the FSDP2 hooks run (unless you explicitlymodel.forward(input)or register the forward method).unshard() - Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-).
fully_shard - Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve unless you deliberately gather to full tensors.
torch.save(model.state_dict())
(Each of these rules is directly described in the official API docs/tutorial; see references.)
- 使用启动,并为每个进程设置CUDA设备(通常通过
torchrun)。LOCAL_RANK - 自底向上应用,即先分片子模块(如TransformerBlock),再分片根模块。
fully_shard() - 调用,而非
model(input),这样FSDP2的钩子才能运行(除非你显式调用model.forward(input)或注册forward方法)。unshard() - 在分片后创建优化器,确保优化器基于DTensor参数(执行之后的参数)构建。
fully_shard - 使用分布式Checkpoint(DCP)或分布式状态字典助手进行Checkpoint,不要直接使用,除非你故意将张量聚合为完整张量。
torch.save(model.state_dict())
(这些规则均直接来自官方API文档/教程;请参考相关引用。)
Step-by-step procedure
分步操作流程
0) Version & environment sanity
0) 版本与环境检查
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use and ensure
torchrun --nproc_per_node <gpus_per_node> ...,RANK,WORLD_SIZEare visible.LOCAL_RANK
Reference: (launch commands and setup), (user contract).
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.md- 优先使用近期稳定版PyTorch,确保其文档中FSDP2和DCP是最新的。
- 使用启动,并确保
torchrun --nproc_per_node <gpus_per_node> ...、RANK、WORLD_SIZE环境变量可见。LOCAL_RANK
参考:(启动命令与设置)、(用户约定)。
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.md1) Initialize distributed and set device
1) 初始化分布式环境并设置设备
Minimal, correct pattern:
dist.init_process_group(backend="nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))- Optionally create a to describe the data-parallel group(s)
DeviceMesh
Reference: (why DeviceMesh exists & how it manages process groups).
references/pytorch_device_mesh_tutorial.md最简正确模式:
dist.init_process_group(backend="nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))- (可选)创建以描述数据并行组。
DeviceMesh
参考:(DeviceMesh的存在意义及进程组管理方式)。
references/pytorch_device_mesh_tutorial.md2) Build model on meta device (recommended for very large models)
2) 在meta设备上构建模型(超大型模型推荐)
For big models, initialize on , apply sharding, then materialize weights on GPU:
metawith torch.device("meta"): model = ...- apply on submodules, then
fully_shard(...)fully_shard(model) model.to_empty(device="cuda")- (or your init routine)
model.reset_parameters()
Reference: (migration guide shows this flow explicitly).
references/pytorch_fsdp2_tutorial.md对于大型模型,先在设备初始化,应用分片,再将权重实例化到GPU:
metawith torch.device("meta"): model = ...- 在子模块上应用,然后对
fully_shard(...)调用modelfully_shard(model) model.to_empty(device="cuda")- (或你的自定义初始化流程)
model.reset_parameters()
参考:(迁移指南明确展示了此流程)。
references/pytorch_fsdp2_tutorial.md3) Apply fully_shard()
bottom-up (wrapping policy = “apply where needed”)
fully_shard()3) 自底向上应用fully_shard()
(封装策略=“按需应用”)
fully_shard()Do not only call on the topmost module.
fully_shardRecommended sharding pattern for transformer-like models:
- iterate modules,
if isinstance(m, TransformerBlock): fully_shard(m, ...) - then
fully_shard(model, ...)
Why:
- forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.
fully_shard
Reference: (bottom-up requirement and why).
references/pytorch_fully_shard_api.md不要仅在最顶层模块调用。
fully_shard针对类Transformer模型的推荐分片模式:
- 遍历模块,
if isinstance(m, TransformerBlock): fully_shard(m, ...) - 然后调用
fully_shard(model, ...)
原因:
- 会为集合通信效率形成“参数组”,并排除已被之前调用分组的参数。自底向上的方式能实现更好的重叠计算,降低峰值内存占用。
fully_shard
参考:(自底向上的要求及原因)。
references/pytorch_fully_shard_api.md4) Configure reshard_after_forward
for memory/perf trade-offs
reshard_after_forward4) 配置reshard_after_forward
以平衡内存/性能
reshard_after_forwardDefault behavior:
- means
Nonefor non-root modules andTruefor root modules (good default).False
Heuristics:
- If you’re memory-bound: keep defaults or force on many blocks.
True - If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often ).
False - Advanced: use an to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.
int
Reference: (full semantics).
references/pytorch_fully_shard_api.md默认行为:
- 表示非根模块使用
None,根模块使用True(良好的默认值)。False
启发式规则:
- 如果受内存限制:保持默认值,或在多个块上强制设置。
True - 如果受吞吐量限制且内存充足:考虑让参数保持未分片状态更久(根模块通常设为)。
False - 高级用法:使用类型值,在forward后将分片恢复到更小的mesh(如节点内),前提是该值是mesh大小的有效约数。
int
参考:(完整语义)。
references/pytorch_fully_shard_api.md5) Mixed precision & offload (optional but common)
5) 混合精度与卸载(可选但常用)
FSDP2 uses:
mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)- if you want CPU offload
offload_policy=CPUOffloadPolicy()
Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep aligned with your gradient reduction expectations.
reduce_dtype - If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.
Reference: (MixedPrecisionPolicy / OffloadPolicy classes).
references/pytorch_fully_shard_api.mdFSDP2使用:
mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)- 如果需要CPU卸载,使用
offload_policy=CPUOffloadPolicy()
经验法则:
- 在H100/A100级GPU上,优先使用BF16参数/梯度归约(如果你的模型数值稳定)。
- 保持与你的梯度归约预期一致。
reduce_dtype - 如果使用CPU卸载,要考虑PCIe/NVLink带宽和运行时开销。
参考:(MixedPrecisionPolicy / OffloadPolicy类)。
references/pytorch_fully_shard_api.md6) Optimizer, gradient clipping, accumulation
6) 优化器、梯度裁剪、梯度累积
- Create the optimizer after sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
- use the FSDP2 mechanism () instead of FSDP1’s
set_requires_gradient_sync.no_sync()
- use the FSDP2 mechanism (
Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.
Reference: .
references/pytorch_fsdp2_tutorial.md- 在分片后创建优化器,使其持有DTensor参数。
- 如果需要梯度累积/禁用同步:
- 使用FSDP2机制(),而非FSDP1的
set_requires_gradient_sync。no_sync()
- 使用FSDP2机制(
梯度裁剪:
- 使用FSDP2教程中展示的方法(“Gradient Clipping and Optimizer with DTensor”),因为参数/梯度是DTensor类型。
参考:。
references/pytorch_fsdp2_tutorial.md7) Checkpointing: prefer DCP or distributed state dict helpers
7) Checkpoint:优先使用DCP或分布式状态字典助手
Two recommended approaches:
A) Distributed Checkpoint (DCP) — best default
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces multiple files (often at least one per rank) and operates “in place”.
B) Distributed state dict helpers
- /
get_model_state_dictwithset_model_state_dictStateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...) - For optimizer: /
get_optimizer_state_dictset_optimizer_state_dict
Avoid:
- Saving DTensor state dicts with plain unless you intentionally convert with
torch.saveand manage memory carefully.DTensor.full_tensor()
References:
- (DCP behavior and caveats)
references/pytorch_dcp_overview.md - and
references/pytorch_dcp_recipe.md(end-to-end usage)references/pytorch_dcp_async_recipe.md - (DTensor vs DCP state-dict flows)
references/pytorch_fsdp2_tutorial.md - (working checkpoint scripts)
references/pytorch_examples_fsdp2.md
两种推荐方案:
A) 分布式Checkpoint(DCP)——最佳默认选择
- DCP支持多节点并行保存/加载,并在加载时支持重新分片。
- DCP生成多个文件(通常每个节点至少一个),并“原地”操作。
B) 分布式状态字典助手
- /
get_model_state_dict结合set_model_state_dictStateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...) - 对于优化器:使用/
get_optimizer_state_dictset_optimizer_state_dict
避免:
- 使用普通保存DTensor状态字典,除非你故意用
torch.save转换并仔细管理内存。DTensor.full_tensor()
参考:
- (DCP的行为与注意事项)
references/pytorch_dcp_overview.md - 和
references/pytorch_dcp_recipe.md(端到端用法)references/pytorch_dcp_async_recipe.md - (DTensor与DCP状态字典流程对比)
references/pytorch_fsdp2_tutorial.md - (可用的checkpoint脚本)
references/pytorch_examples_fsdp2.md
Workflow checklists (copy-paste friendly)
工作流检查清单(可复制粘贴)
Workflow A: Retrofit FSDP2 into an existing training script
工作流A:将FSDP2改造到现有训练脚本中
- Launch with and initialize the process group.
torchrun - Set the CUDA device from ; create a
LOCAL_RANKif you need multi-dim parallelism.DeviceMesh - Build the model (use if needed), apply
metabottom-up, thenfully_shard.fully_shard(model) - Create the optimizer after sharding so it captures DTensor parameters.
- Use so hooks run; use
model(inputs)for accumulation.set_requires_gradient_sync - Add DCP save/load via helpers.
torch.distributed.checkpoint
Reference: , , , .
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_dcp_recipe.md- 使用启动并初始化进程组。
torchrun - 从设置CUDA设备;如果需要多维并行,创建
LOCAL_RANK。DeviceMesh - 构建模型(必要时使用设备),自底向上应用
meta,然后对fully_shard调用model。fully_shard(model) - 在分片后创建优化器,使其捕获DTensor参数。
- 使用以确保钩子运行;使用
model(inputs)进行梯度累积。set_requires_gradient_sync - 通过助手添加DCP保存/加载功能。
torch.distributed.checkpoint
参考:、、、。
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_dcp_recipe.mdWorkflow B: Add DCP save/load (minimal pattern)
工作流B:添加DCP保存/加载(最简模式)
- Wrap state in or assemble state via
Stateful.get_state_dict - Call from all ranks to a shared path.
dcp.save(...) - Call and restore with
dcp.load(...).set_state_dict - Validate any resharding assumptions when loading into a different mesh.
Reference: .
references/pytorch_dcp_recipe.md- 将状态包装在中,或通过
Stateful组装状态。get_state_dict - 从所有节点调用到共享路径。
dcp.save(...) - 调用并通过
dcp.load(...)恢复状态。set_state_dict - 当加载到不同的mesh时,验证所有重新分片的假设。
参考:。
references/pytorch_dcp_recipe.mdDebug checklist (what the agent should check first)
调试检查清单(Agent应首先检查的内容)
- All ranks on distinct GPUs?
If not, verifyand yourtorch.cuda.set_device(LOCAL_RANK)flags.torchrun - Did you accidentally call directly?
forward()
Useor explicitlymodel(input)/ register forward.unshard() - Is applied bottom-up?
fully_shard()
If only root is sharded, expect worse memory/perf and possible confusion. - Optimizer created at the right time?
Must be built on DTensor parameters after sharding. - Checkpointing path consistent?
- If using DCP, don’t mix with ad-hoc unless you understand conversions.
torch.save - Be mindful of PyTorch-version compatibility warnings for DCP.
- If using DCP, don’t mix with ad-hoc
- 所有节点是否使用不同的GPU?
如果不是,验证和你的
torch.cuda.set_device(LOCAL_RANK)参数。torchrun - 是否意外直接调用了? 使用
forward(),或显式调用model(input)/注册forward方法。unshard() - 是否自底向上应用了? 如果仅分片了根模块,预期内存/性能会变差,且可能出现问题。
fully_shard() - 优化器是否在正确的时机创建?
必须在所有调用之后,基于DTensor参数构建。
fully_shard - Checkpoint路径是否一致?
- 如果使用DCP,不要与临时混用,除非你理解转换逻辑。
torch.save - 注意DCP的PyTorch版本兼容性警告。
- 如果使用DCP,不要与临时
Common issues and fixes
常见问题与修复方案
- Forward hooks not running → Call (or
model(inputs)explicitly) instead ofunshard().model.forward(...) - Optimizer sees non-DTensor params → Create optimizer after all calls.
fully_shard - Only root module sharded → Apply bottom-up on submodules before the root.
fully_shard - Memory spikes after forward → Set for more modules.
reshard_after_forward=True - Gradient accumulation desync → Use instead of FSDP1’s
set_requires_gradient_sync.no_sync()
Reference: , .
references/pytorch_fully_shard_api.mdreferences/pytorch_fsdp2_tutorial.md- Forward钩子未运行 → 调用(或显式
model(inputs)),而非unshard()。model.forward(...) - 优化器看到非DTensor参数 → 在所有调用之后创建优化器。
fully_shard - 仅根模块被分片 → 在分片根模块之前,自底向上对子模块应用。
fully_shard - Forward后内存突增 → 为更多模块设置。
reshard_after_forward=True - 梯度累积不同步 → 使用,而非FSDP1的
set_requires_gradient_sync。no_sync()
参考:、。
references/pytorch_fully_shard_api.mdreferences/pytorch_fsdp2_tutorial.mdMinimal reference implementation outline (agent-friendly)
最简参考实现大纲(Agent友好)
The coding agent should implement a script with these labeled blocks:
- : init process group, set device
init_distributed() - : model on meta, apply
build_model_meta(), materialize weightsfully_shard - : optimizer created after sharding
build_optimizer() - : forward/backward/step with
train_step()and DTensor-aware patternsmodel(inputs) - : DCP or distributed state dict helpers
checkpoint_save/load()
Concrete examples live in and the official tutorial reference.
references/pytorch_examples_fsdp2.md编码Agent应实现包含以下标记块的脚本:
- :初始化进程组,设置设备
init_distributed() - :在meta设备构建模型,应用
build_model_meta(),实例化权重fully_shard - :在分片后创建优化器
build_optimizer() - :使用
train_step()和DTensor感知模式执行前向/反向/更新步骤model(inputs) - :使用DCP或分布式状态字典助手
checkpoint_save/load()
具体示例见和官方教程参考。
references/pytorch_examples_fsdp2.mdReferences
参考资料
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_tp_tutorial.mdreferences/pytorch_dcp_overview.mdreferences/pytorch_dcp_recipe.mdreferences/pytorch_dcp_async_recipe.mdreferences/pytorch_examples_fsdp2.md- (optional, production notes)
references/torchtitan_fsdp_notes.md - (optional, integration example)
references/ray_train_fsdp2_example.md
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_tp_tutorial.mdreferences/pytorch_dcp_overview.mdreferences/pytorch_dcp_recipe.mdreferences/pytorch_dcp_async_recipe.mdreferences/pytorch_examples_fsdp2.md- (可选,生产环境笔记)
references/torchtitan_fsdp_notes.md - (可选,集成示例)
references/ray_train_fsdp2_example.md