moe-training
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseMoE Training: Mixture of Experts
MoE训练:混合专家模型
When to Use This Skill
何时使用该技能
Use MoE Training when you need to:
- Train larger models with limited compute (5× cost reduction vs dense models)
- Scale model capacity without proportional compute increase
- Achieve better performance per compute budget than dense models
- Specialize experts for different domains/tasks/languages
- Reduce inference latency with sparse activation (only 13B/47B params active in Mixtral)
- Implement SOTA models like Mixtral 8x7B, DeepSeek-V3, Switch Transformers
Notable MoE Models: Mixtral 8x7B (Mistral AI), DeepSeek-V3, Switch Transformers (Google), GLaM (Google), NLLB-MoE (Meta)
在以下场景中使用MoE训练:
- 在计算资源有限的情况下训练更大的模型(相比密集型模型可降低5倍成本)
- 在不按比例增加计算资源的前提下扩展模型容量
- 相比密集型模型,在相同计算预算下实现更好的性能
- 针对不同领域/任务/语言实现专家模型的专业化
- 通过稀疏激活降低推理延迟(Mixtral中仅13B/47B参数被激活)
- 实现SOTA模型,如Mixtral 8x7B、DeepSeek-V3、Switch Transformers
知名MoE模型:Mixtral 8x7B(Mistral AI)、DeepSeek-V3、Switch Transformers(Google)、GLaM(Google)、NLLB-MoE(Meta)
Installation
安装
bash
undefinedbash
undefinedDeepSpeed with MoE support
DeepSpeed with MoE support
pip install deepspeed>=0.6.0
pip install deepspeed>=0.6.0
Megatron-DeepSpeed for large-scale training
Megatron-DeepSpeed for large-scale training
git clone https://github.com/microsoft/Megatron-DeepSpeed
cd Megatron-DeepSpeed
pip install -r requirements.txt
git clone https://github.com/microsoft/Megatron-DeepSpeed
cd Megatron-DeepSpeed
pip install -r requirements.txt
Alternative: HuggingFace Transformers
Alternative: HuggingFace Transformers
pip install transformers accelerate
undefinedpip install transformers accelerate
undefinedQuick Start
快速开始
Basic MoE Architecture
基础MoE架构
python
import torch
import torch.nn as nn
class MoELayer(nn.Module):
"""Sparse Mixture of Experts layer."""
def __init__(self, hidden_size, num_experts=8, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Expert networks (FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
for _ in range(num_experts)
])
# Gating network (router)
self.gate = nn.Linear(hidden_size, num_experts)
def forward(self, x):
# x shape: (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = x.shape
# Flatten for routing
x_flat = x.view(-1, hidden_size) # (batch_size * seq_len, hidden_size)
# Compute gate scores
gate_logits = self.gate(x_flat) # (batch_size * seq_len, num_experts)
# Top-k routing
gate_scores = torch.softmax(gate_logits, dim=-1)
topk_scores, topk_indices = torch.topk(gate_scores, self.top_k, dim=-1)
# Normalize top-k scores
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True)
# Dispatch and combine expert outputs
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_idx = topk_indices[:, i]
expert_scores = topk_scores[:, i].unsqueeze(-1)
# Route tokens to experts
for expert_id in range(self.num_experts):
mask = (expert_idx == expert_id)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.experts[expert_id](expert_input)
output[mask] += expert_scores[mask] * expert_output
# Reshape back
return output.view(batch_size, seq_len, hidden_size)python
import torch
import torch.nn as nn
class MoELayer(nn.Module):
"""Sparse Mixture of Experts layer."""
def __init__(self, hidden_size, num_experts=8, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Expert networks (FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
for _ in range(num_experts)
])
# Gating network (router)
self.gate = nn.Linear(hidden_size, num_experts)
def forward(self, x):
# x shape: (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = x.shape
# Flatten for routing
x_flat = x.view(-1, hidden_size) # (batch_size * seq_len, hidden_size)
# Compute gate scores
gate_logits = self.gate(x_flat) # (batch_size * seq_len, num_experts)
# Top-k routing
gate_scores = torch.softmax(gate_logits, dim=-1)
topk_scores, topk_indices = torch.topk(gate_scores, self.top_k, dim=-1)
# Normalize top-k scores
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True)
# Dispatch and combine expert outputs
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_idx = topk_indices[:, i]
expert_scores = topk_scores[:, i].unsqueeze(-1)
# Route tokens to experts
for expert_id in range(self.num_experts):
mask = (expert_idx == expert_id)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.experts[expert_id](expert_input)
output[mask] += expert_scores[mask] * expert_output
# Reshape back
return output.view(batch_size, seq_len, hidden_size)DeepSpeed MoE Training
DeepSpeed MoE训练
bash
undefinedbash
undefinedTraining script with MoE
Training script with MoE
deepspeed pretrain_gpt_moe.py
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--seq-length 2048
--max-position-embeddings 2048
--micro-batch-size 4
--global-batch-size 256
--train-iters 500000
--lr 0.0001
--min-lr 0.00001
--lr-decay-style cosine
--num-experts 128
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--moe-train-capacity-factor 1.25
--moe-eval-capacity-factor 2.0
--fp16
--deepspeed_config ds_config.json
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--seq-length 2048
--max-position-embeddings 2048
--micro-batch-size 4
--global-batch-size 256
--train-iters 500000
--lr 0.0001
--min-lr 0.00001
--lr-decay-style cosine
--num-experts 128
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--moe-train-capacity-factor 1.25
--moe-eval-capacity-factor 2.0
--fp16
--deepspeed_config ds_config.json
undefineddeepspeed pretrain_gpt_moe.py
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--seq-length 2048
--max-position-embeddings 2048
--micro-batch-size 4
--global-batch-size 256
--train-iters 500000
--lr 0.0001
--min-lr 0.00001
--lr-decay-style cosine
--num-experts 128
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--moe-train-capacity-factor 1.25
--moe-eval-capacity-factor 2.0
--fp16
--deepspeed_config ds_config.json
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--seq-length 2048
--max-position-embeddings 2048
--micro-batch-size 4
--global-batch-size 256
--train-iters 500000
--lr 0.0001
--min-lr 0.00001
--lr-decay-style cosine
--num-experts 128
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--moe-train-capacity-factor 1.25
--moe-eval-capacity-factor 2.0
--fp16
--deepspeed_config ds_config.json
undefinedCore Concepts
核心概念
1. MoE Architecture
1. MoE架构
Key Components:
- Experts: Multiple specialized FFN networks (typically 8-128)
- Router/Gate: Learned network that selects which experts to use
- Top-k Routing: Activate only k experts per token (k=1 or k=2)
- Load Balancing: Ensure even expert utilization
Input Token
↓
Router (Gate Network)
↓
Top-k Expert Selection (e.g., 2 out of 8)
↓
Expert 1 (weight: 0.6) + Expert 5 (weight: 0.4)
↓
Weighted Combination
↓
Output关键组件:
- 专家网络:多个专业化的FFN网络(通常为8-128个)
- 路由/门控网络:用于选择使用哪些专家的可学习网络
- Top-k路由:每个token仅激活k个专家(k=1或k=2)
- 负载均衡:确保专家网络的使用更均匀
输入Token
↓
路由网络(门控网络)
↓
Top-k专家选择(例如,8选2)
↓
专家1(权重:0.6) + 专家5(权重:0.4)
↓
加权组合
↓
输出2. Routing Mechanisms
2. 路由机制
Top-1 Routing (Switch Transformer):
python
undefinedTop-1路由(Switch Transformer):
python
undefinedSimplest routing: one expert per token
Simplest routing: one expert per token
gate_logits = router(x) # (batch, seq_len, num_experts)
expert_idx = torch.argmax(gate_logits, dim=-1) # Hard routing
**Top-2 Routing (Mixtral):**
```pythongate_logits = router(x) # (batch, seq_len, num_experts)
expert_idx = torch.argmax(gate_logits, dim=-1) # Hard routing
**Top-2路由(Mixtral):**
```pythonTop-2: two experts per token
Top-2: two experts per token
gate_scores = torch.softmax(router(x), dim=-1)
top2_scores, top2_indices = torch.topk(gate_scores, k=2, dim=-1)
gate_scores = torch.softmax(router(x), dim=-1)
top2_scores, top2_indices = torch.topk(gate_scores, k=2, dim=-1)
Normalize scores
Normalize scores
top2_scores = top2_scores / top2_scores.sum(dim=-1, keepdim=True)
top2_scores = top2_scores / top2_scores.sum(dim=-1, keepdim=True)
Combine expert outputs
Combine expert outputs
output = (top2_scores[:, :, 0:1] * expert_outputs[top2_indices[:, :, 0]] +
top2_scores[:, :, 1:2] * expert_outputs[top2_indices[:, :, 1]])
**Expert Choice Routing:**
```pythonoutput = (top2_scores[:, :, 0:1] * expert_outputs[top2_indices[:, :, 0]] +
top2_scores[:, :, 1:2] * expert_outputs[top2_indices[:, :, 1]])
**专家选择路由:**
```pythonExperts choose top-k tokens (instead of tokens choosing experts)
Experts choose top-k tokens (instead of tokens choosing experts)
Guarantees perfect load balancing
Guarantees perfect load balancing
expert_scores = router(x).transpose(-1, -2) # (batch, num_experts, seq_len)
topk_tokens = torch.topk(expert_scores, k=capacity_per_expert, dim=-1)
undefinedexpert_scores = router(x).transpose(-1, -2) # (batch, num_experts, seq_len)
topk_tokens = torch.topk(expert_scores, k=capacity_per_expert, dim=-1)
undefined3. Load Balancing
3. 负载均衡
Auxiliary Loss:
python
def load_balancing_loss(gate_logits, expert_indices, num_experts):
"""Encourage uniform expert usage."""
# Fraction of tokens routed to each expert
expert_counts = torch.bincount(expert_indices.flatten(), minlength=num_experts)
expert_fraction = expert_counts.float() / expert_indices.numel()
# Gate probability for each expert (average across tokens)
gate_probs = torch.softmax(gate_logits, dim=-1).mean(dim=0)
# Auxiliary loss: encourage alignment
aux_loss = num_experts * (expert_fraction * gate_probs).sum()
return aux_loss辅助损失:
python
def load_balancing_loss(gate_logits, expert_indices, num_experts):
"""Encourage uniform expert usage."""
# Fraction of tokens routed to each expert
expert_counts = torch.bincount(expert_indices.flatten(), minlength=num_experts)
expert_fraction = expert_counts.float() / expert_indices.numel()
# Gate probability for each expert (average across tokens)
gate_probs = torch.softmax(gate_logits, dim=-1).mean(dim=0)
# Auxiliary loss: encourage alignment
aux_loss = num_experts * (expert_fraction * gate_probs).sum()
return aux_lossAdd to main loss
Add to main loss
total_loss = language_model_loss + 0.01 * load_balancing_loss(...)
**Router Z-Loss (Stability):**
```python
def router_z_loss(logits):
"""Encourage router to have lower entropy (more decisive)."""
z_loss = torch.logsumexp(logits, dim=-1).pow(2).mean()
return z_loss
total_loss = lm_loss + 0.01 * aux_loss + 0.001 * router_z_loss(gate_logits)total_loss = language_model_loss + 0.01 * load_balancing_loss(...)
**路由Z损失(稳定性):**
```python
def router_z_loss(logits):
"""Encourage router to have lower entropy (more decisive)."""
z_loss = torch.logsumexp(logits, dim=-1).pow(2).mean()
return z_loss
total_loss = lm_loss + 0.01 * aux_loss + 0.001 * router_z_loss(gate_logits)4. Expert Parallelism
4. 专家并行
python
undefinedpython
undefinedDeepSpeed configuration
DeepSpeed configuration
{
"train_batch_size": 256,
"fp16": {"enabled": true},
"moe": {
"enabled": true,
"num_experts": 128,
"expert_parallel_size": 8, # Distribute 128 experts across 8 GPUs
"capacity_factor": 1.25, # Expert capacity = tokens_per_batch * capacity_factor / num_experts
"drop_tokens": true, # Drop tokens exceeding capacity
"use_residual": false
}
}
undefined{
"train_batch_size": 256,
"fp16": {"enabled": true},
"moe": {
"enabled": true,
"num_experts": 128,
"expert_parallel_size": 8, # Distribute 128 experts across 8 GPUs
"capacity_factor": 1.25, # Expert capacity = tokens_per_batch * capacity_factor / num_experts
"drop_tokens": true, # Drop tokens exceeding capacity
"use_residual": false
}
}
undefinedTraining Configuration
训练配置
DeepSpeed MoE Config
DeepSpeed MoE配置
json
{
"train_batch_size": 256,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0001,
"betas": [0.9, 0.999],
"eps": 1e-8
}
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
},
"moe": {
"enabled": true,
"num_experts": 128,
"expert_parallel_size": 8,
"moe_loss_coeff": 0.01,
"train_capacity_factor": 1.25,
"eval_capacity_factor": 2.0,
"min_capacity": 4,
"drop_tokens": true,
"use_residual": false,
"use_tutel": false
},
"zero_optimization": {
"stage": 1
}
}json
{
"train_batch_size": 256,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0001,
"betas": [0.9, 0.999],
"eps": 1e-8
}
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
},
"moe": {
"enabled": true,
"num_experts": 128,
"expert_parallel_size": 8,
"moe_loss_coeff": 0.01,
"train_capacity_factor": 1.25,
"eval_capacity_factor": 2.0,
"min_capacity": 4,
"drop_tokens": true,
"use_residual": false,
"use_tutel": false
},
"zero_optimization": {
"stage": 1
}
}Training Script
训练脚本
bash
#!/bin/bashbash
#!/bin/bashMixtral-style MoE training
Mixtral-style MoE training
deepspeed --num_gpus 8 pretrain_moe.py
--model-parallel-size 1
--num-layers 32
--hidden-size 4096
--num-attention-heads 32
--seq-length 2048
--max-position-embeddings 4096
--micro-batch-size 2
--global-batch-size 256
--train-iters 500000
--save-interval 5000
--eval-interval 1000
--eval-iters 100
--lr 0.0001
--min-lr 0.00001
--lr-decay-style cosine
--lr-warmup-iters 2000
--clip-grad 1.0
--weight-decay 0.1
--num-experts 8
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--moe-train-capacity-factor 1.25
--moe-eval-capacity-factor 2.0
--disable-moe-token-dropping
--fp16
--deepspeed
--deepspeed_config ds_config_moe.json
--data-path /path/to/data
--vocab-file /path/to/vocab.json
--merge-file /path/to/merges.txt
--model-parallel-size 1
--num-layers 32
--hidden-size 4096
--num-attention-heads 32
--seq-length 2048
--max-position-embeddings 4096
--micro-batch-size 2
--global-batch-size 256
--train-iters 500000
--save-interval 5000
--eval-interval 1000
--eval-iters 100
--lr 0.0001
--min-lr 0.00001
--lr-decay-style cosine
--lr-warmup-iters 2000
--clip-grad 1.0
--weight-decay 0.1
--num-experts 8
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--moe-train-capacity-factor 1.25
--moe-eval-capacity-factor 2.0
--disable-moe-token-dropping
--fp16
--deepspeed
--deepspeed_config ds_config_moe.json
--data-path /path/to/data
--vocab-file /path/to/vocab.json
--merge-file /path/to/merges.txt
undefineddeepspeed --num_gpus 8 pretrain_moe.py
--model-parallel-size 1
--num-layers 32
--hidden-size 4096
--num-attention-heads 32
--seq-length 2048
--max-position-embeddings 4096
--micro-batch-size 2
--global-batch-size 256
--train-iters 500000
--save-interval 5000
--eval-interval 1000
--eval-iters 100
--lr 0.0001
--min-lr 0.00001
--lr-decay-style cosine
--lr-warmup-iters 2000
--clip-grad 1.0
--weight-decay 0.1
--num-experts 8
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--moe-train-capacity-factor 1.25
--moe-eval-capacity-factor 2.0
--disable-moe-token-dropping
--fp16
--deepspeed
--deepspeed_config ds_config_moe.json
--data-path /path/to/data
--vocab-file /path/to/vocab.json
--merge-file /path/to/merges.txt
--model-parallel-size 1
--num-layers 32
--hidden-size 4096
--num-attention-heads 32
--seq-length 2048
--max-position-embeddings 4096
--micro-batch-size 2
--global-batch-size 256
--train-iters 500000
--save-interval 5000
--eval-interval 1000
--eval-iters 100
--lr 0.0001
--min-lr 0.00001
--lr-decay-style cosine
--lr-warmup-iters 2000
--clip-grad 1.0
--weight-decay 0.1
--num-experts 8
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--moe-train-capacity-factor 1.25
--moe-eval-capacity-factor 2.0
--disable-moe-token-dropping
--fp16
--deepspeed
--deepspeed_config ds_config_moe.json
--data-path /path/to/data
--vocab-file /path/to/vocab.json
--merge-file /path/to/merges.txt
undefinedAdvanced Patterns
进阶模式
Mixtral 8x7B Architecture
Mixtral 8x7B架构
python
class MixtralMoEBlock(nn.Module):
"""Mixtral-style MoE block with 8 experts, top-2 routing."""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts # 8
self.top_k = config.num_experts_per_tok # 2
# 8 expert FFNs
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(self.hidden_dim, self.ffn_dim, bias=False),
nn.SiLU(),
nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
)
for _ in range(self.num_experts)
])
# Router
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
def forward(self, hidden_states):
batch_size, sequence_length, hidden_dim = hidden_states.shape
# Flatten
hidden_states = hidden_states.view(-1, hidden_dim)
# Router logits
router_logits = self.gate(hidden_states) # (batch * seq_len, num_experts)
# Softmax and top-2
routing_weights = torch.softmax(router_logits, dim=1)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# Normalize routing weights
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# Initialize output
final_hidden_states = torch.zeros_like(hidden_states)
# Route to experts
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(selected_experts == expert_idx)
if idx.shape[0] == 0:
continue
# Current expert tokens
current_hidden_states = hidden_states[idx]
# Expert forward
current_hidden_states = expert_layer(current_hidden_states)
# Weighted by routing scores
current_hidden_states *= routing_weights[idx, top_x, None]
# Accumulate
final_hidden_states.index_add_(0, idx, current_hidden_states)
# Reshape
return final_hidden_states.view(batch_size, sequence_length, hidden_dim)python
class MixtralMoEBlock(nn.Module):
"""Mixtral-style MoE block with 8 experts, top-2 routing."""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts # 8
self.top_k = config.num_experts_per_tok # 2
# 8 expert FFNs
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(self.hidden_dim, self.ffn_dim, bias=False),
nn.SiLU(),
nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
)
for _ in range(self.num_experts)
])
# Router
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
def forward(self, hidden_states):
batch_size, sequence_length, hidden_dim = hidden_states.shape
# Flatten
hidden_states = hidden_states.view(-1, hidden_dim)
# Router logits
router_logits = self.gate(hidden_states) # (batch * seq_len, num_experts)
# Softmax and top-2
routing_weights = torch.softmax(router_logits, dim=1)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# Normalize routing weights
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# Initialize output
final_hidden_states = torch.zeros_like(hidden_states)
# Route to experts
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(selected_experts == expert_idx)
if idx.shape[0] == 0:
continue
# Current expert tokens
current_hidden_states = hidden_states[idx]
# Expert forward
current_hidden_states = expert_layer(current_hidden_states)
# Weighted by routing scores
current_hidden_states *= routing_weights[idx, top_x, None]
# Accumulate
final_hidden_states.index_add_(0, idx, current_hidden_states)
# Reshape
return final_hidden_states.view(batch_size, sequence_length, hidden_dim)PR-MoE (Pyramid-Residual-MoE)
PR-MoE(金字塔残差混合专家)
bash
undefinedbash
undefinedDeepSpeed PR-MoE: 3x better parameter efficiency
DeepSpeed PR-MoE: 3x better parameter efficiency
deepspeed pretrain_gpt_moe.py
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--num-experts "[128, 64, 32, 16]"
--mlp-type residual
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--fp16
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--num-experts "[128, 64, 32, 16]"
--mlp-type residual
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--fp16
undefineddeepspeed pretrain_gpt_moe.py
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--num-experts "[128, 64, 32, 16]"
--mlp-type residual
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--fp16
--num-layers 24
--hidden-size 1024
--num-attention-heads 16
--num-experts "[128, 64, 32, 16]"
--mlp-type residual
--moe-expert-parallel-size 4
--moe-loss-coeff 0.01
--fp16
undefinedBest Practices
最佳实践
1. Expert Count Selection
1. 专家数量选择
python
undefinedpython
undefinedRule of thumb: More experts = more capacity, but diminishing returns
Rule of thumb: More experts = more capacity, but diminishing returns
Typical configurations:
Typical configurations:
- Small models (1B-7B): 8-16 experts
- Small models (1B-7B): 8-16 experts
- Medium models (7B-30B): 8-64 experts
- Medium models (7B-30B): 8-64 experts
- Large models (30B+): 64-256 experts
- Large models (30B+): 64-256 experts
Example: Mixtral 8x7B
Example: Mixtral 8x7B
Total params: 47B (8 experts × 7B each)
Total params: 47B (8 experts × 7B each)
Active params: 13B (2 experts × 7B, top-2 routing)
Active params: 13B (2 experts × 7B, top-2 routing)
Efficiency: 47B capacity with 13B compute
Efficiency: 47B capacity with 13B compute
undefinedundefined2. Capacity Factor Tuning
2. 容量因子调优
python
undefinedpython
undefinedCapacity = (tokens_per_batch / num_experts) * capacity_factor
Capacity = (tokens_per_batch / num_experts) * capacity_factor
Training: Lower capacity (faster, drops some tokens)
Training: Lower capacity (faster, drops some tokens)
train_capacity_factor = 1.25 # 25% buffer
train_capacity_factor = 1.25 # 25% buffer
Evaluation: Higher capacity (no dropping)
Evaluation: Higher capacity (no dropping)
eval_capacity_factor = 2.0 # 100% buffer
eval_capacity_factor = 2.0 # 100% buffer
Formula:
Formula:
expert_capacity = int((seq_len * batch_size / num_experts) * capacity_factor)
undefinedexpert_capacity = int((seq_len * batch_size / num_experts) * capacity_factor)
undefined3. Learning Rate Guidelines
3. 学习率指南
python
undefinedpython
undefinedMoE models need lower LR than dense models
MoE models need lower LR than dense models
- Dense model: lr = 6e-4
- Dense model: lr = 6e-4
- MoE model: lr = 1e-4 (3-6× lower)
- MoE model: lr = 1e-4 (3-6× lower)
Also extend decay schedule
Also extend decay schedule
dense_lr_decay_iters = 300000
moe_lr_decay_iters = 500000 # 1.5-2× longer
undefineddense_lr_decay_iters = 300000
moe_lr_decay_iters = 500000 # 1.5-2× longer
undefined4. Loss Coefficient Tuning
4. 损失系数调优
python
undefinedpython
undefinedStart with standard values
Start with standard values
moe_loss_coeff = 0.01 # Auxiliary loss (load balancing)
router_z_loss_coeff = 0.001 # Router entropy (stability)
moe_loss_coeff = 0.01 # Auxiliary loss (load balancing)
router_z_loss_coeff = 0.001 # Router entropy (stability)
If load imbalance persists, increase aux loss
If load imbalance persists, increase aux loss
if max_expert_usage / min_expert_usage > 2.0:
moe_loss_coeff = 0.1 # Stronger load balancing
if max_expert_usage / min_expert_usage > 2.0:
moe_loss_coeff = 0.1 # Stronger load balancing
If training unstable, increase z-loss
If training unstable, increase z-loss
if grad_norm > 10.0:
router_z_loss_coeff = 0.01
undefinedif grad_norm > 10.0:
router_z_loss_coeff = 0.01
undefined5. Avoid Common Pitfalls
5. 避免常见误区
python
undefinedpython
undefined❌ Bad: Using same LR as dense model
❌ Bad: Using same LR as dense model
optimizer = Adam(model.parameters(), lr=6e-4)
optimizer = Adam(model.parameters(), lr=6e-4)
✅ Good: Lower LR for MoE
✅ Good: Lower LR for MoE
optimizer = Adam([
{'params': model.non_moe_params, 'lr': 6e-4},
{'params': model.moe_params, 'lr': 1e-4}
])
optimizer = Adam([
{'params': model.non_moe_params, 'lr': 6e-4},
{'params': model.moe_params, 'lr': 1e-4}
])
❌ Bad: No load balancing
❌ Bad: No load balancing
loss = lm_loss
loss = lm_loss
✅ Good: Add auxiliary loss
✅ Good: Add auxiliary loss
loss = lm_loss + 0.01 * aux_loss + 0.001 * z_loss
loss = lm_loss + 0.01 * aux_loss + 0.001 * z_loss
❌ Bad: Too many experts for small dataset
❌ Bad: Too many experts for small dataset
num_experts = 128 # Overfitting risk
num_experts = 128 # Overfitting risk
✅ Good: Match experts to data diversity
✅ Good: Match experts to data diversity
num_experts = 8 # Better for small datasets
undefinednum_experts = 8 # Better for small datasets
undefinedInference Optimization
推理优化
Sparse Inference
稀疏推理
python
undefinedpython
undefinedOnly activate top-k experts (huge memory savings)
Only activate top-k experts (huge memory savings)
@torch.no_grad()
def moe_inference(x, model, top_k=2):
"""Sparse MoE inference: only load k experts."""
# Router
gate_logits = model.gate(x)
topk_scores, topk_indices = torch.topk(
torch.softmax(gate_logits, dim=-1),
k=top_k,
dim=-1
)
# Load and run only top-k experts
output = torch.zeros_like(x)
for i in range(top_k):
expert_idx = topk_indices[:, i]
# Load expert from disk/offload if needed
expert = model.load_expert(expert_idx)
output += topk_scores[:, i:i+1] * expert(x)
return outputundefined@torch.no_grad()
def moe_inference(x, model, top_k=2):
"""Sparse MoE inference: only load k experts."""
# Router
gate_logits = model.gate(x)
topk_scores, topk_indices = torch.topk(
torch.softmax(gate_logits, dim=-1),
k=top_k,
dim=-1
)
# Load and run only top-k experts
output = torch.zeros_like(x)
for i in range(top_k):
expert_idx = topk_indices[:, i]
# Load expert from disk/offload if needed
expert = model.load_expert(expert_idx)
output += topk_scores[:, i:i+1] * expert(x)
return outputundefinedResources
资源
- DeepSpeed MoE Tutorial: https://www.deepspeed.ai/tutorials/mixture-of-experts-nlg/
- Mixtral Paper: https://arxiv.org/abs/2401.04088
- Switch Transformers: https://arxiv.org/abs/2101.03961
- HuggingFace MoE Guide: https://huggingface.co/blog/moe
- NVIDIA MoE Blog: https://developer.nvidia.com/blog/applying-mixture-of-experts-in-llm-architectures/
- DeepSpeed MoE教程:https://www.deepspeed.ai/tutorials/mixture-of-experts-nlg/
- Mixtral论文:https://arxiv.org/abs/2401.04088
- Switch Transformers:https://arxiv.org/abs/2101.03961
- HuggingFace MoE指南:https://huggingface.co/blog/moe
- NVIDIA MoE博客:https://developer.nvidia.com/blog/applying-mixture-of-experts-in-llm-architectures/
See Also
另请参阅
- - MoE model architectures (Mixtral, Switch, DeepSeek-V3)
references/architectures.md - - Advanced training techniques and optimization
references/training.md - - Production deployment and serving patterns
references/inference.md
- - MoE模型架构(Mixtral、Switch、DeepSeek-V3)
references/architectures.md - - 进阶训练技术与优化
references/training.md - - 生产部署与服务模式
references/inference.md