kernel-triton-writing

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Triton Kernel Writing

Triton 内核编写

Principles

原则

Correctness First

正确性优先

  1. Never benchmark before verification passes.
  2. Always mask loads and stores for non-divisible shapes.
  3. Include
    kernel_fn
    ,
    reference_fn
    , and
    get_inputs()
    exports for companion scripts.
  4. Always run
    scripts/verify_kernel.py
    to validate against the reference.
  1. 验证通过前切勿进行基准测试。
  2. 对于无法整除的形状,始终对加载和存储操作进行掩码处理。
  3. 导出
    kernel_fn
    reference_fn
    get_inputs()
    供配套脚本使用。
  4. 始终运行
    scripts/verify_kernel.py
    以对照参考实现验证正确性。

FP16/BF16 Precision Rules (LOW FREEDOM -- follow exactly)

FP16/BF16 精度规则(自由度极低——严格遵循)

Transcendental functions (
tl.exp
,
tl.log
,
tl.math.erf
,
tl.math.tanh
) require fp32 inputs.
python
undefined
超越函数(
tl.exp
tl.log
tl.math.erf
tl.math.tanh
)要求输入为fp32类型。
python
undefined

WRONG -- compilation error or wrong results with fp16/bf16:

WRONG -- compilation error or wrong results with fp16/bf16:

result = tl.exp(x)
result = tl.exp(x)

CORRECT -- cast to fp32, compute, cast back:

CORRECT -- cast to fp32, compute, cast back:

x_fp32 = x.to(tl.float32) result = tl.exp(x_fp32).to(x.dtype)

Rule: any math function beyond basic arithmetic (+, -, *, /) requires fp32 cast in, original dtype cast out.

Additional precision constraints:
- `tl.sigmoid()` is unavailable in some Triton versions. Use `1.0 / (1.0 + tl.exp(-x_fp32))`.
- Always cast back to `x.dtype` before `tl.store` -- mismatches cause "Type mismatch, store Float32 to Float16".
- Unlike PyTorch, Triton does NOT auto-promote fp16/bf16 to fp32 for accumulation. Always use `tl.float32` accumulators for `tl.dot`.
- **TF32 for matmul:** On Ampere+/Hopper, `tl.dot` uses TF32 by default for fp32 inputs (same as `torch.mm`). Do NOT add `input_precision="ieee"` — it is 3-8x slower. TF32 is the correct default. If verification fails due to TF32 precision (~0.01-0.1 abs diff), ensure `reference_fn` also uses TF32 (plain `torch.mm`, no `allow_tf32=False`).
x_fp32 = x.to(tl.float32) result = tl.exp(x_fp32).to(x.dtype)

规则:除基础算术运算(+、-、*、/)之外的任何数学函数,都需要先转换为fp32计算,再转换回原始数据类型。

额外精度约束:
- 部分Triton版本中`tl.sigmoid()`不可用。使用`1.0 / (1.0 + tl.exp(-x_fp32))`替代。
- 在执行`tl.store`前必须转换回`x.dtype`——类型不匹配会导致"Type mismatch, store Float32 to Float16"错误。
- 与PyTorch不同,Triton不会自动将fp16/bf16提升为fp32进行累加。对于`tl.dot`,始终使用`tl.float32`累加器。
- **矩阵乘法的TF32支持**:在Ampere+/Hopper架构上,`tl.dot`默认对fp32输入使用TF32(与`torch.mm`一致)。请勿添加`input_precision="ieee"`——这会使速度慢3-8倍。TF32是正确的默认选项。如果因TF32精度问题(绝对误差约0.01-0.1)导致验证失败,请确保`reference_fn`也使用TF32(使用普通`torch.mm`,不要设置`allow_tf32=False`)。

CPU-GPU Sync Avoidance (LOW FREEDOM)

避免CPU-GPU同步(自由度极低)

