rwkv-architecture

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

RWKV - Receptance Weighted Key Value

RWKV - Receptance Weighted Key Value

Quick start

快速开始

RWKV (RwaKuv) combines Transformer parallelization (training) with RNN efficiency (inference).
Installation:
bash
undefined
RWKV(读作RwaKuv)结合了Transformer的并行化训练特性与RNN的推理效率。
安装步骤:
bash
undefined

Install PyTorch

Install PyTorch

pip install torch --upgrade --extra-index-url https://download.pytorch.org/whl/cu121
pip install torch --upgrade --extra-index-url https://download.pytorch.org/whl/cu121

Install dependencies

Install dependencies

pip install pytorch-lightning==1.9.5 deepspeed wandb ninja --upgrade
pip install pytorch-lightning==1.9.5 deepspeed wandb ninja --upgrade

Install RWKV

Install RWKV

pip install rwkv

**Basic usage** (GPT mode + RNN mode):
```python
import os
from rwkv.model import RWKV

os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1'  # Use CUDA kernel for speed
pip install rwkv

**基础用法**(GPT模式 + RNN模式):
```python
import os
from rwkv.model import RWKV

os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1'  # Use CUDA kernel for speed

Load model

Load model

model = RWKV( model='/path/to/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16' )
model = RWKV( model='/path/to/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16' )

GPT mode (parallel processing)

GPT mode (parallel processing)

out, state = model.forward([187, 510, 1563, 310, 247], None) print(out.detach().cpu().numpy()) # Logits
out, state = model.forward([187, 510, 1563, 310, 247], None) print(out.detach().cpu().numpy()) # Logits

RNN mode (sequential processing, same result)

RNN mode (sequential processing, same result)

out, state = model.forward([187, 510], None) # First 2 tokens out, state = model.forward([1563], state) # Next token out, state = model.forward([310, 247], state) # Last tokens print(out.detach().cpu().numpy()) # Same logits as above!
undefined
out, state = model.forward([187, 510], None) # First 2 tokens out, state = model.forward([1563], state) # Next token out, state = model.forward([310, 247], state) # Last tokens print(out.detach().cpu().numpy()) # Same logits as above!
undefined

Common workflows

常见工作流

Workflow 1: Text generation (streaming)

工作流1:文本生成(流式)

Efficient token-by-token generation:
python
from rwkv.model import RWKV
from rwkv.utils import PIPELINE

model = RWKV(model='RWKV-4-Pile-14B-20230313-ctx8192-test1050', strategy='cuda fp16')
pipeline = PIPELINE(model, "20B_tokenizer.json")
高效逐token生成:
python
from rwkv.model import RWKV
from rwkv.utils import PIPELINE

model = RWKV(model='RWKV-4-Pile-14B-20230313-ctx8192-test1050', strategy='cuda fp16')
pipeline = PIPELINE(model, "20B_tokenizer.json")

Initial prompt

Initial prompt

prompt = "The future of AI is" state = None
prompt = "The future of AI is" state = None

Generate token by token

Generate token by token

for token in prompt: out, state = pipeline.model.forward(pipeline.encode(token), state)
for token in prompt: out, state = pipeline.model.forward(pipeline.encode(token), state)

Continue generation

Continue generation

for _ in range(100): out, state = pipeline.model.forward(None, state) token = pipeline.sample_logits(out) print(pipeline.decode(token), end='', flush=True)

**Key advantage**: Constant memory per token (no growing KV cache)
for _ in range(100): out, state = pipeline.model.forward(None, state) token = pipeline.sample_logits(out) print(pipeline.decode(token), end='', flush=True)

**核心优势**:每个token的内存占用恒定(无不断增长的KV缓存)

Workflow 2: Long context processing (infinite context)

工作流2:长上下文处理(无限上下文)

Process million-token sequences:
python
model = RWKV(model='RWKV-4-Pile-14B', strategy='cuda fp16')
处理百万级token序列:
python
model = RWKV(model='RWKV-4-Pile-14B', strategy='cuda fp16')

Process very long document

Process very long document

state = None long_document = load_document() # e.g., 1M tokens
state = None long_document = load_document() # e.g., 1M tokens

