flashkda-delta-attention
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseFlashKDA Delta Attention Skill
FlashKDA Delta Attention Skill
Skill by ara.so — Daily 2026 Skills collection.
FlashKDA provides high-performance CUDA kernels for Kimi Delta Attention (KDA) built on CUTLASS. It targets SM90+ GPUs (H100/H20 class) and integrates as a drop-in backend for 's operation.
flash-linear-attentionchunk_kda由ara.so开发的技能——Daily 2026技能合集。
FlashKDA 提供基于CUTLASS构建的Kimi Delta Attention(KDA)高性能CUDA核。它面向SM90+级GPU(H100/H20系列),可作为中操作的即插即用后端集成。
flash-linear-attentionchunk_kdaRequirements
要求
- GPU: SM90+ (H100, H20, or newer)
- CUDA 12.9+
- PyTorch 2.4+
- Python 3.8+
- GPU:SM90+(H100、H20或更新型号)
- CUDA 12.9+
- PyTorch 2.4+
- Python 3.8+
Installation
安装
bash
git clone https://github.com/MoonshotAI/FlashKDA.git flash-kda
cd flash-kda
git submodule update --init --recursive
pip install -v .Install the FLA integration (optional but recommended):
bash
pip install -U flash-linear-attention # >= 0.5.0bash
git clone https://github.com/MoonshotAI/FlashKDA.git flash-kda
cd flash-kda
git submodule update --init --recursive
pip install -v .安装FLA集成(可选但推荐):
bash
pip install -U flash-linear-attention # >= 0.5.0Core Kernel API
核心核API
flash_kda.fwd
flash_kda.fwdflash_kda.fwd
flash_kda.fwdThe primary low-level kernel call:
python
import torch
import flash_kda
flash_kda.fwd(
q, k, v, g, beta, scale, out,
A_log, dt_bias, lower_bound,
initial_state=None,
final_state=None,
cu_seqlens=None
)Tensor shapes and dtypes:
| Parameter | Dtype | Shape | Notes |
|---|---|---|---|
| bf16 | | Query; K must be 128 |
| bf16 | | Key; K must be 128 |
| bf16 | | Value; V must be 128 |
| bf16 | | Gate logits (sigmoid/activation applied internally) |
| bf16 | | Beta logits (sigmoid applied internally) |
| float | scalar | Attention scale factor |
| bf16 | | Pre-allocated output tensor |
| fp32 | | Per-head log-gate parameter |
| fp32 | | Per-head gate bias |
| float | scalar | Gate lower bound, range |
| bf16/fp32/None | | Optional initial recurrent state |
| bf16/fp32/None | | Optional output final state |
| int64 | | Optional cumulative seq lengths for varlen |
Constraints:
- required
K == V == 128 - When is provided,
cu_seqlensmust be 1 andBis total tokens across all sequencesT - and
initial_statedtypes must match when both providedfinal_state
主要的底层核调用:
python
import torch
import flash_kda
flash_kda.fwd(
q, k, v, g, beta, scale, out,
A_log, dt_bias, lower_bound,
initial_state=None,
final_state=None,
cu_seqlens=None
)张量形状与数据类型:
| 参数 | 数据类型 | 形状 | 说明 |
|---|---|---|---|
| bf16 | | 查询向量;K必须为128 |
| bf16 | | 键向量;K必须为128 |
| bf16 | | 值向量;V必须为128 |
| bf16 | | 门控logits(内部会应用sigmoid/激活函数) |
| bf16 | | Beta logits(内部会应用sigmoid函数) |
| float | 标量 | 注意力缩放因子 |
| bf16 | | 预分配的输出张量 |
| fp32 | | 每头的log门控参数 |
| fp32 | | 每头的门控偏置 |
| float | 标量 | 门控下界,范围为 |
| bf16/fp32/None | | 可选的初始循环状态 |
| bf16/fp32/None | | 可选的输出最终状态 |
| int64 | | 可选的变长序列累积长度 |
约束条件:
- 必须满足
K == V == 128 - 当提供时,
cu_seqlens必须为1,B为所有序列的总token数T - 若同时提供和
initial_state,二者的数据类型必须匹配final_state
Usage via flash-linear-attention Backend (Recommended)
通过flash-linear-attention后端使用(推荐)
FlashKDA auto-dispatches from FLA's when installed:
chunk_kdapython
import torch
import logging
from fla.ops.kda import chunk_kda安装FlashKDA后,会自动从FLA的中调度:
chunk_kdapython
import torch
import logging
from fla.ops.kda import chunk_kdaOptional: see dispatch decisions
可选:查看调度决策
logging.basicConfig(level=logging.INFO)
B, T, H, K, V = 2, 2048, 16, 128, 128
q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(B, T, H, dtype=torch.bfloat16, device='cuda')
A_log = torch.randn(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
h0 = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5
lower_bound = -5.0
with torch.inference_mode():
out, final_state = chunk_kda(
q=q, k=k, v=v, g=g, beta=beta,
scale=scale,
initial_state=h0,
output_final_state=True,
use_gate_in_kernel=True,
use_qk_l2norm_in_kernel=True,
use_beta_sigmoid_in_kernel=True,
safe_gate=True,
A_log=A_log,
dt_bias=dt_bias,
lower_bound=lower_bound,
transpose_state_layout=True,
)
logging.basicConfig(level=logging.INFO)
B, T, H, K, V = 2, 2048, 16, 128, 128
q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(B, T, H, dtype=torch.bfloat16, device='cuda')
A_log = torch.randn(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
h0 = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5
lower_bound = -5.0
with torch.inference_mode():
out, final_state = chunk_kda(
q=q, k=k, v=v, g=g, beta=beta,
scale=scale,
initial_state=h0,
output_final_state=True,
use_gate_in_kernel=True,
use_qk_l2norm_in_kernel=True,
use_beta_sigmoid_in_kernel=True,
safe_gate=True,
A_log=A_log,
dt_bias=dt_bias,
lower_bound=lower_bound,
transpose_state_layout=True,
)
out: [B, T, H, V], final_state: [B, H, V, K]
out: [B, T, H, V], final_state: [B, H, V, K]
undefinedundefinedDirect Low-Level Kernel Usage
直接使用底层核
python
import torch
import flash_kda
def run_flash_kda(
q, k, v, g, beta,
A_log, dt_bias,
lower_bound=-5.0,
initial_state=None,
):
B, T, H, K = q.shape
V = v.shape[-1]
scale = K ** -0.5
out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device=q.device)
final_state = torch.zeros(B, H, V, K, dtype=torch.float32, device=q.device)
flash_kda.fwd(
q, k, v, g, beta,
scale, out,
A_log, dt_bias, lower_bound,
initial_state=initial_state,
final_state=final_state,
cu_seqlens=None,
)
return out, final_state
B, T, H, K = 1, 4096, 8, 128
device = 'cuda'
dtype = torch.bfloat16
q = torch.randn(B, T, H, K, device=device, dtype=dtype)
k = torch.randn(B, T, H, K, device=device, dtype=dtype)
v = torch.randn(B, T, H, K, device=device, dtype=dtype) # V==K==128
g = torch.randn(B, T, H, K, device=device, dtype=dtype)
beta = torch.randn(B, T, H, device=device, dtype=dtype)
A_log = torch.full((H,), -0.1, device=device, dtype=torch.float32)
dt_bias = torch.zeros(H, K, device=device, dtype=torch.float32)
with torch.inference_mode():
out, state = run_flash_kda(q, k, v, g, beta, A_log, dt_bias)
print(out.shape) # [1, 4096, 8, 128]
print(state.shape) # [1, 8, 128, 128]python
import torch
import flash_kda
def run_flash_kda(
q, k, v, g, beta,
A_log, dt_bias,
lower_bound=-5.0,
initial_state=None,
):
B, T, H, K = q.shape
V = v.shape[-1]
scale = K ** -0.5
out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device=q.device)
final_state = torch.zeros(B, H, V, K, dtype=torch.float32, device=q.device)
flash_kda.fwd(
q, k, v, g, beta,
scale, out,
A_log, dt_bias, lower_bound,
initial_state=initial_state,
final_state=final_state,
cu_seqlens=None,
)
return out, final_state
B, T, H, K = 1, 4096, 8, 128
device = 'cuda'
dtype = torch.bfloat16
q = torch.randn(B, T, H, K, device=device, dtype=dtype)
k = torch.randn(B, T, H, K, device=device, dtype=dtype)
v = torch.randn(B, T, H, K, device=device, dtype=dtype) # V==K==128
g = torch.randn(B, T, H, K, device=device, dtype=dtype)
beta = torch.randn(B, T, H, device=device, dtype=dtype)
A_log = torch.full((H,), -0.1, device=device, dtype=torch.float32)
dt_bias = torch.zeros(H, K, device=device, dtype=torch.float32)
with torch.inference_mode():
out, state = run_flash_kda(q, k, v, g, beta, A_log, dt_bias)
print(out.shape) # [1, 4096, 8, 128]
print(state.shape) # [1, 8, 128, 128]Variable-Length (Packed) Batching
变长(打包)批处理
Use for variable-length sequences packed into a single batch dimension:
cu_seqlenspython
import torch
import flash_kda使用处理打包到单个批维度的变长序列:
cu_seqlenspython
import torch
import flash_kdaTwo sequences of lengths 512 and 768, packed together
两个长度分别为512和768的序列,打包在一起
seq_lens = [512, 768]
T_total = sum(seq_lens)
N = len(seq_lens)
H, K, V = 16, 128, 128
cu_seqlens = torch.tensor([0, 512, 1280], dtype=torch.int64, device='cuda')
seq_lens = [512, 768]
T_total = sum(seq_lens)
N = len(seq_lens)
H, K, V = 16, 128, 128
cu_seqlens = torch.tensor([0, 512, 1280], dtype=torch.int64, device='cuda')
B must be 1 for varlen mode
变长模式下B必须为1
q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(1, T_total, H, dtype=torch.bfloat16, device='cuda')
A_log = torch.zeros(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
out = torch.empty(1, T_total, H, V, dtype=torch.bfloat16, device='cuda')
q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(1, T_total, H, dtype=torch.bfloat16, device='cuda')
A_log = torch.zeros(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
out = torch.empty(1, T_total, H, V, dtype=torch.bfloat16, device='cuda')
State shape is [N, H, V, K] in varlen mode
变长模式下状态形状为[N, H, V, K]
final_state = torch.zeros(N, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5
with torch.inference_mode():
flash_kda.fwd(
q, k, v, g, beta,
scale, out,
A_log, dt_bias, lower_bound=-5.0,
initial_state=None,
final_state=final_state,
cu_seqlens=cu_seqlens,
)
print(out.shape) # [1, 1280, 16, 128]
print(final_state.shape) # [2, 16, 128, 128]
undefinedfinal_state = torch.zeros(N, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5
with torch.inference_mode():
flash_kda.fwd(
q, k, v, g, beta,
scale, out,
A_log, dt_bias, lower_bound=-5.0,
initial_state=None,
final_state=final_state,
cu_seqlens=cu_seqlens,
)
print(out.shape) # [1, 1280, 16, 128]
print(final_state.shape) # [2, 16, 128, 128]
undefinedStateful Inference (Multi-turn / Streaming)
有状态推理(多轮/流式)
Pass from a previous call to maintain recurrent state across chunks:
initial_statepython
import torch
import flash_kda
H, K, V = 16, 128, 128
B = 2
scale = K ** -0.5
def inference_step(q, k, v, g, beta, A_log, dt_bias, state=None):
T = q.shape[1]
out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device='cuda')
new_state = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
flash_kda.fwd(
q, k, v, g, beta, scale, out,
A_log, dt_bias, lower_bound=-5.0,
initial_state=state,
final_state=new_state,
cu_seqlens=None,
)
return out, new_state
A_log = torch.zeros(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
state = None
for chunk_idx in range(4):
q = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(B, 256, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(B, 256, H, dtype=torch.bfloat16, device='cuda')
with torch.inference_mode():
out, state = inference_step(q, k, v, g, beta, A_log, dt_bias, state)
print(f"Chunk {chunk_idx}: out={out.shape}, state={state.shape}")传入上一次调用的以在多个块之间维持循环状态:
initial_statepython
import torch
import flash_kda
H, K, V = 16, 128, 128
B = 2
scale = K ** -0.5
def inference_step(q, k, v, g, beta, A_log, dt_bias, state=None):
T = q.shape[1]
out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device='cuda')
new_state = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
flash_kda.fwd(
q, k, v, g, beta, scale, out,
A_log, dt_bias, lower_bound=-5.0,
initial_state=state,
final_state=new_state,
cu_seqlens=None,
)
return out, new_state
A_log = torch.zeros(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
state = None
for chunk_idx in range(4):
q = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(B, 256, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(B, 256, H, dtype=torch.bfloat16, device='cuda')
with torch.inference_mode():
out, state = inference_step(q, k, v, g, beta, A_log, dt_bias, state)
print(f"Chunk {chunk_idx}: out={out.shape}, state={state.shape}")Configuration & Environment Variables
配置与环境变量
| Variable | Values | Effect |
|---|---|---|
| | Set to |
bash
undefined| 变量名 | 取值范围 | 效果 |
|---|---|---|
| | 设置为 |
bash
undefinedDisable FlashKDA, use Triton path
禁用FlashKDA,使用Triton路径
FLA_FLASH_KDA=0 python your_script.py
undefinedFLA_FLASH_KDA=0 python your_script.py
undefinedRunning Tests
运行测试
bash
bash tests/test.sh- — correctness tests against PyTorch reference and flash-linear-attention
tests/test_fwd.py
bash
bash tests/test.sh- —— 针对PyTorch参考实现和flash-linear-attention的正确性测试
tests/test_fwd.py
Common Patterns & Troubleshooting
常见模式与故障排除
Check dispatch logging
检查调度日志
python
import logging
logging.basicConfig(level=logging.INFO)python
import logging
logging.basicConfig(level=logging.INFO)Successful: [FLA Backend] kda.chunk_kda -> flashkda
成功:[FLA Backend] kda.chunk_kda -> flashkda
Rejected: [FLA Backend] kda.chunk_kda rejected: <reason>
失败:[FLA Backend] kda.chunk_kda rejected: <原因>
undefinedundefinedVerify GPU compatibility
验证GPU兼容性
python
import torch
cap = torch.cuda.get_device_capability()
assert cap >= (9, 0), f"FlashKDA requires SM90+, got SM{cap[0]}{cap[1]}"python
import torch
cap = torch.cuda.get_device_capability()
assert cap >= (9, 0), f"FlashKDA需要SM90+级GPU,当前为SM{cap[0]}{cap[1]}"K and V must be 128
K和V必须为128
python
undefinedpython
undefinedWRONG — will error
错误——会报错
q = torch.randn(1, 512, 8, 64, ...) # K=64 not supported
q = torch.randn(1, 512, 8, 64, ...) # 不支持K=64
CORRECT
正确
q = torch.randn(1, 512, 8, 128, ...) # K=128 required
undefinedq = torch.randn(1, 512, 8, 128, ...) # 必须为K=128
undefinedUse torch.inference_mode()
not torch.no_grad()
torch.inference_mode()torch.no_grad()使用torch.inference_mode()
而非torch.no_grad()
torch.inference_mode()torch.no_grad()python
undefinedpython
undefinedFlashKDA requires inference_mode for FLA dispatch
FlashKDA需要inference_mode才能进行FLA调度
with torch.inference_mode():
out, state = chunk_kda(...)
undefinedwith torch.inference_mode():
out, state = chunk_kda(...)
undefinedState dtype consistency
状态数据类型一致性
python
undefinedpython
undefinedinitial_state and final_state must have matching dtypes
initial_state和final_state的数据类型必须匹配
initial = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
final = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda') # must match
initial = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
final = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda') # 必须匹配
bf16 initial + fp32 final → error
bf16初始状态 + fp32最终状态 → 报错
undefinedundefinedlower_bound
valid range
lower_boundlower_bound
有效范围
lower_boundpython
lower_bound = -5.0 # valid: range is [-5.0, 0]
lower_bound = -2.5 # valid
lower_bound = 0.0 # valid boundary
lower_bound = -10.0 # out of spec — use -5.0 as safe minimumpython
lower_bound = -5.0 # 有效:范围为[-5.0, 0]
lower_bound = -2.5 # 有效
lower_bound = 0.0 # 有效边界
lower_bound = -10.0 # 超出规格——使用-5.0作为安全最小值IntelliSense / clangd setup for development
开发环境的IntelliSense / clangd设置
bash
bash setup_clangd.shbash
bash setup_clangd.shGenerates .clangd with correct include paths for CUDA/CUTLASS sources
生成包含CUDA/CUTLASS源文件正确包含路径的.clangd文件
undefinedundefined