Never call
.item()
in kernel wrappers. It forces a CPU-GPU sync (~50-100us per call).
PitfallFix
tensor.item()
for seed
x.data_ptr() % (2**31)
torch.randint(...).item()
Use tensor metadata for pseudo-random seed
Allocating output every callAccept pre-allocated output as parameter
Python loops calling kernelBatch operations
在内核包装器中切勿调用
.item()
。这会强制进行CPU-GPU同步(每次调用约50-100微秒)。
陷阱修复方案
使用
tensor.item()
获取种子
使用
x.data_ptr() % (2**31)
使用
torch.randint(...).item()
使用张量元数据生成伪随机种子
每次调用都分配输出张量接受预分配的输出作为参数
使用Python循环调用内核批量处理操作

C Integer Division Semantics (CRITICAL)

C语言整数除法语义(至关重要)

Triton uses C semantics (round toward zero) for
//
and
%
, NOT Python semantics (round toward negative infinity). This only matters when operands can be negative.
ExpressionPythonTriton/C
-7 // 2
-4
-3
-7 % 2
1
-1
Safe pattern: Ensure all index/offset values are non-negative. If negative values are possible, use
(idx % BLOCK + BLOCK) % BLOCK
.
See references/concepts-semantics.md for full rules and scalar-only exception.
Triton对
//
%
使用C语言语义(向零取整),而非Python语义(向负无穷取整)。这仅在操作数可能为负数时才会产生影响。
表达式PythonTriton/C
-7 // 2
-4
-3
-7 % 2
1
-1
安全模式:确保所有索引/偏移值均为非负数。如果可能出现负值,使用
(idx % BLOCK + BLOCK) % BLOCK
有关完整规则和标量例外情况,请参阅references/concepts-semantics.md

Kernel Design Mental Model

内核设计思维模型

  • Parallelization axis: Element-wise kernels parallelize over flattened elements. Row-wise kernels (LayerNorm, softmax) parallelize over rows. Matmul kernels tile in 2D (M, N).
  • Block size: Power-of-2 only (256, 512, 1024, 2048). Start with 1024 for H100, 512 for V100.
  • Memory coalescing: Adjacent threads must access adjacent memory addresses. The compiler handles this automatically from block-level pointer arithmetic.
  • Grid: Use
    triton.cdiv(n_elements, BLOCK_SIZE)
    . With autotune, grid must be a lambda:
    lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
    .
  • Decorator order:
    @triton.autotune
    (outermost) ->
    @triton.heuristics
    ->
    @triton.jit
    (innermost).
  • reset_to_zero
    :
    Required for autotune on kernels that accumulate (e.g., matmul output). Without it, later configs see leftover values from earlier trials.
  • 并行化轴:逐元素内核在扁平化元素上并行化。按行处理的内核(LayerNorm、softmax)在行上并行化。矩阵乘法内核在二维(M、N)上进行分块。
  • 块大小:仅支持2的幂(256、512、1024、2048)。H100从1024开始,V100从512开始。
  • 内存合并:相邻线程必须访问相邻的内存地址。编译器会通过块级指针运算自动处理此问题。
  • 网格:使用
    triton.cdiv(n_elements, BLOCK_SIZE)
    。结合自动调优时,网格必须是一个lambda函数:
    lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
  • 装饰器顺序
    @triton.autotune
    (最外层)->
    @triton.heuristics
    ->
    @triton.jit
    (最内层)。
  • reset_to_zero
    :对于进行累加的内核(例如矩阵乘法输出),自动调优时需要此选项。如果没有它,后续配置会看到之前试验留下的值。

Workflow

工作流程

Fast path: If the user explicitly requests a Triton kernel (e.g., "Write a Triton kernel for X", "Implement softmax in Triton"), start at Phase 2. Only use Phase 0-1 when the request is ambiguous about whether Triton is appropriate.
快速路径:如果用户明确请求Triton内核(例如“编写一个用于X的Triton内核”、“用Triton实现softmax”),直接从阶段2开始。仅当请求对是否适合使用Triton存在歧义时,才使用阶段0-1。

Phase 0: Route the Operator (only for ambiguous requests)

阶段0:选择算子(仅针对歧义请求)