Stream through entire document

Stream through entire document

for chunk in chunks(long_document, chunk_size=1024): out, state = model.forward(chunk, state)
for chunk in chunks(long_document, chunk_size=1024): out, state = model.forward(chunk, state)

State now contains information from entire 1M token document

State now contains information from entire 1M token document

Memory usage: O(1) (constant, not O(n)!)

Memory usage: O(1) (constant, not O(n)!)

undefined
undefined

Workflow 3: Fine-tuning RWKV

工作流3:RWKV微调

Standard fine-tuning workflow:
python
undefined
标准微调工作流:
python
undefined

Training script

Training script

import pytorch_lightning as pl from rwkv.model import RWKV from rwkv.trainer import RWKVTrainer
import pytorch_lightning as pl from rwkv.model import RWKV from rwkv.trainer import RWKVTrainer

Configure model

Configure model

config = { 'n_layer': 24, 'n_embd': 1024, 'vocab_size': 50277, 'ctx_len': 1024 }
config = { 'n_layer': 24, 'n_embd': 1024, 'vocab_size': 50277, 'ctx_len': 1024 }

Setup trainer

Setup trainer

trainer = pl.Trainer( accelerator='gpu', devices=8, precision='bf16', strategy='deepspeed_stage_2', max_epochs=1 )
trainer = pl.Trainer( accelerator='gpu', devices=8, precision='bf16', strategy='deepspeed_stage_2', max_epochs=1 )

Train

Train

model = RWKV(config) trainer.fit(model, train_dataloader)
undefined
model = RWKV(config) trainer.fit(model, train_dataloader)
undefined

Workflow 4: RWKV vs Transformer comparison

工作流4:RWKV与Transformer对比

Memory comparison (1M token sequence):
python
undefined
内存对比(100万token序列):
python
undefined

Transformer (GPT)

Transformer (GPT)

Memory: O(n²) for attention

Memory: O(n²) for attention

KV cache: 1M × hidden_dim × n_layers × 2 (keys + values)

KV cache: 1M × hidden_dim × n_layers × 2 (keys + values)

Example: 1M × 4096 × 24 × 2 = ~400GB (impractical!)

Example: 1M × 4096 × 24 × 2 = ~400GB (impractical!)

RWKV

RWKV

Memory: O(1) per token

Memory: O(1) per token

State: hidden_dim × n_layers = 4096 × 24 = ~400KB

State: hidden_dim × n_layers = 4096 × 24 = ~400KB

1,000,000× more efficient!

1,000,000× more efficient!


**Speed comparison** (inference):
```python

**速度对比**(推理阶段):
```python

Transformer: O(n) per token (quadratic overall)

Transformer: O(n) per token (quadratic overall)

First token: 1 computation

First token: 1 computation

Second token: 2 computations

Second token: 2 computations

...

...

1000th token: 1000 computations

1000th token: 1000 computations

RWKV: O(1) per token (linear overall)

RWKV: O(1) per token (linear overall)

Every token: 1 computation

Every token: 1 computation

1000th token: 1 computation (same as first!)

1000th token: 1 computation (same as first!)

undefined
undefined

When to use vs alternatives

适用场景与替代方案

Use RWKV when:
  • Need very long context (100K+ tokens)
  • Want constant memory usage
  • Building streaming applications
  • Need RNN efficiency with Transformer performance
  • Memory-constrained deployment
Key advantages:
  • Linear time: O(n) vs O(n²) for Transformers
  • No KV cache: Constant memory per token
  • Infinite context: No fixed window limit
  • Parallelizable training: Like GPT
  • Sequential inference: Like RNN
Use alternatives instead:
  • Transformers: Need absolute best performance, have compute
  • Mamba: Want state-space models
  • RetNet: Need retention mechanism
  • Hyena: Want convolution-based approach
适合使用RWKV的场景
  • 需要超长上下文(10万+ token)
  • 希望内存占用恒定
  • 构建流式应用
  • 需要RNN的效率与Transformer的性能
  • 内存受限的部署环境
