flashkda-delta-attention

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

FlashKDA Delta Attention Skill

FlashKDA Delta Attention Skill

Skill by ara.so — Daily 2026 Skills collection.
FlashKDA provides high-performance CUDA kernels for Kimi Delta Attention (KDA) built on CUTLASS. It targets SM90+ GPUs (H100/H20 class) and integrates as a drop-in backend for
flash-linear-attention
's
chunk_kda
operation.
ara.so开发的技能——Daily 2026技能合集。
FlashKDA 提供基于CUTLASS构建的Kimi Delta Attention(KDA)高性能CUDA核。它面向SM90+级GPU(H100/H20系列),可作为
flash-linear-attention
chunk_kda
操作的即插即用后端集成。

Requirements

要求

  • GPU: SM90+ (H100, H20, or newer)
  • CUDA 12.9+
  • PyTorch 2.4+
  • Python 3.8+
  • GPU:SM90+(H100、H20或更新型号)
  • CUDA 12.9+
  • PyTorch 2.4+
  • Python 3.8+

Installation

安装

bash
git clone https://github.com/MoonshotAI/FlashKDA.git flash-kda
cd flash-kda
git submodule update --init --recursive
pip install -v .
Install the FLA integration (optional but recommended):
bash
pip install -U flash-linear-attention  # >= 0.5.0
bash
git clone https://github.com/MoonshotAI/FlashKDA.git flash-kda
cd flash-kda
git submodule update --init --recursive
pip install -v .
安装FLA集成(可选但推荐):
bash
pip install -U flash-linear-attention  # >= 0.5.0

Core Kernel API

核心核API

flash_kda.fwd

flash_kda.fwd

The primary low-level kernel call:
python
import torch
import flash_kda

flash_kda.fwd(
    q, k, v, g, beta, scale, out,
    A_log, dt_bias, lower_bound,
    initial_state=None,
    final_state=None,
    cu_seqlens=None
)
Tensor shapes and dtypes:
ParameterDtypeShapeNotes
q
bf16
[B, T, H, K]
Query; K must be 128
k
bf16
[B, T, H, K]
Key; K must be 128
v
bf16
[B, T, H, V]
Value; V must be 128
g
bf16
[B, T, H, K]
Gate logits (sigmoid/activation applied internally)
beta
bf16
[B, T, H]
Beta logits (sigmoid applied internally)
scale
floatscalarAttention scale factor
out
bf16
[B, T, H, V]
Pre-allocated output tensor
A_log
fp32
[H]
Per-head log-gate parameter
dt_bias
fp32
[H, K]
Per-head gate bias
lower_bound
floatscalarGate lower bound, range
[-5.0, 0]
initial_state
bf16/fp32/None
[B, H, V, K]
or
[N, H, V, K]
Optional initial recurrent state
final_state
bf16/fp32/None
[B, H, V, K]
or
[N, H, V, K]
Optional output final state
cu_seqlens
int64
[N+1]
Optional cumulative seq lengths for varlen
Constraints:
  • K == V == 128
    required
  • When
    cu_seqlens
    is provided,
    B
    must be 1 and
    T
    is total tokens across all sequences
  • initial_state
    and
    final_state
    dtypes must match when both provided
主要的底层核调用:
python
import torch
import flash_kda

flash_kda.fwd(
    q, k, v, g, beta, scale, out,
    A_log, dt_bias, lower_bound,
    initial_state=None,
    final_state=None,
    cu_seqlens=None
)
张量形状与数据类型:
参数数据类型形状说明
q
bf16
[B, T, H, K]
查询向量;K必须为128
k
bf16
[B, T, H, K]
键向量;K必须为128
v
bf16
[B, T, H, V]
值向量;V必须为128
g
bf16
[B, T, H, K]
门控logits(内部会应用sigmoid/激活函数)
beta
bf16
[B, T, H]
Beta logits(内部会应用sigmoid函数)
scale
float标量注意力缩放因子
out
bf16
[B, T, H, V]
预分配的输出张量
A_log
fp32
[H]
每头的log门控参数
dt_bias
fp32
[H, K]
每头的门控偏置
lower_bound
float标量门控下界,范围为
[-5.0, 0]
initial_state
bf16/fp32/None
[B, H, V, K]
[N, H, V, K]
可选的初始循环状态
final_state
bf16/fp32/None
[B, H, V, K]
[N, H, V, K]
可选的输出最终状态
cu_seqlens
int64
[N+1]
可选的变长序列累积长度
约束条件:
  • 必须满足
    K == V == 128
  • 当提供
    cu_seqlens
    时,
    B
    必须为1,
    T
    为所有序列的总token数
  • 若同时提供
    initial_state
    final_state
    ,二者的数据类型必须匹配