Skip this phase if the user explicitly asks for a Triton kernel. Only use when the request is ambiguous (e.g., "optimize this operation").
Triton wins when 2+ operations can share registers instead of writing/reading global memory. Quick rules:
PatternDecision
Single element-wise op (
relu
,
sigmoid
)
SKIP — PyTorch already optimal
Standalone matmulSKIP — cuBLAS is optimized
Standard attentionSKIP — Use FlashAttention
Element-wise chain (2+ ops), reduction, matmul + epilogueUSE TRITON
If SKIP, recommend the alternative and STOP. See references/operator-routing.md for edge cases.
如果用户明确要求Triton内核,跳过此阶段。仅当请求存在歧义时使用(例如“优化此操作”)。
当2个及以上操作可以共享寄存器,而非写入/读取全局内存时,Triton更具优势。快速判断规则:
模式决策
单一逐元素操作(
relu
sigmoid
跳过——PyTorch已优化至最优
独立矩阵乘法跳过——cuBLAS已优化
标准注意力机制跳过——使用FlashAttention
逐元素操作链(2个及以上)、归约操作、矩阵乘法+结尾操作使用TRITON
如果选择跳过,推荐替代方案并停止。有关边缘情况,请参阅references/operator-routing.md

Phase 1: Analyze the Operator (only for ambiguous requests)

阶段1:分析算子(仅针对歧义请求)

From the user's request, identify: (1) operation type, (2) parallelization strategy, (3) input shapes and dtypes.
从用户的请求中确定:(1) 操作类型,(2) 并行化策略,(3) 输入形状和数据类型。

Phase 2: Design the Kernel

阶段2:设计内核

Pick the skeleton below that matches your operation. These skeletons are sufficient for element-wise, reduction, matmul, and fusion kernels — do NOT read reference files for these common patterns. Only consult
references/
when implementing uncommon patterns (grouped GEMM, TMA, extern functions) or debugging issues.
Element-wise skeleton (GELU, dropout, fused ops on flat tensors):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # ... compute ...
    tl.store(out_ptr + offsets, result, mask=mask)
Row-wise skeleton (softmax, LayerNorm, RMSNorm — one program per row):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    x = tl.load(x_ptr + row_idx * n_cols + col_offsets, mask=mask, other=0.0)
    # ... reduce / normalize ...
    tl.store(out_ptr + row_idx * n_cols + col_offsets, result, mask=mask)
Tiled matmul skeleton (GEMM with 2D tiling, grouped ordering, and autotune):
python
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr, M, N, K,
    stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    pid = tl.program_id(0)
    num_m_blocks = tl.cdiv(M, BLOCK_M)
    num_n_blocks = tl.cdiv(N, BLOCK_N)
    # Grouped ordering for L2 cache locality
    num_pid_in_group = GROUP_SIZE_M * num_n_blocks
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_m_blocks - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
        b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        acc += tl.dot(a, b)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
        offs_k += BLOCK_K

    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
选择与操作匹配的骨架。**这些骨架足以覆盖逐元素、归约、矩阵乘法和融合内核——对于这些常见模式,无需查阅参考文件。**仅在实现不常见模式(分组GEMM、TMA、外部函数)或调试问题时,才查阅
references/
目录。
逐元素骨架(GELU、dropout、扁平张量上的融合操作):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # ... compute ...
    tl.store(out_ptr + offsets, result, mask=mask)
按行处理骨架(softmax、LayerNorm、RMSNorm——每行一个程序):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    x = tl.load(x_ptr + row_idx * n_cols + col_offsets, mask=mask, other=0.0)
    # ... reduce / normalize ...
    tl.store(out_ptr + row_idx * n_cols + col_offsets, result, mask=mask)
分块矩阵乘法骨架(带有二维分块、分组排序和自动调优的GEMM):
python
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr, M, N, K,
    stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    pid = tl.program_id(0)
    num_m_blocks = tl.cdiv(M, BLOCK_M)
    num_n_blocks = tl.cdiv(N, BLOCK_N)
    # Grouped ordering for L2 cache locality
    num_pid_in_group = GROUP_SIZE_M * num_n_blocks
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_m_blocks - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
        b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        acc += tl.dot(a, b)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
        offs_k += BLOCK_K

    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)

Phase 3: Write the Kernel

阶段3:编写内核

