Loading...
Loading...
Step-by-step tutorial for adding a heavyweight AOT CUDA/C++ kernel to sgl-kernel (including tests & benchmarks)
npx skill4agent add sgl-project/sglang add-sgl-kernelsgl-kernelscale(x, factor) = x * factorxfactorx * factorouttorch.float16torch.bfloat16torch.float32DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16sgl-kernel/include/utils.hpython/sglang/jit_kernelsgl-kernelflashinferflashinferjit_kernelsgl-kernel/csrc/elementwise/scale.cusgl-kernel/include/sgl_kernel_ops.hsgl-kernel/csrc/common_extension.ccsgl-kernel/CMakeLists.txtset(SOURCES ...)sgl-kernel/python/sgl_kernel/sgl-kernel/python/sgl_kernel/__init__.pysgl-kernel/tests/test_scale.pysgl-kernel/benchmark/bench_scale.pycsrc/csrc/elementwise/csrc/gemm/csrc/attention/csrc/moe/sgl-kernel/csrc/elementwise/scale.cu#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "utils.h" // DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
// scale_kernel: out[i] = input[i] * factor
// Supports float, half (__half), __nv_bfloat16 via template T
template <typename T>
__global__ void scale_kernel(T* __restrict__ out,
const T* __restrict__ input,
float factor,
int64_t n) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx < n) {
out[idx] = static_cast<T>(static_cast<float>(input[idx]) * factor);
}
}
void scale(at::Tensor& out, const at::Tensor& input, double factor) {
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(out.is_cuda(), "out must be a CUDA tensor");
TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
TORCH_CHECK(out.sizes() == input.sizes(), "out and input must have the same shape");
TORCH_CHECK(out.scalar_type() == input.scalar_type(),
"out and input must have the same dtype");
const int64_t n = input.numel();
const int threads = 256;
const int blocks = (n + threads - 1) / threads;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
// Dispatches over float, float16, bfloat16
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
scale_kernel<c_type><<<blocks, threads, 0, stream>>>(
static_cast<c_type*>(out.data_ptr()),
static_cast<const c_type*>(input.data_ptr()),
static_cast<float>(factor),
n);
cudaError_t status = cudaGetLastError();
TORCH_CHECK(status == cudaSuccess,
"scale_kernel launch failed: ", cudaGetErrorString(status));
return true;
});
}at::TensorTORCH_CHECKat::cuda::getCurrentCUDAStream()DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16floathalf__nv_bfloat16TORCH_CHECKinclude/sgl_kernel_ops.hsgl-kernel/include/sgl_kernel_ops.hvoid scale(at::Tensor& out, const at::Tensor& input, double factor);csrc/common_extension.ccsgl-kernel/csrc/common_extension.ccTORCH_LIBRARY_FRAGMENT(sgl_kernel, m)// From csrc/elementwise
m.def("scale(Tensor! out, Tensor input, float factor) -> ()");
m.impl("scale", torch::kCUDA, &scale);Tensor!torch.compilefloatdoubleCMakeLists.txtsgl-kernel/CMakeLists.txtset(SOURCES ...)csrc/elementwise/scale.cusgl-kernel/python/sgl_kernel/sgl-kernel/python/sgl_kernel/elementwise.pysgl-kernel/python/sgl_kernel/__init__.pysgl-kernel/python/sgl_kernel/elementwise.pyimport torch
def scale(
input: torch.Tensor,
factor: float,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Element-wise scale: out = input * factor.
Supported dtypes: torch.float16, torch.bfloat16, torch.float32.
Parameters
----------
input : CUDA input tensor
factor : scale factor (float)
out : optional pre-allocated CUDA output tensor (same shape/dtype as input)
"""
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.scale.default(out, input, factor)
return outsgl-kernel/python/sgl_kernel/__init__.pysgl-kernel/tests/test_scale.pyimport pytest
import torch
import sgl_kernel
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("size", [128, 1024, 4096, 65536])
@pytest.mark.parametrize("factor", [0.5, 1.0, 2.0])
def test_scale_correctness(dtype, size, factor):
input = torch.randn(size, dtype=dtype, device="cuda")
out = torch.empty_like(input)
result = sgl_kernel.scale(input, factor, out=out)
assert result is out
expected = input * factor
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
def test_scale_shape_mismatch():
input = torch.randn(128, dtype=torch.float16, device="cuda")
out = torch.empty(256, dtype=torch.float16, device="cuda")
with pytest.raises(RuntimeError, match="same shape"):
sgl_kernel.scale(input, 2.0, out=out)
def test_scale_cpu_input():
input = torch.randn(128, dtype=torch.float16) # CPU
out = torch.empty_like(input)
with pytest.raises(RuntimeError, match="CUDA"):
sgl_kernel.scale(input, 2.0, out=out)
if __name__ == "__main__":
pytest.main([__file__, "-q"])sgl-kernel/benchmark/bench_scale.pyimport itertools
import os
import torch
import triton
import triton.testing
import sgl_kernel
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
dtypes = [torch.float16] if IS_CI else [torch.float16, torch.bfloat16, torch.float32]
sizes = [4096] if IS_CI else [2**n for n in range(10, 20)] # 1K … 512K
factors = [2.0]
configs = list(itertools.product(dtypes, sizes))
def torch_scale(input: torch.Tensor, factor: float) -> torch.Tensor:
return input * factor
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["dtype", "size"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "torch"],
line_names=["SGL Kernel", "PyTorch"],
styles=[("green", "-"), ("red", "--")],
ylabel="µs (median)",
plot_name="scale-performance",
args={},
)
)
def benchmark(dtype, size, provider):
input = torch.randn(size, dtype=dtype, device="cuda")
out = torch.empty_like(input)
factor = 2.0
if provider == "sglang":
fn = lambda: sgl_kernel.scale(input, factor, out=out)
else:
fn = lambda: torch_scale(input, factor)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)cd sgl-kernel
make build -j16cd sgl-kernel
make build -j1 MAX_JOBS=2 CMAKE_ARGS="-DSGL_KERNEL_COMPILE_THREADS=1"pytest sgl-kernel/tests/test_scale.py -q
python sgl-kernel/benchmark/bench_scale.pyCUDA_LAUNCH_BLOCKING=1compute-sanitizer --tool memcheck python ...MAX_JOBSSGL_KERNEL_COMPILE_THREADSsgl-kernel/analyze_whl_kernel_sizes.py.cuSOURCESsgl-kernel/README.mdsgl-kernel/include/sgl_kernel_ops.hsgl-kernel/csrc/common_extension.ccsgl-kernel/CMakeLists.txtsgl-kernel/include/utils.hDISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16sgl-kernel/csrc/elementwise/activation.cusgl-kernel/csrc/elementwise/scale.cu # NEW: CUDA kernel + launcher
sgl-kernel/include/sgl_kernel_ops.h # MODIFIED: C++ declaration
sgl-kernel/csrc/common_extension.cc # MODIFIED: schema + dispatch registration
sgl-kernel/CMakeLists.txt # MODIFIED: add source file (alphabetical)
sgl-kernel/python/sgl_kernel/elementwise.py # MODIFIED: Python wrapper
sgl-kernel/python/sgl_kernel/__init__.py # MODIFIED: re-export Python API
sgl-kernel/tests/test_scale.py # NEW: tests
sgl-kernel/benchmark/bench_scale.py # NEW: benchmark