torch-tensor-parallelism
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseTensor Parallelism Implementation Guide
PyTorch张量并行实现指南
This skill provides guidance for implementing tensor parallelism patterns in PyTorch, specifically for ColumnParallelLinear and RowParallelLinear layers that distribute computation across multiple devices.
本指南提供了在PyTorch中实现张量并行模式的方法,特别是针对ColumnParallelLinear和RowParallelLinear层,这些层可将计算任务分发到多个设备上。
Core Concepts
核心概念
Tensor Parallelism Overview
张量并行概述
Tensor parallelism splits individual layers across multiple devices to parallelize computation within a single forward/backward pass. The two primary patterns are:
-
ColumnParallelLinear: Shards weights along the output dimension (columns). Each device computes a portion of the output features, then results are concatenated via all-gather.
-
RowParallelLinear: Shards weights along the input dimension (rows). Each device computes partial outputs using its shard of the input, then results are summed via all-reduce.
张量并行将单个层拆分到多个设备上,以在单次前向/反向传播中实现计算并行化。两种主要模式如下:
-
ColumnParallelLinear:沿输出维度(列)切分权重。每个设备计算部分输出特征,随后通过all-gather操作将结果拼接。
-
RowParallelLinear:沿输入维度(行)切分权重。每个设备使用自身的输入分片计算部分输出,随后通过all-reduce操作将结果求和。
Critical Implementation Requirement
关键实现要求
When implementing tensor parallelism (especially in simulation or testing contexts), the forward pass must actually perform the collective operations, not just compute local shards:
- ColumnParallelLinear: Must concatenate outputs from all ranks (all-gather semantics)
- RowParallelLinear: Must sum outputs from all ranks (all-reduce semantics)
A common mistake is returning only the local shard and expecting an external framework to handle collective operations. Unless explicitly specified otherwise, the implementation should produce the final, complete output.
实现张量并行时(尤其是在模拟或测试场景下),前向传播必须实际执行集合操作,而不仅仅是计算本地分片:
- ColumnParallelLinear:必须拼接所有rank的输出(符合all-gather语义)
- RowParallelLinear:必须对所有rank的输出求和(符合all-reduce语义)
一个常见错误是仅返回本地分片,期望外部框架处理集合操作。除非有明确说明,否则实现应生成最终的完整输出。
Implementation Approach
实现步骤
Step 1: Understand the Parallelism Pattern
步骤1:理解并行模式
Before implementing, clearly identify:
- Which dimension is being sharded (input features vs output features)
- What collective operation combines the results (all-gather vs all-reduce)
- Whether the implementation should simulate distributed execution or prepare for actual distributed execution
- How bias should be handled in the parallel context
在实现前,需明确以下几点:
- 要切分的维度(输入特征 vs 输出特征)
- 用于合并结果的集合操作(all-gather vs all-reduce)
- 实现是要模拟分布式执行还是为实际分布式执行做准备
- 在并行环境下如何处理偏置(bias)
Step 2: Weight Sharding
步骤2:权重切分
For weight matrix W of shape (out_features, in_features):
ColumnParallelLinear:
- Shard W along dim=0 (output features)
- Each rank gets W_shard of shape (out_features // world_size, in_features)
- Output shape per rank: (batch, out_features // world_size)
RowParallelLinear:
- Shard W along dim=1 (input features)
- Each rank gets W_shard of shape (out_features, in_features // world_size)
- Input to each rank should be corresponding shard of input
- Output shape per rank: (batch, out_features) - partial sum
对于形状为(out_features, in_features)的权重矩阵W:
ColumnParallelLinear:
- 沿dim=0(输出特征)切分W
- 每个rank获得形状为(out_features // world_size, in_features)的W_shard
- 每个rank的输出形状:(batch, out_features // world_size)
RowParallelLinear:
- 沿dim=1(输入特征)切分W
- 每个rank获得形状为(out_features, in_features // world_size)的W_shard
- 每个rank的输入应为对应分片的输入
- 每个rank的输出形状:(batch, out_features) - 部分求和结果
Step 3: Forward Pass Implementation
步骤3:前向传播实现
ColumnParallelLinear Forward:
1. Compute local output: y_local = x @ W_shard.T + bias_shard (if bias per shard)
2. All-gather to concatenate: y = concat([y_0, y_1, ..., y_n], dim=-1)
3. Return complete output of shape (batch, out_features)RowParallelLinear Forward:
1. Get input shard: x_shard = x[..., start:end] for this rank
2. Compute partial output: y_partial = x_shard @ W_shard.T
3. All-reduce to sum: y = sum([y_0, y_1, ..., y_n])
4. Add bias (only once, not per-rank): y = y + bias
5. Return complete output of shape (batch, out_features)ColumnParallelLinear前向传播:
1. 计算本地输出:y_local = x @ W_shard.T + bias_shard(如果每个分片有偏置)
2. 执行all-gather拼接:y = concat([y_0, y_1, ..., y_n], dim=-1)
3. 返回形状为(batch, out_features)的完整输出RowParallelLinear前向传播:
1. 获取输入分片:x_shard = x[..., start:end](当前rank对应的分片)
2. 计算部分输出:y_partial = x_shard @ W_shard.T
3. 执行all-reduce求和:y = sum([y_0, y_1, ..., y_n])
4. 添加偏置(仅执行一次,而非每个rank都添加):y = y + bias
5. 返回形状为(batch, out_features)的完整输出Step 4: Bias Handling
步骤4:偏置处理
ColumnParallelLinear:
- Bias can be sharded along with output features
- Each rank adds its bias shard to its output shard
- After all-gather, the full bias has been applied
RowParallelLinear:
- Bias must NOT be sharded or added per-rank (would cause N-fold bias)
- Add bias only once after the all-reduce operation
- Typically only rank 0 adds bias, OR add bias after the sum
ColumnParallelLinear:
- 偏置可随输出特征一起切分
- 每个rank将自身的偏置分片添加到输出分片
- 执行all-gather后,完整的偏置已被应用
RowParallelLinear:
- 偏置不能被切分或由每个rank单独添加(会导致结果被放大N倍,N为world_size)
- 仅在all-reduce操作后添加一次偏置
- 通常仅由rank 0添加偏置,或在求和完成后添加
Verification Strategies
验证策略
Mathematical Verification
数学验证
When local testing is unavailable, verify implementation correctness through mathematical analysis:
- Simple example: Use a 2x4 weight matrix with world_size=2
- Trace computation: Manually compute what each rank produces
- Verify combination: Confirm all-gather/all-reduce produces correct final output
- Compare to baseline: Verify parallel output matches non-parallel computation
当无法进行本地测试时,可通过数学分析验证实现正确性:
- 简单示例:使用world_size=2和一个2x4的权重矩阵
- 追踪计算过程:手动计算每个rank的输出
- 验证合并结果:确认all-gather/all-reduce操作能生成正确的最终输出
- 与基线对比:验证并行实现的输出与非并行实现的输出一致
Shape Verification Checklist
形状验证清单
- Input shape matches expected (batch, in_features)
- Weight shard shape matches expected partitioning
- Local output shape is correct for the parallelism type
- Final output shape matches (batch, out_features) - NOT the sharded dimension
- 输入形状符合预期(batch, in_features)
- 权重分片形状符合预期的切分方式
- 本地输出形状对应正确的并行类型
- 最终输出形状为(batch, out_features),而非切分后的维度
Test Cases to Consider
需考虑的测试用例
- world_size=1: Trivial case, should match non-parallel implementation exactly
- world_size=2,4,8: Common parallel configurations
- Non-divisible dimensions: What happens when out_features % world_size != 0?
- Different batch sizes: Verify batch dimension is handled correctly
- With and without bias: Test both configurations
- world_size=1:简单场景,应与非并行实现完全一致
- world_size=2,4,8:常见的并行配置
- 非可整除维度:当out_features % world_size != 0时如何处理?
- 不同批量大小:验证批量维度是否被正确处理
- 带偏置和不带偏置:测试两种配置
Common Pitfalls
常见陷阱
Pitfall 1: Returning Local Shards Only
陷阱1:仅返回本地分片
Symptom: Output tensor size is (out_features / world_size) instead of (out_features)
Cause: Implementation computes local shard but doesn't perform all-gather
Fix: Implement the collective operation to combine results from all ranks
症状:输出张量大小为(out_features / world_size)而非(out_features)
原因:实现仅计算了本地分片,但未执行all-gather操作
解决方法:实现集合操作以合并所有rank的结果
Pitfall 2: Incorrect Bias Handling in RowParallelLinear
陷阱2:RowParallelLinear中偏置处理错误
Symptom: Output values are N times larger than expected (where N is world_size)
Cause: Each rank adds the full bias, then values are summed
Fix: Add bias only once after all-reduce, not per-rank
症状:输出值比预期大N倍(N为world_size)
原因:每个rank都添加了完整的偏置,随后结果被求和
解决方法:仅在all-reduce操作后添加一次偏置,而非每个rank单独添加
Pitfall 3: Misinterpreting "Simulation" Requirements
陷阱3:误解“模拟”需求
Symptom: Implementation works for world_size=1 but fails for larger world sizes
Cause: Assuming external framework handles collective operations
Fix: Read requirements carefully - "as if using all_gather" means implement the operation
症状:实现在world_size=1时正常,但在更大的world_size下失败
原因:假设外部框架会处理集合操作
解决方法:仔细阅读需求——“如同使用all_gather”意味着需要自行实现该操作
Pitfall 4: Truncated File Writes
陷阱4:文件写入被截断
Symptom: Implementation has syntax errors or missing code
Cause: File write operation was truncated
Fix: Always read back the complete file after writing to verify integrity
症状:实现存在语法错误或代码缺失
原因:文件写入操作被截断
解决方法:写入文件后务必回读完整内容以验证完整性
Pitfall 5: Wrong Dimension for Sharding
陷阱5:切分维度错误
Symptom: Shape mismatch errors during matrix multiplication
Cause: Sharding along wrong dimension (rows vs columns confusion)
Fix: ColumnParallel shards output features (dim=0 of weight), RowParallel shards input features (dim=1 of weight)
症状:矩阵乘法时出现形状不匹配错误
原因:沿错误维度切分(混淆了行和列)
解决方法:ColumnParallel切分输出特征(权重的dim=0),RowParallel切分输入特征(权重的dim=1)
Pre-Implementation Checklist
实现前检查清单
Before writing code, confirm understanding of:
- Which collective operation is needed (all-gather vs all-reduce)
- What the final output shape should be
- Whether simulation should actually perform collective ops or defer them
- How bias should be handled for this parallelism type
- What happens for edge cases (world_size=1, non-divisible dimensions)
编写代码前,确认已理解以下内容:
- 需要使用哪种集合操作(all-gather vs all-reduce)
- 最终输出形状应为
- 模拟实现是否需要实际执行集合操作还是延迟执行
- 针对该并行类型应如何处理偏置
- 边缘情况(world_size=1、非可整除维度)的处理方式
Post-Implementation Checklist
实现后检查清单
After writing code:
- Read back the complete implementation file to verify no truncation
- Verify output shapes match expected dimensions for all world sizes
- Trace through a simple example manually to verify correctness
- Test trivial case (world_size=1) matches non-parallel baseline
- Test at least one non-trivial case (world_size=2 or 4)
编写代码后:
- 回读完整的实现文件以验证无截断
- 验证所有world_size配置下的输出形状符合预期
- 通过手动追踪简单示例验证正确性
- 测试简单场景(world_size=1)与非并行基线一致
- 至少测试一个非简单场景(world_size=2或4)