Usage via flash-linear-attention Backend (Recommended)

通过flash-linear-attention后端使用(推荐)

FlashKDA auto-dispatches from FLA's
chunk_kda
when installed:
python
import torch
import logging
from fla.ops.kda import chunk_kda
安装FlashKDA后,会自动从FLA的
chunk_kda
中调度:
python
import torch
import logging
from fla.ops.kda import chunk_kda

Optional: see dispatch decisions

可选:查看调度决策

logging.basicConfig(level=logging.INFO)
B, T, H, K, V = 2, 2048, 16, 128, 128
q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') g = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') beta = torch.randn(B, T, H, dtype=torch.bfloat16, device='cuda') A_log = torch.randn(H, dtype=torch.float32, device='cuda') dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda') h0 = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5 lower_bound = -5.0
with torch.inference_mode(): out, final_state = chunk_kda( q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=h0, output_final_state=True, use_gate_in_kernel=True, use_qk_l2norm_in_kernel=True, use_beta_sigmoid_in_kernel=True, safe_gate=True, A_log=A_log, dt_bias=dt_bias, lower_bound=lower_bound, transpose_state_layout=True, )
logging.basicConfig(level=logging.INFO)
B, T, H, K, V = 2, 2048, 16, 128, 128
q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') g = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') beta = torch.randn(B, T, H, dtype=torch.bfloat16, device='cuda') A_log = torch.randn(H, dtype=torch.float32, device='cuda') dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda') h0 = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5 lower_bound = -5.0
with torch.inference_mode(): out, final_state = chunk_kda( q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=h0, output_final_state=True, use_gate_in_kernel=True, use_qk_l2norm_in_kernel=True, use_beta_sigmoid_in_kernel=True, safe_gate=True, A_log=A_log, dt_bias=dt_bias, lower_bound=lower_bound, transpose_state_layout=True, )

out: [B, T, H, V], final_state: [B, H, V, K]

out: [B, T, H, V], final_state: [B, H, V, K]

undefined
undefined

Direct Low-Level Kernel Usage

直接使用底层核

python
import torch
import flash_kda

def run_flash_kda(
    q, k, v, g, beta,
    A_log, dt_bias,
    lower_bound=-5.0,
    initial_state=None,
):
    B, T, H, K = q.shape
    V = v.shape[-1]
    scale = K ** -0.5

    out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device=q.device)
    final_state = torch.zeros(B, H, V, K, dtype=torch.float32, device=q.device)

    flash_kda.fwd(
        q, k, v, g, beta,
        scale, out,
        A_log, dt_bias, lower_bound,
        initial_state=initial_state,
        final_state=final_state,
        cu_seqlens=None,
    )
    return out, final_state


B, T, H, K = 1, 4096, 8, 128
device = 'cuda'
dtype  = torch.bfloat16

q       = torch.randn(B, T, H, K,   device=device, dtype=dtype)
k       = torch.randn(B, T, H, K,   device=device, dtype=dtype)
v       = torch.randn(B, T, H, K,   device=device, dtype=dtype)  # V==K==128
g       = torch.randn(B, T, H, K,   device=device, dtype=dtype)
beta    = torch.randn(B, T, H,      device=device, dtype=dtype)
A_log   = torch.full((H,), -0.1,    device=device, dtype=torch.float32)
dt_bias = torch.zeros(H, K,         device=device, dtype=torch.float32)

with torch.inference_mode():
    out, state = run_flash_kda(q, k, v, g, beta, A_log, dt_bias)

print(out.shape)    # [1, 4096, 8, 128]
print(state.shape)  # [1, 8, 128, 128]
python
import torch
import flash_kda

def run_flash_kda(
    q, k, v, g, beta,
    A_log, dt_bias,
    lower_bound=-5.0,
    initial_state=None,
):
    B, T, H, K = q.shape
    V = v.shape[-1]
    scale = K ** -0.5

    out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device=q.device)
    final_state = torch.zeros(B, H, V, K, dtype=torch.float32, device=q.device)

    flash_kda.fwd(
        q, k, v, g, beta,
        scale, out,
        A_log, dt_bias, lower_bound,
        initial_state=initial_state,
        final_state=final_state,
        cu_seqlens=None,
    )
    return out, final_state


B, T, H, K = 1, 4096, 8, 128
device = 'cuda'
dtype  = torch.bfloat16

