kernel-triton-writing
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseTriton Kernel Writing
Triton 内核编写
Principles
原则
Correctness First
正确性优先
- Never benchmark before verification passes.
- Always mask loads and stores for non-divisible shapes.
- Include ,
kernel_fn, andreference_fnexports for companion scripts.get_inputs() - Always run to validate against the reference.
scripts/verify_kernel.py
- 验证通过前切勿进行基准测试。
- 对于无法整除的形状,始终对加载和存储操作进行掩码处理。
- 导出、
kernel_fn和reference_fn供配套脚本使用。get_inputs() - 始终运行以对照参考实现验证正确性。
scripts/verify_kernel.py
FP16/BF16 Precision Rules (LOW FREEDOM -- follow exactly)
FP16/BF16 精度规则(自由度极低——严格遵循)
Transcendental functions (, , , ) require fp32 inputs.
tl.exptl.logtl.math.erftl.math.tanhpython
undefined超越函数(、、、)要求输入为fp32类型。
tl.exptl.logtl.math.erftl.math.tanhpython
undefinedWRONG -- compilation error or wrong results with fp16/bf16:
WRONG -- compilation error or wrong results with fp16/bf16:
result = tl.exp(x)
result = tl.exp(x)
CORRECT -- cast to fp32, compute, cast back:
CORRECT -- cast to fp32, compute, cast back:
x_fp32 = x.to(tl.float32)
result = tl.exp(x_fp32).to(x.dtype)
Rule: any math function beyond basic arithmetic (+, -, *, /) requires fp32 cast in, original dtype cast out.
Additional precision constraints:
- `tl.sigmoid()` is unavailable in some Triton versions. Use `1.0 / (1.0 + tl.exp(-x_fp32))`.
- Always cast back to `x.dtype` before `tl.store` -- mismatches cause "Type mismatch, store Float32 to Float16".
- Unlike PyTorch, Triton does NOT auto-promote fp16/bf16 to fp32 for accumulation. Always use `tl.float32` accumulators for `tl.dot`.
- **TF32 for matmul:** On Ampere+/Hopper, `tl.dot` uses TF32 by default for fp32 inputs (same as `torch.mm`). Do NOT add `input_precision="ieee"` — it is 3-8x slower. TF32 is the correct default. If verification fails due to TF32 precision (~0.01-0.1 abs diff), ensure `reference_fn` also uses TF32 (plain `torch.mm`, no `allow_tf32=False`).x_fp32 = x.to(tl.float32)
result = tl.exp(x_fp32).to(x.dtype)
规则:除基础算术运算(+、-、*、/)之外的任何数学函数,都需要先转换为fp32计算,再转换回原始数据类型。
额外精度约束:
- 部分Triton版本中`tl.sigmoid()`不可用。使用`1.0 / (1.0 + tl.exp(-x_fp32))`替代。
- 在执行`tl.store`前必须转换回`x.dtype`——类型不匹配会导致"Type mismatch, store Float32 to Float16"错误。
- 与PyTorch不同,Triton不会自动将fp16/bf16提升为fp32进行累加。对于`tl.dot`,始终使用`tl.float32`累加器。
- **矩阵乘法的TF32支持**:在Ampere+/Hopper架构上,`tl.dot`默认对fp32输入使用TF32(与`torch.mm`一致)。请勿添加`input_precision="ieee"`——这会使速度慢3-8倍。TF32是正确的默认选项。如果因TF32精度问题(绝对误差约0.01-0.1)导致验证失败,请确保`reference_fn`也使用TF32(使用普通`torch.mm`,不要设置`allow_tf32=False`)。CPU-GPU Sync Avoidance (LOW FREEDOM)
避免CPU-GPU同步(自由度极低)
Never call in kernel wrappers. It forces a CPU-GPU sync (~50-100us per call).
.item()| Pitfall | Fix |
|---|---|
| |
| Use tensor metadata for pseudo-random seed |
| Allocating output every call | Accept pre-allocated output as parameter |
| Python loops calling kernel | Batch operations |
在内核包装器中切勿调用。这会强制进行CPU-GPU同步(每次调用约50-100微秒)。
.item()| 陷阱 | 修复方案 |
|---|---|
使用 | 使用 |
使用 | 使用张量元数据生成伪随机种子 |
| 每次调用都分配输出张量 | 接受预分配的输出作为参数 |
| 使用Python循环调用内核 | 批量处理操作 |
C Integer Division Semantics (CRITICAL)
C语言整数除法语义(至关重要)
Triton uses C semantics (round toward zero) for and , NOT Python semantics (round toward negative infinity). This only matters when operands can be negative.
//%| Expression | Python | Triton/C |
|---|---|---|
| | |
| | |
Safe pattern: Ensure all index/offset values are non-negative. If negative values are possible, use .
(idx % BLOCK + BLOCK) % BLOCKSee references/concepts-semantics.md for full rules and scalar-only exception.
Triton对和使用C语言语义(向零取整),而非Python语义(向负无穷取整)。这仅在操作数可能为负数时才会产生影响。
//%| 表达式 | Python | Triton/C |
|---|---|---|
| | |
| | |
安全模式:确保所有索引/偏移值均为非负数。如果可能出现负值,使用。
(idx % BLOCK + BLOCK) % BLOCK有关完整规则和标量例外情况,请参阅references/concepts-semantics.md。
Kernel Design Mental Model
内核设计思维模型
- Parallelization axis: Element-wise kernels parallelize over flattened elements. Row-wise kernels (LayerNorm, softmax) parallelize over rows. Matmul kernels tile in 2D (M, N).
- Block size: Power-of-2 only (256, 512, 1024, 2048). Start with 1024 for H100, 512 for V100.
- Memory coalescing: Adjacent threads must access adjacent memory addresses. The compiler handles this automatically from block-level pointer arithmetic.
- Grid: Use . With autotune, grid must be a lambda:
triton.cdiv(n_elements, BLOCK_SIZE).lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),) - Decorator order: (outermost) ->
@triton.autotune->@triton.heuristics(innermost).@triton.jit - : Required for autotune on kernels that accumulate (e.g., matmul output). Without it, later configs see leftover values from earlier trials.
reset_to_zero
- 并行化轴:逐元素内核在扁平化元素上并行化。按行处理的内核(LayerNorm、softmax)在行上并行化。矩阵乘法内核在二维(M、N)上进行分块。
- 块大小:仅支持2的幂(256、512、1024、2048)。H100从1024开始,V100从512开始。
- 内存合并:相邻线程必须访问相邻的内存地址。编译器会通过块级指针运算自动处理此问题。
- 网格:使用。结合自动调优时,网格必须是一个lambda函数:
triton.cdiv(n_elements, BLOCK_SIZE)。lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),) - 装饰器顺序:(最外层)->
@triton.autotune->@triton.heuristics(最内层)。@triton.jit - :对于进行累加的内核(例如矩阵乘法输出),自动调优时需要此选项。如果没有它,后续配置会看到之前试验留下的值。
reset_to_zero
Workflow
工作流程
Fast path: If the user explicitly requests a Triton kernel (e.g., "Write a Triton kernel for X", "Implement softmax in Triton"), start at Phase 2. Only use Phase 0-1 when the request is ambiguous about whether Triton is appropriate.
快速路径:如果用户明确请求Triton内核(例如“编写一个用于X的Triton内核”、“用Triton实现softmax”),直接从阶段2开始。仅当请求对是否适合使用Triton存在歧义时,才使用阶段0-1。
Phase 0: Route the Operator (only for ambiguous requests)
阶段0:选择算子(仅针对歧义请求)
Skip this phase if the user explicitly asks for a Triton kernel. Only use when the request is ambiguous (e.g., "optimize this operation").
Triton wins when 2+ operations can share registers instead of writing/reading global memory. Quick rules:
| Pattern | Decision |
|---|---|
Single element-wise op ( | SKIP — PyTorch already optimal |
| Standalone matmul | SKIP — cuBLAS is optimized |
| Standard attention | SKIP — Use FlashAttention |
| Element-wise chain (2+ ops), reduction, matmul + epilogue | USE TRITON |
If SKIP, recommend the alternative and STOP. See references/operator-routing.md for edge cases.
如果用户明确要求Triton内核,跳过此阶段。仅当请求存在歧义时使用(例如“优化此操作”)。
当2个及以上操作可以共享寄存器,而非写入/读取全局内存时,Triton更具优势。快速判断规则:
| 模式 | 决策 |
|---|---|
单一逐元素操作( | 跳过——PyTorch已优化至最优 |
| 独立矩阵乘法 | 跳过——cuBLAS已优化 |
| 标准注意力机制 | 跳过——使用FlashAttention |
| 逐元素操作链(2个及以上)、归约操作、矩阵乘法+结尾操作 | 使用TRITON |
如果选择跳过,推荐替代方案并停止。有关边缘情况,请参阅references/operator-routing.md。
Phase 1: Analyze the Operator (only for ambiguous requests)
阶段1:分析算子(仅针对歧义请求)
From the user's request, identify: (1) operation type, (2) parallelization strategy, (3) input shapes and dtypes.
从用户的请求中确定:(1) 操作类型,(2) 并行化策略,(3) 输入形状和数据类型。
Phase 2: Design the Kernel
阶段2:设计内核
Pick the skeleton below that matches your operation. These skeletons are sufficient for element-wise, reduction, matmul, and fusion kernels — do NOT read reference files for these common patterns. Only consult when implementing uncommon patterns (grouped GEMM, TMA, extern functions) or debugging issues.
references/Element-wise skeleton (GELU, dropout, fused ops on flat tensors):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# ... compute ...
tl.store(out_ptr + offsets, result, mask=mask)Row-wise skeleton (softmax, LayerNorm, RMSNorm — one program per row):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
x = tl.load(x_ptr + row_idx * n_cols + col_offsets, mask=mask, other=0.0)
# ... reduce / normalize ...
tl.store(out_ptr + row_idx * n_cols + col_offsets, result, mask=mask)Tiled matmul skeleton (GEMM with 2D tiling, grouped ordering, and autotune):
python
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(0)
num_m_blocks = tl.cdiv(M, BLOCK_M)
num_n_blocks = tl.cdiv(N, BLOCK_N)
# Grouped ordering for L2 cache locality
num_pid_in_group = GROUP_SIZE_M * num_n_blocks
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_m_blocks - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
offs_k += BLOCK_K
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)选择与操作匹配的骨架。**这些骨架足以覆盖逐元素、归约、矩阵乘法和融合内核——对于这些常见模式,无需查阅参考文件。**仅在实现不常见模式(分组GEMM、TMA、外部函数)或调试问题时,才查阅目录。
references/逐元素骨架(GELU、dropout、扁平张量上的融合操作):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# ... compute ...
tl.store(out_ptr + offsets, result, mask=mask)按行处理骨架(softmax、LayerNorm、RMSNorm——每行一个程序):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
x = tl.load(x_ptr + row_idx * n_cols + col_offsets, mask=mask, other=0.0)
# ... reduce / normalize ...
tl.store(out_ptr + row_idx * n_cols + col_offsets, result, mask=mask)分块矩阵乘法骨架(带有二维分块、分组排序和自动调优的GEMM):
python
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(0)
num_m_blocks = tl.cdiv(M, BLOCK_M)
num_n_blocks = tl.cdiv(N, BLOCK_N)
# Grouped ordering for L2 cache locality
num_pid_in_group = GROUP_SIZE_M * num_n_blocks
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_m_blocks - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
offs_k += BLOCK_K
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)Phase 3: Write the Kernel
阶段3:编写内核
Create an output directory, then write the kernel file to .
{output_dir}/kernel.pyThe kernel file MUST include:
- decorated kernel function
@triton.jit - for production kernels (see references/api-core.md)
@triton.autotune - Python wrapper function (descriptive name for external import)
- Fixed contract exports (companion scripts rely on these exact names):
- — alias to the wrapper function
kernel_fn - — PyTorch reference with identical signature
reference_fn(*args) - — returns
get_inputs()of fresh CUDA tensors for testing/benchmarkinglist
Concise example (fused GELU + dropout):
python
import triton
import triton.language as tl
import torch
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
],
key=['n_elements'],
)
@triton.jit
def fused_gelu_dropout_kernel(
x_ptr, out_ptr, n_elements, p, seed,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x_fp32 = x.to(tl.float32)
x = (0.5 * x_fp32 * (1.0 + tl.math.erf(x_fp32 * 0.7071067811865476))).to(x.dtype)
random = tl.rand(seed, offsets)
x = tl.where(random > p, x / (1.0 - p), 0.0)
tl.store(out_ptr + offsets, x, mask=mask)
def fused_gelu_dropout_triton(x: torch.Tensor, p: float = 0.1) -> torch.Tensor:
n_elements = x.numel()
out = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
seed = (x.data_ptr() % (2**31)) ^ n_elements # sync-free seed
fused_gelu_dropout_kernel[grid](x, out, n_elements, p, seed)
return out创建输出目录,然后将内核文件写入。
{output_dir}/kernel.py内核文件必须包含:
- 带有装饰器的内核函数
@triton.jit - 生产级内核需添加(参阅references/api-core.md)
@triton.autotune - Python包装器函数(便于外部导入的描述性名称)
- 固定契约导出(配套脚本依赖这些确切名称):
- —— 包装器函数的别名
kernel_fn - —— 具有相同签名的PyTorch参考实现
reference_fn(*args) - —— 返回用于测试/基准测试的全新CUDA张量列表
get_inputs()
简洁示例(融合GELU + dropout):
python
import triton
import triton.language as tl
import torch
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
],
key=['n_elements'],
)
@triton.jit
def fused_gelu_dropout_kernel(
x_ptr, out_ptr, n_elements, p, seed,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x_fp32 = x.to(tl.float32)
x = (0.5 * x_fp32 * (1.0 + tl.math.erf(x_fp32 * 0.7071067811865476))).to(x.dtype)
random = tl.rand(seed, offsets)
x = tl.where(random > p, x / (1.0 - p), 0.0)
tl.store(out_ptr + offsets, x, mask=mask)
def fused_gelu_dropout_triton(x: torch.Tensor, p: float = 0.1) -> torch.Tensor:
n_elements = x.numel()
out = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
seed = (x.data_ptr() % (2**31)) ^ n_elements # sync-free seed
fused_gelu_dropout_kernel[grid](x, out, n_elements, p, seed)
return out--- Fixed contract (companion scripts rely on these names) ---
--- Fixed contract (companion scripts rely on these names) ---
kernel_fn = fused_gelu_dropout_triton
def reference_fn(x, p=0.1):
torch.manual_seed((x.data_ptr() % (2**31)) ^ x.numel())
return torch.nn.functional.dropout(
torch.nn.functional.gelu(x), p, training=True
)
def get_inputs():
return [torch.randn(128 * 1024 * 1024, device="cuda")]
For more patterns (SiLU+mul, RMSNorm, linear+GELU, add+LayerNorm), see [references/patterns-fusion.md](references/patterns-fusion.md). For GEMM patterns, see [references/patterns-gemm.md](references/patterns-gemm.md).kernel_fn = fused_gelu_dropout_triton
def reference_fn(x, p=0.1):
torch.manual_seed((x.data_ptr() % (2**31)) ^ x.numel())
return torch.nn.functional.dropout(
torch.nn.functional.gelu(x), p, training=True
)
def get_inputs():
return [torch.randn(128 * 1024 * 1024, device="cuda")]
更多模式(SiLU+mul、RMSNorm、linear+GELU、add+LayerNorm)请参阅[references/patterns-fusion.md](references/patterns-fusion.md)。有关GEMM模式,请参阅[references/patterns-gemm.md](references/patterns-gemm.md)。Phase 4: Verify Correctness
阶段4:验证正确性
Run the companion verification script:
bash
python scripts/verify_kernel.py {output_dir}/kernel.py --rtol 1e-3 --atol 1e-3Output:
json
{"correct": true, "max_abs_diff": 1.2e-7, "max_rel_diff": 3.4e-6, "details": "..."}Stop if . Fix the kernel before benchmarking.
correct: falseTolerance guide:
| Dtype | rtol | atol | Notes |
|---|---|---|---|
| float16 | 1e-3 | 1e-3 | |
| bfloat16 | 1e-2 | 1e-2 | |
| float32 | 1e-5 | 1e-5 | Element-wise ops |
| float32 (matmul) | 1e-2 | 1e-1 | TF32 accumulation order differs between Triton tiles and cuBLAS |
运行配套验证脚本:
bash
python scripts/verify_kernel.py {output_dir}/kernel.py --rtol 1e-3 --atol 1e-3输出:
json
{"correct": true, "max_abs_diff": 1.2e-7, "max_rel_diff": 3.4e-6, "details": "..."}**如果,请停止。**在进行基准测试前修复内核。
correct: false容差指南:
| 数据类型 | rtol | atol | 说明 |
|---|---|---|---|
| float16 | 1e-3 | 1e-3 | |
| bfloat16 | 1e-2 | 1e-2 | |
| float32 | 1e-5 | 1e-5 | 逐元素操作 |
| float32(矩阵乘法) | 1e-2 | 1e-1 | Triton分块与cuBLAS的TF32累加顺序不同 |
Phase 5: Benchmark Performance (optional)
阶段5:性能基准测试(可选)
Only benchmark if the user explicitly requests performance numbers. Skip this phase for correctness-focused requests.
bash
python scripts/benchmark_kernel.py {output_dir}/kernel.pyOutput:
json
{"kernel_time_ms": 0.45, "reference_time_ms": 1.23, "speedup": 2.73, "warmup_iters": 10, "benchmark_iters": 40}仅当用户明确请求性能数据时才进行基准测试。对于以正确性为重点的请求,跳过此阶段。
bash
python scripts/benchmark_kernel.py {output_dir}/kernel.py输出:
json
{"kernel_time_ms": 0.45, "reference_time_ms": 1.23, "speedup": 2.73, "warmup_iters": 10, "benchmark_iters": 40}References (consult only when stuck)
参考资料(仅在遇到问题时查阅)
The skeletons and principles above cover element-wise, reduction, matmul, and fusion kernels. Do NOT read reference files for these common patterns.
Only consult when:
references/- Implementing uncommon patterns (grouped GEMM, TMA, persistent matmul, extern functions)
- Debugging a compile error or incorrect result not covered by the error table below
- Needing API details for an unfamiliar operation
tl.*
How to search: Grep for your keyword across . Read only the file Grep points to.
references/| File | When to use |
|---|---|
| Unfamiliar |
| Unfamiliar |
| Grouped GEMM, persistent matmul, TMA, MX formats |
| Flash attention details, backward passes, libdevice |
| Debug ops, interpreter mode, env vars |
上述骨架和原则涵盖了逐元素、归约、矩阵乘法和融合内核。对于这些常见模式,请勿查阅参考文件。
仅在以下情况时查阅:
references/- 实现不常见模式(分组GEMM、TMA、持久化矩阵乘法、外部函数)
- 调试未被以下错误表覆盖的编译错误或不正确结果
- 需要不熟悉的操作的API细节
tl.*
搜索方法: 在目录中搜索关键词。仅阅读搜索指向的文件。
references/| 文件 | 使用场景 |
|---|---|
| 不熟悉 |
| 不熟悉 |
| 分组GEMM、持久化矩阵乘法、TMA、MX格式 |
| Flash Attention细节、反向传播、libdevice |
| 调试操作、解释器模式、环境变量 |
Error Handling and Troubleshooting
错误处理与故障排除
Common Errors
常见错误
| Error / Symptom | Cause | Fix |
|---|---|---|
| "Type mismatch, store Float32 to Float16" | Missing | Cast fp32 result back |
| Block size passed as runtime value | Add |
| Tensor shapes don't broadcast | Check with |
| Large diffs everywhere | Wrong dtype in | Check load dtype matches input |
| Matmul 3-8x slower than expected | | Remove it; use TF32 default. Ensure |
| Matmul ~0.01-0.1 abs diff vs reference | TF32 vs IEEE mismatch | Use same precision in both kernel and reference (TF32 for both) |
| Diffs at boundaries | Missing mask | Add mask to all load/store ops |
| Random diffs | Race condition | Check atomics and ordering |
| NaN/Inf | Division by zero or fp16 overflow | Guard with epsilon; use |
| Grid lambda returns int, not tuple | Return |
| Non-constexpr argument | Both args of |
| Register/shared memory pressure | Reduce BLOCK_SIZE or |
| Kernel not updating after edit | Stale compilation cache | |
| Mismatched results vs PyTorch | C integer division semantics | Triton uses truncation; see |
For extended error table, interpreter mode issues, and environment variables, see references/troubleshooting.md.
| 错误/症状 | 原因 | 修复方案 |
|---|---|---|
| "Type mismatch, store Float32 to Float16" | 存储前缺少 | 将fp32结果转换回原始类型 |
| 块大小作为运行时参数传递 | 添加 |
二元操作中 | 张量形状无法广播 | 使用 |
| 所有结果差异都很大 | | 检查加载数据类型与输入匹配 |
| 矩阵乘法比预期慢3-8倍 | | 移除该设置;使用默认的TF32。确保 |
| 与参考实现的绝对差异约0.01-0.1 | TF32与IEEE精度不匹配 | 在内核和参考实现中使用相同精度(均使用TF32) |
| 边界处存在差异 | 缺少掩码 | 为所有加载/存储操作添加掩码 |
| 随机差异 | 竞争条件 | 检查原子操作和顺序 |
| NaN/Inf | 除零或fp16溢出 | 使用epsilon保护;使用 |
| 网格lambda返回整数而非元组 | 返回 |
| 参数不是常量表达式 | |
| 寄存器/共享内存压力过大 | 减小BLOCK_SIZE或 |
| 修改后内核未更新 | 编译缓存过期 | 执行 |
| 与PyTorch结果不匹配 | C语言整数除法语义 | Triton使用截断取整;请参阅 |
有关扩展错误表、解释器模式问题和环境变量,请参阅references/troubleshooting.md。
When to Abort
何时终止
Stop and report failure if:
- Not a good fit -- Pure matmul or complex control flow (Phase 0 should catch this).
- Verification fails after 3 attempts -- Numerical issues too severe to fix.
- No speedup -- Reference is already well-optimized (cuBLAS, cuDNN).
- Hardware mismatch -- Target GPU not available for testing.
如果出现以下情况,请停止并报告失败:
- 不适合——纯矩阵乘法或复杂控制流(阶段0应已识别)。
- 验证3次尝试后仍失败——数值问题过于严重无法修复。
- 无性能提升——参考实现已高度优化(cuBLAS、cuDNN)。
- 硬件不匹配——目标GPU不可用于测试。