Loading...
Loading...
Accepts Triton operator implementations, automatically invokes Torch small operator implementations (CPU or NPU) for precision comparison, and generates precision reports. It is used when users need to verify the correctness and precision of Triton operator implementations, compare precision with PyTorch implementations, and generate standardized precision reports.
npx skill4agent add ascend/agent-skills triton-operator-precision-eval┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Triton算子实现 │────▶│ 生成测试数据 │────▶│ 执行Torch对比实现 │
└─────────────────┘ └─────────────────┘ └─────────────────┘
▲ │ │
│ ▼ ▼
│ ┌─────────────────┐ ┌─────────────────┐
│ │ 执行Triton实现 │ │ 计算误差指标 │
│ └─────────────────┘ └─────────────────┘
│ │ │
└─────────────────────┼─────────────────────┘
│
▼
┌─────────────────┐
│ 生成精度报告 │
└─────────────────┘test_common.generate_numpy()test_common.validate_cmp()torch_nputest_abs.pyimport triton
import triton.language as tl
import numpy as np
import torch
import pytest
import test_commondef torch_pointwise(x0):
# Implement Torch function corresponding to the Triton operator
return torch.abs(x0)@triton.jit
def triton_abs(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr):
# Triton kernel implementation
offset = tl.program_id(0) * XBLOCK
base1 = tl.arange(0, XBLOCK_SUB)
loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB
for loop1 in range(loops1):
x0_prime = offset + (loop1 * XBLOCK_SUB) + base1
x0 = offset + (loop1 * XBLOCK_SUB) + base1
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp2 = tl.abs(tmp0)
tl.store(out_ptr0 + (x0), tmp2, None)@pytest.mark.parametrize('param_list',
[
['float16', (2, 4096, 8), 32, 2048, 64],
['float32', (2, 4096, 8), 32, 2048, 64],
['int8', (2, 4096, 8), 32, 2048, 64],
['uint8', (2, 4096, 8), 32, 2048, 64],
]
)
def test_case(param_list):
dtype, shape, ncore, xblock, xblock_sub = param_list
np_x0 = test_common.generate_numpy(shape, dtype)
x0 = torch.from_numpy(np_x0).to(eval('torch.' + dtype)).npu()
y_ref = torch_pointwise(x0)
y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu()
triton_abs[ncore, 1, 1](x0, y_cal, xblock, xblock_sub)
test_common.validate_cmp(dtype, y_cal, y_ref)# 运行单个测试文件
pytest test_abs.py -v
# 运行所有测试文件
pytest ./examples/ -v| Data Type | Verification Method | Error Threshold |
|---|---|---|
| float16 | Relative Error | rtol=1e-03, atol=1e-03 |
| float32 | Relative Error | rtol=1e-04, atol=1e-04 |
| bfloat16 | Relative Error | rtol=1e-02, atol=1e-02 |
| int32/int64/int16/int8 | Exact Match | - |
| uint32/uint64/uint16/uint8 | Exact Match | - |
| bool | Exact Match | - |
eco_report.txt================================================================================
Triton算子精度验证报告
--------------------------------------------------------------------------------
[验证配置]:
数据类型: float32 (Single Precision)
MERE阈值: 1.220703e-04
MARE阈值: 1.220703e-03 (10×MERE阈值)
小值域阈值: 1.000000e-07
--------------------------------------------------------------------------------
[精度标准]:
float16: 相对误差 rtol=1e-03, atol=1e-03
float32: 相对误差 rtol=1e-04, atol=1e-04
bfloat16: 相对误差 rtol=1e-02, atol=1e-02
int32/int64/int16/int8: 完全相等
uint32/uint64/uint16/uint8: 完全相等
bool: 完全相等
--------------------------------------------------------------------------------
[验证结果]:
验证结果: FAIL
样本总数: 4096
--------------------------------------------------------------------------------
[误差指标]:
平均相对误差(MERE): 6.642197e-03
阈值要求: MERE < 1.220703e-04
最大相对误差(MARE): 3.458786e+00
阈值要求: MARE < 1.220703e-03
--------------------------------------------------------------------------------
[判定条件]:
✓ MERE < 阈值: False
✓ MARE < 10×阈值: False
✓ 总体结果: False
================================================================================| Problem | Possible Cause | Solution |
|---|---|---|
| Triton kernel compilation failed | Triton syntax error or version incompatibility | Check Triton syntax, ensure Triton version is compatible with the code |
| Precision verification failed | Incorrect operator implementation logic or precision loss | Check the operator implementation, adjust the algorithm to improve precision |
| NPU device unavailable | | Install |
| Insufficient memory | Test data is too large | Reduce test data scale or adjust parameter configuration |