triton-operator-code-gen

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Triton 算子代码生成

Triton Operator Code Generation

核心原则

Core Principles

计算逻辑 → Tiling 策略 → 代码实现
这个顺序绝对不可颠倒。错误的计算逻辑会导致完全错误的结果,错误的 tiling 策略会导致性能问题或内存溢出。
Computing Logic → Tiling Strategy → Code Implementation
This order must not be reversed. Incorrect computing logic will lead to completely wrong results, while incorrect tiling strategy will cause performance issues or memory overflow.

参考资源加载路由表

Reference Resource Loading Routing Table

MANDATORY - 按需加载:根据任务阶段加载对应的参考文档
阶段必须加载不要加载
理解需求文档所有 references
确认计算逻辑所有 references
设计 Tiling 策略
hardware-architecture.md
templates.md
生成 Kernel 代码
templates.md
hardware-architecture.md
生成测试代码所有 references
MANDATORY - Load On Demand: Load corresponding reference documents according to task stages
PhaseMust LoadDo Not Load
Understand Requirement DocumentsNoneAll references
Confirm Computing LogicNoneAll references
Design Tiling Strategy
hardware-architecture.md
templates.md
Generate Kernel Code
templates.md
hardware-architecture.md
Generate Test CodeNoneAll references

工作流程

Workflow

阶段 1:理解需求文档

Phase 1: Understand Requirement Documents

提取:数学公式、输入输出规格、约束条件、Tiling 策略
Extract: mathematical formulas, input/output specifications, constraints, tiling strategies

阶段 2:确认计算逻辑

Phase 2: Confirm Computing Logic

  1. 用伪代码描述计算过程
  2. 确认数据依赖关系
  3. 确认精度处理(归约操作必须使用 FP32)
输出:计算逻辑确认(必须与用户确认)
  1. Describe the computing process with pseudocode
  2. Confirm data dependency relationships
  3. Confirm precision handling (reduction operations must use FP32)
Output: Computing logic confirmation (must be confirmed with the user)

阶段 3:设计 Tiling 策略

Phase 3: Design Tiling Strategy

MANDATORY - READ ENTIRE FILE:在设计 Tiling 策略之前,你必须完整阅读
hardware-architecture.md
绝对不要设置任何行数限制。
核间切分原则(必须遵循)
  1. grid = 物理核数:保证利用每个核,避免资源浪费
  2. 核内均衡负载:每个核自己计算要处理哪些数据,实现负载均衡
python
core_num = get_npu_aicore_num()  # 或 get_npu_vectorcore_num()

grid = (core_num,)  # 原则1:grid必须等于物理核数