q       = torch.randn(B, T, H, K,   device=device, dtype=dtype)
k       = torch.randn(B, T, H, K,   device=device, dtype=dtype)
v       = torch.randn(B, T, H, K,   device=device, dtype=dtype)  # V==K==128
g       = torch.randn(B, T, H, K,   device=device, dtype=dtype)
beta    = torch.randn(B, T, H,      device=device, dtype=dtype)
A_log   = torch.full((H,), -0.1,    device=device, dtype=torch.float32)
dt_bias = torch.zeros(H, K,         device=device, dtype=torch.float32)

with torch.inference_mode():
    out, state = run_flash_kda(q, k, v, g, beta, A_log, dt_bias)

print(out.shape)    # [1, 4096, 8, 128]
print(state.shape)  # [1, 8, 128, 128]

Variable-Length (Packed) Batching

变长(打包)批处理

Use
cu_seqlens
for variable-length sequences packed into a single batch dimension:
python
import torch
import flash_kda
使用
cu_seqlens
处理打包到单个批维度的变长序列:
python
import torch
import flash_kda

Two sequences of lengths 512 and 768, packed together

两个长度分别为512和768的序列,打包在一起

seq_lens = [512, 768] T_total = sum(seq_lens) N = len(seq_lens) H, K, V = 16, 128, 128
cu_seqlens = torch.tensor([0, 512, 1280], dtype=torch.int64, device='cuda')
seq_lens = [512, 768] T_total = sum(seq_lens) N = len(seq_lens) H, K, V = 16, 128, 128
cu_seqlens = torch.tensor([0, 512, 1280], dtype=torch.int64, device='cuda')

B must be 1 for varlen mode

变长模式下B必须为1

q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda') k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda') v = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device='cuda') g = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda') beta = torch.randn(1, T_total, H, dtype=torch.bfloat16, device='cuda')
A_log = torch.zeros(H, dtype=torch.float32, device='cuda') dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
out = torch.empty(1, T_total, H, V, dtype=torch.bfloat16, device='cuda')
q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda') k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda') v = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device='cuda') g = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda') beta = torch.randn(1, T_total, H, dtype=torch.bfloat16, device='cuda')
A_log = torch.zeros(H, dtype=torch.float32, device='cuda') dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
out = torch.empty(1, T_total, H, V, dtype=torch.bfloat16, device='cuda')

State shape is [N, H, V, K] in varlen mode

变长模式下状态形状为[N, H, V, K]