Create an output directory, then write the kernel file to
{output_dir}/kernel.py
.
The kernel file MUST include:
  • @triton.jit
    decorated kernel function
  • @triton.autotune
    for production kernels (see references/api-core.md)
  • Python wrapper function (descriptive name for external import)
  • Fixed contract exports (companion scripts rely on these exact names):
    • kernel_fn
      — alias to the wrapper function
    • reference_fn(*args)
      — PyTorch reference with identical signature
    • get_inputs()
      — returns
      list
      of fresh CUDA tensors for testing/benchmarking
Concise example (fused GELU + dropout):
python
import triton
import triton.language as tl
import torch

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
    ],
    key=['n_elements'],
)
@triton.jit
def fused_gelu_dropout_kernel(
    x_ptr, out_ptr, n_elements, p, seed,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    x_fp32 = x.to(tl.float32)
    x = (0.5 * x_fp32 * (1.0 + tl.math.erf(x_fp32 * 0.7071067811865476))).to(x.dtype)

    random = tl.rand(seed, offsets)
    x = tl.where(random > p, x / (1.0 - p), 0.0)

    tl.store(out_ptr + offsets, x, mask=mask)


def fused_gelu_dropout_triton(x: torch.Tensor, p: float = 0.1) -> torch.Tensor:
    n_elements = x.numel()
    out = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    seed = (x.data_ptr() % (2**31)) ^ n_elements  # sync-free seed
    fused_gelu_dropout_kernel[grid](x, out, n_elements, p, seed)
    return out
创建输出目录,然后将内核文件写入
{output_dir}/kernel.py
内核文件必须包含:
  • 带有
    @triton.jit
    装饰器的内核函数
  • 生产级内核需添加
    @triton.autotune
    (参阅references/api-core.md
  • Python包装器函数(便于外部导入的描述性名称)
  • 固定契约导出(配套脚本依赖这些确切名称):
    • kernel_fn
      —— 包装器函数的别名
    • reference_fn(*args)
      —— 具有相同签名的PyTorch参考实现
    • get_inputs()
      —— 返回用于测试/基准测试的全新CUDA张量列表
简洁示例(融合GELU + dropout):
python
import triton
import triton.language as tl
import torch

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
    ],
    key=['n_elements'],
)
@triton.jit
def fused_gelu_dropout_kernel(
    x_ptr, out_ptr, n_elements, p, seed,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    x_fp32 = x.to(tl.float32)
    x = (0.5 * x_fp32 * (1.0 + tl.math.erf(x_fp32 * 0.7071067811865476))).to(x.dtype)

    random = tl.rand(seed, offsets)
    x = tl.where(random > p, x / (1.0 - p), 0.0)

    tl.store(out_ptr + offsets, x, mask=mask)


def fused_gelu_dropout_triton(x: torch.Tensor, p: float = 0.1) -> torch.Tensor:
    n_elements = x.numel()
    out = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    seed = (x.data_ptr() % (2**31)) ^ n_elements  # sync-free seed
    fused_gelu_dropout_kernel[grid](x, out, n_elements, p, seed)
    return out

--- Fixed contract (companion scripts rely on these names) ---

--- Fixed contract (companion scripts rely on these names) ---

kernel_fn = fused_gelu_dropout_triton
def reference_fn(x, p=0.1): torch.manual_seed((x.data_ptr() % (2**31)) ^ x.numel()) return torch.nn.functional.dropout( torch.nn.functional.gelu(x), p, training=True )
def get_inputs(): return [torch.randn(128 * 1024 * 1024, device="cuda")]

For more patterns (SiLU+mul, RMSNorm, linear+GELU, add+LayerNorm), see [references/patterns-fusion.md](references/patterns-fusion.md). For GEMM patterns, see [references/patterns-gemm.md](references/patterns-gemm.md).
kernel_fn = fused_gelu_dropout_triton
def reference_fn(x, p=0.1): torch.manual_seed((x.data_ptr() % (2**31)) ^ x.numel()) return torch.nn.functional.dropout( torch.nn.functional.gelu(x), p, training=True )
def get_inputs(): return [torch.randn(128 * 1024 * 1024, device="cuda")]

更多模式(SiLU+mul、RMSNorm、linear+GELU、add+LayerNorm)请参阅[references/patterns-fusion.md](references/patterns-fusion.md)。有关GEMM模式,请参阅[references/patterns-gemm.md](references/patterns-gemm.md)。

Phase 4: Verify Correctness

阶段4:验证正确性

Run the companion verification script:
bash
python scripts/verify_kernel.py {output_dir}/kernel.py --rtol 1e-3 --atol 1e-3
Output:
json
{"correct": true, "max_abs_diff": 1.2e-7, "max_rel_diff": 3.4e-6, "details": "..."}
Stop if
correct: false
.
Fix the kernel before benchmarking.
Tolerance guide:
DtypertolatolNotes
float161e-31e-3
bfloat161e-21e-2
float321e-51e-5Element-wise ops
float32 (matmul)1e-21e-1TF32 accumulation order differs between Triton tiles and cuBLAS
运行配套验证脚本:
bash
python scripts/verify_kernel.py {output_dir}/kernel.py --rtol 1e-3 --atol 1e-3
输出:
json
{"correct": true, "max_abs_diff": 1.2e-7, "max_rel_diff": 3.4e-6, "details": "..."}
**如果
correct: false
,请停止。**在进行基准测试前修复内核。
容差指南:
数据类型rtolatol说明
float161e-31e-3
bfloat161e-21e-2
float321e-51e-5逐元素操作
float32(矩阵乘法)1e-21e-1Triton分块与cuBLAS的TF32累加顺序不同

Phase 5: Benchmark Performance (optional)

阶段5:性能基准测试(可选)

Only benchmark if the user explicitly requests performance numbers. Skip this phase for correctness-focused requests.
bash
python scripts/benchmark_kernel.py {output_dir}/kernel.py
Output:
json
{"kernel_time_ms": 0.45, "reference_time_ms": 1.23, "speedup": 2.73, "warmup_iters": 10, "benchmark_iters": 40}
仅当用户明确请求性能数据时才进行基准测试。对于以正确性为重点的请求,跳过此阶段。
bash
python scripts/benchmark_kernel.py {output_dir}/kernel.py
输出:
json
{"kernel_time_ms": 0.45, "reference_time_ms": 1.23, "speedup": 2.73, "warmup_iters": 10, "benchmark_iters": 40}

References (consult only when stuck)

参考资料(仅在遇到问题时查阅)

The skeletons and principles above cover element-wise, reduction, matmul, and fusion kernels. Do NOT read reference files for these common patterns.
Only consult
references/
when:
  • Implementing uncommon patterns (grouped GEMM, TMA, persistent matmul, extern functions)
  • Debugging a compile error or incorrect result not covered by the error table below
  • Needing API details for an unfamiliar
    tl.*
    operation
How to search: Grep for your keyword across
references/
. Read only the file Grep points to.
FileWhen to use
references/api-core.md
Unfamiliar
triton.autotune
/
triton.Config
options
references/api-language.md
Unfamiliar
tl.*
operations
references/patterns-gemm.md
Grouped GEMM, persistent matmul, TMA, MX formats
references/patterns-advanced.md
Flash attention details, backward passes, libdevice
references/troubleshooting.md
Debug ops, interpreter mode, env vars
上述骨架和原则涵盖了逐元素、归约、矩阵乘法和融合内核。对于这些常见模式,请勿查阅参考文件。
仅在以下情况时查阅
references/
  • 实现不常见模式(分组GEMM、TMA、持久化矩阵乘法、外部函数)
  • 调试未被以下错误表覆盖的编译错误或不正确结果
  • 需要不熟悉的
    tl.*
    操作的API细节
搜索方法:
references/
目录中搜索关键词。仅阅读搜索指向的文件。
文件使用场景
references/api-core.md
不熟悉
triton.autotune
/
triton.Config
选项
references/api-language.md
不熟悉
tl.*
操作
references/patterns-gemm.md
分组GEMM、持久化矩阵乘法、TMA、MX格式
references/patterns-advanced.md
Flash Attention细节、反向传播、libdevice
references/troubleshooting.md
调试操作、解释器模式、环境变量

Error Handling and Troubleshooting

错误处理与故障排除

Common Errors

常见错误

Error / SymptomCauseFix
"Type mismatch, store Float32 to Float16"Missing
.to(x.dtype)
before store
Cast fp32 result back
BLOCK_SIZE is not a constexpr
Block size passed as runtime valueAdd
: tl.constexpr
annotation
shape mismatch
in binary op
Tensor shapes don't broadcastCheck with
tl.static_print
; use
[:, None]
/
[None, :]
Large diffs everywhereWrong dtype in
tl.load
Check load dtype matches input
Matmul 3-8x slower than expected
input_precision="ieee"
on
tl.dot
Remove it; use TF32 default. Ensure
reference_fn
also uses TF32
Matmul ~0.01-0.1 abs diff vs referenceTF32 vs IEEE mismatchUse same precision in both kernel and reference (TF32 for both)
Diffs at boundariesMissing maskAdd mask to all load/store ops
Random diffsRace conditionCheck atomics and ordering
NaN/InfDivision by zero or fp16 overflowGuard with epsilon; use
tl.float32
accumulator
grid must be a tuple
Grid lambda returns int, not tupleReturn
(value,)
with trailing comma
expected constexpr
in
tl.arange
Non-constexpr argumentBoth args of
tl.arange(start, end)
must be constexpr
triton.OutOfResources
Register/shared memory pressureReduce BLOCK_SIZE or
num_stages
Kernel not updating after editStale compilation cache
rm -rf ~/.triton/cache/
Mismatched results vs PyTorchC integer division semanticsTriton uses truncation; see
references/concepts-semantics.md
For extended error table, interpreter mode issues, and environment variables, see references/troubleshooting.md.
错误/症状原因修复方案
"Type mismatch, store Float32 to Float16"存储前缺少
.to(x.dtype)
转换
将fp32结果转换回原始类型
BLOCK_SIZE is not a constexpr
块大小作为运行时参数传递添加
: tl.constexpr
注解
二元操作中
shape mismatch
张量形状无法广播使用
tl.static_print
检查;使用
[:, None]
/
[None, :]
调整形状
所有结果差异都很大
tl.load
中数据类型错误
检查加载数据类型与输入匹配
矩阵乘法比预期慢3-8倍
tl.dot
上设置了
input_precision="ieee"
移除该设置;使用默认的TF32。确保
reference_fn
也使用TF32
与参考实现的绝对差异约0.01-0.1TF32与IEEE精度不匹配在内核和参考实现中使用相同精度(均使用TF32)
边界处存在差异缺少掩码为所有加载/存储操作添加掩码
随机差异竞争条件检查原子操作和顺序
NaN/Inf除零或fp16溢出使用epsilon保护;使用
tl.float32
累加器
grid must be a tuple
网格lambda返回整数而非元组返回
(value,)
(带尾随逗号)
tl.arange
expected constexpr
参数不是常量表达式
tl.arange(start, end)
的两个参数必须都是常量表达式
triton.OutOfResources
寄存器/共享内存压力过大减小BLOCK_SIZE或
num_stages
修改后内核未更新编译缓存过期执行
rm -rf ~/.triton/cache/
与PyTorch结果不匹配C语言整数除法语义Triton使用截断取整;请参阅
references/concepts-semantics.md
有关扩展错误表、解释器模式问题和环境变量,请参阅references/troubleshooting.md

When to Abort

何时终止

Stop and report failure if:
  1. Not a good fit -- Pure matmul or complex control flow (Phase 0 should catch this).
  2. Verification fails after 3 attempts -- Numerical issues too severe to fix.
  3. No speedup -- Reference is already well-optimized (cuBLAS, cuDNN).
  4. Hardware mismatch -- Target GPU not available for testing.
如果出现以下情况,请停止并报告失败:
  1. 不适合——纯矩阵乘法或复杂控制流(阶段0应已识别)。
  2. 验证3次尝试后仍失败——数值问题过于严重无法修复。
  3. 无性能提升——参考实现已高度优化(cuBLAS、cuDNN)。
  4. 硬件不匹配——目标GPU不可用于测试。