@triton.jit
def xxx_fwd(
    ......
    M, N,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid = tl.program_id(0)
    num_core = tl.num_programs(0)

    num_block_m = tl.cdiv(M, BLOCK_M)
    num_block_n = tl.cdiv(N, BLOCK_N)

    total_blocks = num_block_m * num_block_n

    # 原则2:核内循环处理多任务,每个核自己计算要处理的数据
    for block_idx in range(pid, total_blocks, num_core):
        pid_m = block_idx // num_block_n
        pid_n = block_idx % num_block_n
UB空间计算
UB 总大小: 192KB (A2/A3)
安全 BLOCK_SIZE = (196608 - 32) / (缓冲区数量 × 数据类型大小) × 0.8
MANDATORY - READ ENTIRE FILE: You must read the entire
hardware-architecture.md
before designing the tiling strategy.
Never set any line limits.
Inter-core Partitioning Principles (Must Follow):
  1. grid = number of physical cores: Ensure every core is utilized to avoid resource waste
  2. Balanced intra-core load: Each core calculates which data to process on its own to achieve load balance
python
core_num = get_npu_aicore_num()  # or get_npu_vectorcore_num()

grid = (core_num,)  # Principle 1: grid must equal the number of physical cores

@triton.jit
def xxx_fwd(
    ......
    M, N,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid = tl.program_id(0)
    num_core = tl.num_programs(0)

    num_block_m = tl.cdiv(M, BLOCK_M)
    num_block_n = tl.cdiv(N, BLOCK_N)

    total_blocks = num_block_m * num_block_n

    # Principle 2: Intra-core loop handles multiple tasks, each core calculates the data it needs to process
    for block_idx in range(pid, total_blocks, num_core):
        pid_m = block_idx // num_block_n
        pid_n = block_idx % num_block_n
UB Space Calculation:
Total UB size: 192KB (A2/A3)
Safe BLOCK_SIZE = (196608 - 32) / (number of buffers × data type size) × 0.8

阶段 4:生成 Kernel 代码

Phase 4: Generate Kernel Code

MANDATORY - READ ENTIRE FILE:在生成代码之前,你必须完整阅读
templates.md
绝对不要设置任何行数限制。
根据算子类型选择对应模板灵活参考:
算子类型特征核心类型模板
归约类sum/max/min 归约vector core模板 1
GEMMtl.dot() 矩阵乘法AI core模板 2
激活函数逐元素计算vector core模板 3
损失函数softmax + reductionvector core模板 4
索引变换索引计算、条件分支vector core模板 5
注意力QK^T + SV 多阶段AI core模板 6
MoE门控机制vector core模板 7
后处理简单数据变换vector core模板 8
卷积状态更新、滑动窗口AI core模板 9
MANDATORY - READ ENTIRE FILE: You must read the entire
templates.md
before generating code.
Never set any line limits.
Flexibly refer to the corresponding template based on operator type:
Operator TypeFeaturesCore TypeTemplate
Reductionsum/max/min reductionvector coreTemplate 1
GEMMtl.dot() matrix multiplicationAI coreTemplate 2
Activation FunctionElement-wise calculationvector coreTemplate 3
Loss Functionsoftmax + reductionvector coreTemplate 4
Index TransformationIndex calculation, conditional branchingvector coreTemplate 5
AttentionQK^T + SV multi-stageAI coreTemplate 6
MoEGating mechanismvector coreTemplate 7
Post-processingSimple data transformationvector coreTemplate 8
ConvolutionState update, sliding windowAI coreTemplate 9

阶段 5:生成测试代码

Phase 5: Generate Test Code

反模式清单(NEVER)

Anti-Pattern List (NEVER)

  • ❌ 不确认计算逻辑就开始写代码
  • ❌ 忽略 UB 大小限制(192KB)
  • ❌ 归约操作不使用 FP32 精度
  • ❌ 使用 int64 数据类型(性能极差)
  • ❌ grid 大小超过 65535
  • ❌ 在 kernel 中使用第三方库
  • ❌ 一个元素一个元素地计算
  • ❌ 过度复杂的优化(如对角线分核)
  • ❌ 调用第三方函数获取核数
  • ❌ 混淆 Vector Core 和 Cube Core 的用途
  • ❌ 使用pytorch而不用triton实现算子
  • ❌ 不测试算子的正确性
  • ❌ 不在npu上测试算子
  • ❌ 不确保测试标杆的准确性
  • ❌ grid大小不等于物理核数(违反核间切分原则1)
  • ❌ 核间负载不均衡(违反核间切分原则2)
  • ❌ Don't start writing code without confirming computing logic
  • ❌ Ignore UB size limit (192KB)
  • ❌ Don't use FP32 precision for reduction operations
  • ❌ Use int64 data type (extremely poor performance)
  • ❌ grid size exceeds 65535
  • ❌ Use third-party libraries in kernel
  • ❌ Calculate element by element
  • ❌ Overly complex optimizations (e.g., diagonal core partitioning)
  • ❌ Call third-party functions to get core count
  • ❌ Confuse the usage of Vector Core and Cube Core
  • ❌ Implement operators with PyTorch instead of Triton
  • ❌ Don't test operator correctness
  • ❌ Don't test operators on NPU
  • ❌ Don't ensure the accuracy of test benchmarks
  • ❌ grid size is not equal to the number of physical cores (violates Inter-core Partitioning Principle 1)
  • ❌ Unbalanced inter-core load (violates Inter-core Partitioning Principle 2)

常见陷阱

Common Pitfalls

陷阱症状解决方案
计算逻辑错误输出结果与预期不符用伪代码描述计算过程,与用户确认
UB 溢出运行时报错 "ub overflow"计算缓冲区总大小,减小 BLOCK_SIZE
coreDim 超限运行时报错 "coreDim can't be greater than UINT16_MAX"增大 BLOCK_SIZE 或设置
TRITON_ALL_BLOCKS_PARALLEL=1
精度损失FP16 输入时结果不准确归约操作前升精度到 FP32
索引长度不够D-cache报错在超大shape下int32索引长度不足,需要换成int64
PitfallSymptomSolution
Incorrect computing logicOutput results do not match expectationsDescribe the computing process with pseudocode and confirm with the user
UB overflowRuntime error "ub overflow"Calculate total buffer size and reduce BLOCK_SIZE
coreDim exceededRuntime error "coreDim can't be greater than UINT16_MAX"Increase BLOCK_SIZE or set
TRITON_ALL_BLOCKS_PARALLEL=1
Precision lossInaccurate results with FP16 inputUpgrade precision to FP32 before reduction operations
Insufficient index lengthD-cache errorReplace int32 with int64 for index when dealing with super-large shapes

检查清单

Checklist

计算逻辑

Computing Logic

  • 数学公式理解正确
  • 伪代码与公式一致
  • 边界条件处理正确
  • 数据类型转换正确
  • Mathematical formulas are correctly understood
  • Pseudocode is consistent with formulas
  • Boundary conditions are handled correctly
  • Data type conversions are correct

Tiling 策略

Tiling Strategy

  • grid = 物理核数(原则1)
  • 核内循环处理多任务,负载均衡(原则2)
  • UB 空间计算正确
  • BLOCK_SIZE 选择合理
  • grid = number of physical cores (Principle 1)
  • Intra-core loop handles multiple tasks with balanced load (Principle 2)
  • UB space calculation is correct
  • BLOCK_SIZE is reasonably selected

Kernel 实现

Kernel Implementation

  • 核数获取函数正确调用
  • 指针计算正确
  • mask 处理正确
  • 精度处理正确(归约用 FP32)
  • 无第三方库依赖
  • Core count acquisition function is called correctly
  • Pointer calculation is correct
  • Mask handling is correct
  • Precision handling is correct (use FP32 for reduction)
  • No third-party library dependencies

测试代码

Test Code

  • PyTorch 参考实现正确
  • 测试用例覆盖多种形状
  • 测试用例覆盖多种数据类型
  • 精度容差设置合理
  • 执行测试代码,确保算子正确运行
  • PyTorch reference implementation is correct
  • Test cases cover multiple shapes
  • Test cases cover multiple data types
  • Precision tolerance is set reasonably
  • Execute test code to ensure the operator runs correctly