核心优势:
  • 线性时间:复杂度为O(n),对比Transformer的O(n²)
  • 无KV缓存:每个token的内存占用恒定
  • 无限上下文:无固定窗口限制
  • 并行训练:与GPT类似的并行训练方式
  • 顺序推理:与RNN类似的顺序推理方式
适合使用替代方案的场景
  • Transformers:追求极致性能且具备充足计算资源
  • Mamba:偏好状态空间模型
  • RetNet:需要保留机制
  • Hyena:偏好基于卷积的方法

Common issues

常见问题

Issue: Out of memory during training
Use gradient checkpointing and DeepSpeed:
python
trainer = pl.Trainer(
    strategy='deepspeed_stage_3',  # Full ZeRO-3
    precision='bf16'
)
Issue: Slow inference
Enable CUDA kernel:
python
os.environ["RWKV_CUDA_ON"] = '1'
Issue: Model not loading
Check model path and strategy:
python
model = RWKV(
    model='/absolute/path/to/model.pth',
    strategy='cuda fp16'  # Or 'cpu fp32' for CPU
)
Issue: State management in RNN mode
Always pass state between forward calls:
python
undefined
问题:训练时内存不足
使用梯度检查点与DeepSpeed:
python
trainer = pl.Trainer(
    strategy='deepspeed_stage_3',  # Full ZeRO-3
    precision='bf16'
)
问题:推理速度慢
启用CUDA内核:
python
os.environ["RWKV_CUDA_ON"] = '1'
问题:模型无法加载
检查模型路径与策略:
python
model = RWKV(
    model='/absolute/path/to/model.pth',
    strategy='cuda fp16'  # Or 'cpu fp32' for CPU
)
问题:RNN模式下的状态管理
始终在forward调用之间传递state:
python
undefined

WRONG: State lost

WRONG: State lost

out1, _ = model.forward(tokens1, None) out2, _ = model.forward(tokens2, None) # No context from tokens1!
out1, _ = model.forward(tokens1, None) out2, _ = model.forward(tokens2, None) # No context from tokens1!

CORRECT: State preserved

CORRECT: State preserved

out1, state = model.forward(tokens1, None) out2, state = model.forward(tokens2, state) # Has context from tokens1
undefined
out1, state = model.forward(tokens1, None) out2, state = model.forward(tokens2, state) # Has context from tokens1
undefined

Advanced topics

进阶主题

Time-mixing and channel-mixing: See references/architecture-details.md for WKV operation, time-decay mechanism, and receptance gates.
State management: See references/state-management.md for att_x_prev, att_kv, ffn_x_prev states, and numerical stability considerations.
RWKV-7 improvements: See references/rwkv7.md for latest architectural improvements (March 2025) and multimodal capabilities.
时间混合与通道混合:有关WKV运算、时间衰减机制和接受度门控的详细信息,请参阅references/architecture-details.md
状态管理:有关att_x_prev、att_kv、ffn_x_prev状态以及数值稳定性的注意事项,请参阅references/state-management.md
RWKV-7改进:有关2025年3月发布的最新架构改进与多模态能力,请参阅references/rwkv7.md

Hardware requirements

硬件要求

  • GPU: NVIDIA (CUDA 11.6+) or CPU
  • VRAM (FP16):
    • 169M model: 1GB
    • 430M model: 2GB
    • 1.5B model: 4GB
    • 3B model: 8GB
    • 7B model: 16GB
    • 14B model: 32GB
  • Inference: O(1) memory per token
  • Training: Parallelizable like GPT
Performance (vs Transformers):
  • Speed: Similar training, faster inference
  • Memory: 1000× less for long sequences
  • Scaling: Linear vs quadratic
  • GPU:NVIDIA(支持CUDA 11.6+)或CPU
  • VRAM(FP16精度):
    • 169M模型:1GB
    • 430M模型:2GB
    • 1.5B模型:4GB
    • 3B模型:8GB
    • 7B模型:16GB
    • 14B模型:32GB
  • 推理:每个token的内存占用为O(1)
  • 训练:可像GPT一样并行化
性能对比(与Transformers):
  • 速度:训练速度相近,推理速度更快
  • 内存:长序列场景下内存效率高1000倍
  • 扩展性:线性扩展对比二次方扩展

Resources

相关资源