final_state = torch.zeros(N, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5
with torch.inference_mode(): flash_kda.fwd( q, k, v, g, beta, scale, out, A_log, dt_bias, lower_bound=-5.0, initial_state=None, final_state=final_state, cu_seqlens=cu_seqlens, )
print(out.shape) # [1, 1280, 16, 128] print(final_state.shape) # [2, 16, 128, 128]
undefined
final_state = torch.zeros(N, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5
with torch.inference_mode(): flash_kda.fwd( q, k, v, g, beta, scale, out, A_log, dt_bias, lower_bound=-5.0, initial_state=None, final_state=final_state, cu_seqlens=cu_seqlens, )
print(out.shape) # [1, 1280, 16, 128] print(final_state.shape) # [2, 16, 128, 128]
undefined

Stateful Inference (Multi-turn / Streaming)

有状态推理(多轮/流式)

Pass
initial_state
from a previous call to maintain recurrent state across chunks:
python
import torch
import flash_kda

H, K, V = 16, 128, 128
B = 2
scale = K ** -0.5

def inference_step(q, k, v, g, beta, A_log, dt_bias, state=None):
    T = q.shape[1]
    out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device='cuda')
    new_state = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
    flash_kda.fwd(
        q, k, v, g, beta, scale, out,
        A_log, dt_bias, lower_bound=-5.0,
        initial_state=state,
        final_state=new_state,
        cu_seqlens=None,
    )
    return out, new_state

A_log   = torch.zeros(H,    dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')

state = None
for chunk_idx in range(4):
    q    = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
    k    = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
    v    = torch.randn(B, 256, H, V, dtype=torch.bfloat16, device='cuda')
    g    = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
    beta = torch.randn(B, 256, H,    dtype=torch.bfloat16, device='cuda')

    with torch.inference_mode():
        out, state = inference_step(q, k, v, g, beta, A_log, dt_bias, state)
    print(f"Chunk {chunk_idx}: out={out.shape}, state={state.shape}")
传入上一次调用的
initial_state
以在多个块之间维持循环状态:
python
import torch
import flash_kda

H, K, V = 16, 128, 128
B = 2
scale = K ** -0.5

def inference_step(q, k, v, g, beta, A_log, dt_bias, state=None):
    T = q.shape[1]
    out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device='cuda')
    new_state = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
    flash_kda.fwd(
        q, k, v, g, beta, scale, out,
        A_log, dt_bias, lower_bound=-5.0,
        initial_state=state,
        final_state=new_state,
        cu_seqlens=None,
    )
    return out, new_state

A_log   = torch.zeros(H,    dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')

state = None
for chunk_idx in range(4):
    q    = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
    k    = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
    v    = torch.randn(B, 256, H, V, dtype=torch.bfloat16, device='cuda')
    g    = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
    beta = torch.randn(B, 256, H,    dtype=torch.bfloat16, device='cuda')

    with torch.inference_mode():
        out, state = inference_step(q, k, v, g, beta, A_log, dt_bias, state)
    print(f"Chunk {chunk_idx}: out={out.shape}, state={state.shape}")

Configuration & Environment Variables

配置与环境变量

VariableValuesEffect
FLA_FLASH_KDA
0
/
1
Set to
0
to force Triton fallback in FLA
bash
undefined
变量名取值范围效果
FLA_FLASH_KDA
0
/
1
设置为
0
可强制FLA使用Triton作为回退方案
bash
undefined

Disable FlashKDA, use Triton path

禁用FlashKDA,使用Triton路径

FLA_FLASH_KDA=0 python your_script.py
undefined
FLA_FLASH_KDA=0 python your_script.py
undefined

Running Tests

运行测试

bash
bash tests/test.sh
  • tests/test_fwd.py
    — correctness tests against PyTorch reference and flash-linear-attention
bash
bash tests/test.sh
  • tests/test_fwd.py
    —— 针对PyTorch参考实现和flash-linear-attention的正确性测试

Common Patterns & Troubleshooting

常见模式与故障排除

Check dispatch logging

检查调度日志

python
import logging
logging.basicConfig(level=logging.INFO)
python
import logging
logging.basicConfig(level=logging.INFO)

Successful: [FLA Backend] kda.chunk_kda -> flashkda

成功:[FLA Backend] kda.chunk_kda -> flashkda

Rejected: [FLA Backend] kda.chunk_kda rejected: <reason>

失败:[FLA Backend] kda.chunk_kda rejected: <原因>

undefined
undefined

Verify GPU compatibility

验证GPU兼容性

python
import torch
cap = torch.cuda.get_device_capability()
assert cap >= (9, 0), f"FlashKDA requires SM90+, got SM{cap[0]}{cap[1]}"
python
import torch
cap = torch.cuda.get_device_capability()
assert cap >= (9, 0), f"FlashKDA需要SM90+级GPU,当前为SM{cap[0]}{cap[1]}"

K and V must be 128

K和V必须为128

python
undefined
python
undefined

WRONG — will error

错误——会报错

q = torch.randn(1, 512, 8, 64, ...) # K=64 not supported
q = torch.randn(1, 512, 8, 64, ...) # 不支持K=64

CORRECT

正确

q = torch.randn(1, 512, 8, 128, ...) # K=128 required
undefined
q = torch.randn(1, 512, 8, 128, ...) # 必须为K=128
undefined

Use
torch.inference_mode()
not
torch.no_grad()

使用
torch.inference_mode()
而非
torch.no_grad()

python
undefined
python
undefined

FlashKDA requires inference_mode for FLA dispatch

FlashKDA需要inference_mode才能进行FLA调度

with torch.inference_mode(): out, state = chunk_kda(...)
undefined
with torch.inference_mode(): out, state = chunk_kda(...)
undefined

State dtype consistency

状态数据类型一致性

python
undefined
python
undefined

initial_state and final_state must have matching dtypes

initial_state和final_state的数据类型必须匹配

initial = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda') final = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda') # must match
initial = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda') final = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda') # 必须匹配

bf16 initial + fp32 final → error

bf16初始状态 + fp32最终状态 → 报错

undefined
undefined

lower_bound
valid range

lower_bound
有效范围

python
lower_bound = -5.0   # valid: range is [-5.0, 0]
lower_bound = -2.5   # valid
lower_bound = 0.0    # valid boundary
lower_bound = -10.0  # out of spec — use -5.0 as safe minimum
python
lower_bound = -5.0   # 有效:范围为[-5.0, 0]
lower_bound = -2.5   # 有效
lower_bound = 0.0    # 有效边界
lower_bound = -10.0  # 超出规格——使用-5.0作为安全最小值

IntelliSense / clangd setup for development

开发环境的IntelliSense / clangd设置

bash
bash setup_clangd.sh
bash
bash setup_clangd.sh

Generates .clangd with correct include paths for CUDA/CUTLASS sources

生成包含CUDA/CUTLASS源文件正确包含路径的.clangd文件

undefined
undefined