pytorch-model-recovery

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

PyTorch Model Recovery

PyTorch模型恢复

This skill provides guidance for tasks involving PyTorch model architecture recovery from state dictionaries, selective layer training, and TorchScript export.
本技能提供以下任务的指导:从状态字典恢复PyTorch模型架构、选择性层训练,以及TorchScript导出。

When to Use This Skill

适用场景

This skill applies when:
  • Reconstructing a model architecture from a state dictionary (
    .pt
    or
    .pth
    file containing weights)
  • Training or fine-tuning specific layers while keeping others frozen
  • Converting a recovered model to TorchScript format
  • Debugging model loading issues or architecture mismatches
本技能适用于以下情况:
  • 从状态字典(包含权重的
    .pt
    .pth
    文件)重建模型架构
  • 训练或微调特定层,同时冻结其他层
  • 将恢复后的模型转换为TorchScript格式
  • 调试模型加载问题或架构不匹配问题

Approach Overview

方法概述

Model recovery tasks require a systematic, incremental approach with verification at each step. The key phases are:
  1. Architecture Analysis - Infer model structure from state dictionary keys
  2. Architecture Implementation - Build the model class to match the state dict
  3. Verification - Confirm weights load correctly before any training
  4. Training - Fine-tune specific layers with appropriate hyperparameters
  5. Export - Save to required format (often TorchScript)
模型恢复任务需要系统的、渐进式的方法,并在每个步骤进行验证。关键阶段如下:
  1. 架构分析 - 从状态字典键推断模型结构
  2. 架构实现 - 构建与状态字典匹配的模型类
  3. 验证 - 在训练前确认权重加载正确
  4. 训练 - 使用合适的超参数微调特定层
  5. 导出 - 将模型保存为所需格式(通常为TorchScript)

Phase 1: Architecture Analysis

阶段1:架构分析

Examining the State Dictionary

检查状态字典

To understand the model architecture, first load and inspect the state dictionary:
python
import torch

weights = torch.load('model_weights.pt', map_location='cpu')
要了解模型架构,首先加载并检查状态字典:
python
import torch

weights = torch.load('model_weights.pt', map_location='cpu')

Print all keys with shapes

打印所有键及其形状

for key, value in weights.items(): print(f"{key}: {value.shape}")
undefined
for key, value in weights.items(): print(f"{key}: {value.shape}")
undefined

Key Patterns to Identify

需要识别的常见键模式

Common patterns in state dictionary keys:
Key PatternIndicates
encoder.layers.N.*
Transformer encoder with N+1 layers
decoder.layers.N.*
Transformer decoder with N+1 layers
embedding.weight
Embedding layer
pos_encoder.pe
Positional encoding (often a buffer)
output_layer.weight/bias
Final linear projection
*.in_proj_weight
Combined QKV projection in attention
*.self_attn.*
Self-attention component
*.linear1/linear2.*
Feed-forward network layers
*.norm1/norm2.*
Layer normalization
状态字典键中的常见模式:
键模式含义
encoder.layers.N.*
包含N+1层的Transformer编码器
decoder.layers.N.*
包含N+1层的Transformer解码器
embedding.weight
嵌入层
pos_encoder.pe
位置编码(通常是缓冲区)
output_layer.weight/bias
最终线性投影层
*.in_proj_weight
注意力机制中的组合QKV投影
*.self_attn.*
自注意力组件
*.linear1/linear2.*
前馈网络层
*.norm1/norm2.*
层归一化

Inferring Dimensions

推断维度

Extract model dimensions from weight shapes:
python
undefined
从权重形状提取模型维度:
python
undefined

Example: Inferring transformer dimensions

示例:推断Transformer维度

d_model = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[1] nhead = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) * nhead_factor
d_model = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[1] nhead = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) * nhead_factor

Note: in_proj_weight has shape [3*d_model, d_model] for combined QKV

注意:in_proj_weight的形状为[3*d_model, d_model],对应组合式QKV

vocab_size = weights['embedding.weight'].shape[0] num_layers = max(int(k.split('.')[2]) for k in weights if 'encoder.layers' in k) + 1
undefined
vocab_size = weights['embedding.weight'].shape[0] num_layers = max(int(k.split('.')[2]) for k in weights if 'encoder.layers' in k) + 1
undefined

Phase 2: Architecture Implementation

阶段2:架构实现

Building the Model Class

构建模型类

When implementing the model class:
  1. Match the exact layer names used in the state dictionary
  2. Use the same PyTorch module types (e.g.,
    nn.TransformerEncoder
    vs custom)
  3. Register buffers for non-learnable tensors (e.g., positional encodings)
python
class RecoveredModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
        super().__init__()
        # Ensure attribute names match state dict keys exactly
        self.embedding = nn.Embedding(vocab_size, d_model)

        # For positional encoding stored as buffer
        self.pos_encoder = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True  # Check if original used batch_first
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)
实现模型类时:
  1. 匹配状态字典中使用的精确层名称
  2. 使用相同的PyTorch模块类型(例如
    nn.TransformerEncoder
    而非自定义模块)
  3. 为不可学习的张量注册缓冲区(例如位置编码)
python
class RecoveredModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
        super().__init__()
        # 确保属性名称与状态字典键完全匹配
        self.embedding = nn.Embedding(vocab_size, d_model)

        # 位置编码作为缓冲区注册
        self.pos_encoder = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True  # 检查原模型是否使用batch_first
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)

Common Architecture Mistakes

常见架构错误

  • Incorrect layer naming:
    self.fc
    vs
    self.output_layer
    - must match exactly
  • Missing buffers: Positional encodings often registered as buffers, not parameters
  • Wrong module types: Custom attention vs
    nn.MultiheadAttention
  • Batch dimension mismatch:
    batch_first=True
    vs
    batch_first=False
  • 层名称不正确
    self.fc
    self.output_layer
    必须完全匹配
  • 缺少缓冲区:位置编码通常注册为缓冲区,而非参数
  • 模块类型错误:自定义注意力与
    nn.MultiheadAttention
    混淆
  • 批量维度不匹配
    batch_first=True
    batch_first=False
    设置错误

Phase 3: Verification (Critical)

阶段3:验证(关键步骤)

Verify Architecture Before Training

训练前验证架构

Always verify the model loads weights correctly before any training:
python
model = RecoveredModel(...)
在进行任何训练前,务必验证模型是否能正确加载权重:
python
model = RecoveredModel(...)

This will raise an error if keys don't match

如果键不匹配,此操作会抛出错误

model.load_state_dict(weights, strict=True) print("Weights loaded successfully!")
model.load_state_dict(weights, strict=True) print("权重加载成功!")

Verify a forward pass works

验证前向传播正常

with torch.no_grad(): dummy_input = torch.randint(0, vocab_size, (1, 10)) output = model(dummy_input) print(f"Output shape: {output.shape}")
undefined
with torch.no_grad(): dummy_input = torch.randint(0, vocab_size, (1, 10)) output = model(dummy_input) print(f"输出形状: {output.shape}")
undefined

Handling Key Mismatches

处理键不匹配问题

If
load_state_dict
fails, compare keys:
python
model_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())

missing = weight_keys - model_keys
unexpected = model_keys - weight_keys

print(f"Missing in model: {missing}")
print(f"Unexpected in model: {unexpected}")
如果
load_state_dict
失败,对比键集合:
python
model_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())

missing = weight_keys - model_keys
unexpected = model_keys - weight_keys

print(f"模型中缺少的键: {missing}")
print(f"模型中多余的键: {unexpected}")

Verify TorchScript Compatibility Early

提前验证TorchScript兼容性

If TorchScript export is required, test it early:
python
undefined
如果需要导出TorchScript,尽早进行测试:
python
undefined

Test scripting works before investing time in training

在投入训练时间前测试脚本化是否可行

try: scripted = torch.jit.script(model) print("TorchScript scripting successful") except Exception as e: print(f"Scripting failed: {e}") # Try tracing instead traced = torch.jit.trace(model, dummy_input) print("TorchScript tracing successful")
undefined
try: scripted = torch.jit.script(model) print("TorchScript脚本化成功") except Exception as e: print(f"脚本化失败: {e}") # 尝试使用追踪方式 traced = torch.jit.trace(model, dummy_input) print("TorchScript追踪成功")
undefined

Phase 4: Training Specific Layers

阶段4:训练特定层

Freezing Layers

冻结层

To train only specific layers, freeze all others:
python
undefined
要仅训练特定层,先冻结所有其他层:
python
undefined

Freeze all parameters first

先冻结所有参数

for param in model.parameters(): param.requires_grad = False
for param in model.parameters(): param.requires_grad = False

Unfreeze only target layers

仅解冻目标层

for param in model.output_layer.parameters(): param.requires_grad = True
for param in model.output_layer.parameters(): param.requires_grad = True

Verify freeze status

验证冻结状态

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f"Trainable: {trainable:,} / {total:,} parameters")
undefined
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f"可训练参数: {trainable:,} / {total:,}")
undefined

Computing Baseline Loss

计算基准损失

Before training, establish a baseline:
python
model.eval()
with torch.no_grad():
    outputs = model(inputs)
    original_loss = criterion(outputs, targets)
    print(f"Original MSE loss: {original_loss.item()}")
训练前,先建立基准:
python
model.eval()
with torch.no_grad():
    outputs = model(inputs)
    original_loss = criterion(outputs, targets)
    print(f"原始MSE损失: {original_loss.item()}")

Training Loop Considerations

训练循环注意事项

python
undefined
python
undefined

Create optimizer only for trainable parameters

仅为可训练参数创建优化器

optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=0.001 )
optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=0.001 )

Training with progress tracking

带进度跟踪的训练

for epoch in range(num_epochs): model.train() optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)

loss.backward()
optimizer.step()

if epoch % 10 == 0:
    print(f"Epoch {epoch}: Loss = {loss.item():.6f}")
undefined
for epoch in range(num_epochs): model.train() optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)

loss.backward()
optimizer.step()

if epoch % 10 == 0:
    print(f"Epoch {epoch}: 损失 = {loss.item():.6f}")
undefined

Alternative: Closed-Form Solution for Linear Layers

替代方案:线性层的闭式解

When retraining only a linear output layer, consider a closed-form solution for efficiency:
python
undefined
当仅重新训练线性输出层时,考虑使用闭式解以提高效率:
python
undefined

Pre-compute frozen layer outputs

预计算冻结层的输出

model.eval() with torch.no_grad(): # Get features before output layer features = model.get_features(inputs) # Shape: [N, d_model]
model.eval() with torch.no_grad(): # 获取输出层之前的特征 features = model.get_features(inputs) # 形状: [N, d_model]

Solve linear regression: W*features = targets

求解线性回归: W*features = targets

Using pseudo-inverse: W = targets @ features.T @ (features @ features.T)^-1

使用伪逆: W = targets @ features.T @ (features @ features.T)^-1

solution = torch.linalg.lstsq(features, targets).solution model.output_layer.weight.data = solution.T
undefined
solution = torch.linalg.lstsq(features, targets).solution model.output_layer.weight.data = solution.T
undefined

Phase 5: TorchScript Export

阶段5:TorchScript导出

Saving the Model

保存模型

python
undefined
python
undefined

Ensure model is in eval mode

确保模型处于eval模式

model.eval()
model.eval()

Script the model (preferred for control flow)

脚本化模型(推荐用于含控制流的模型)

scripted_model = torch.jit.script(model) scripted_model.save('/app/model.pt')
scripted_model = torch.jit.script(model) scripted_model.save('/app/model.pt')

Or trace the model (for simpler models)

或追踪模型(适用于简单模型)

traced_model = torch.jit.trace(model, example_input) traced_model.save('/app/model.pt')
undefined
traced_model = torch.jit.trace(model, example_input) traced_model.save('/app/model.pt')
undefined

Verify Saved Model

验证保存的模型

python
undefined
python
undefined

Reload and verify

重新加载并验证

loaded = torch.jit.load('/app/model.pt') loaded.eval()
with torch.no_grad(): original_out = model(test_input) loaded_out = loaded(test_input)
diff = (original_out - loaded_out).abs().max()
print(f"Max difference: {diff.item()}")
assert diff < 1e-5, "Model outputs don't match!"
undefined
loaded = torch.jit.load('/app/model.pt') loaded.eval()
with torch.no_grad(): original_out = model(test_input) loaded_out = loaded(test_input)
diff = (original_out - loaded_out).abs().max()
print(f"最大差异: {diff.item()}")
assert diff < 1e-5, "模型输出不匹配!"
undefined

Environment Considerations

环境注意事项

Handling Slow Environments

处理低性能环境

When operating in resource-constrained environments:
  1. Benchmark first: Test basic operations before committing to full solution
    python
    import time
    start = time.time()
    _ = model(torch.randint(0, vocab_size, (1, 10)))
    print(f"Single forward pass: {time.time() - start:.2f}s")
  2. Reduce batch size: Process samples individually if needed
  3. Set realistic timeouts: Base on benchmarks, not arbitrary values
  4. Use incremental checkpoints: Save progress periodically
在资源受限环境中操作时:
  1. 先基准测试:在投入完整解决方案前,测试基础操作
    python
    import time
    start = time.time()
    _ = model(torch.randint(0, vocab_size, (1, 10)))
    print(f"单次前向传播耗时: {time.time() - start:.2f}s")
  2. 减小批量大小:必要时逐个处理样本
  3. 设置合理超时:基于基准测试结果,而非任意值
  4. 使用增量检查点:定期保存进度

Memory Management

内存管理

python
undefined
python
undefined

Clear GPU cache between operations

操作之间清理GPU缓存

torch.cuda.empty_cache()
torch.cuda.empty_cache()

Use gradient checkpointing for large models

对大型模型使用梯度检查点

from torch.utils.checkpoint import checkpoint
from torch.utils.checkpoint import checkpoint

Process in smaller batches

分小批量处理

for batch in torch.split(data, batch_size): process(batch)
undefined
for batch in torch.split(data, batch_size): process(batch)
undefined

Common Pitfalls

常见陷阱

  1. Not verifying architecture match before training - Always test
    load_state_dict
    first
  2. Arbitrary hyperparameters - Justify choices based on task characteristics
  3. Ignoring TorchScript compatibility - Test export early, not after training
  4. Syntax errors in edits - Review code changes carefully, especially string formatting
  5. Incomplete state dict mapping - Verify all keys are accounted for
  6. Not establishing baseline metrics - Compute original loss before training
  7. Missing
    torch.no_grad()
    for inference
    - Use context manager for evaluation
  8. Forgetting to set
    model.eval()
    - Required for consistent behavior in eval/export
  1. 训练前未验证架构匹配 - 务必先测试
    load_state_dict
  2. 超参数设置随意 - 根据任务特性调整并说明理由
  3. 忽略TorchScript兼容性 - 尽早测试导出,而非训练后
  4. 代码编辑中的语法错误 - 仔细检查代码变更,尤其是字符串格式化
  5. 状态字典映射不完整 - 确保所有键都已处理
  6. 未建立基准指标 - 训练前计算原始损失
  7. 推理时未使用
    torch.no_grad()
    - 评估时使用上下文管理器
  8. 忘记设置
    model.eval()
    - 评估/导出时需要设置以保证行为一致

Verification Checklist

验证清单

Before considering the task complete:
  • State dictionary keys fully analyzed and documented
  • Model architecture matches state dict exactly (verified with
    load_state_dict
    )
  • Forward pass produces valid output
  • Baseline loss/metric computed
  • Target layers correctly unfrozen, others frozen
  • Training improves loss over baseline
  • TorchScript export succeeds
  • Exported model produces same outputs as original
  • Model saved to required path
任务完成前需确认:
  • 状态字典键已完整分析并记录
  • 模型架构与状态字典完全匹配(通过
    load_state_dict
    验证)
  • 前向传播生成有效输出
  • 已计算基准损失/指标
  • 目标层已正确解冻,其他层已冻结
  • 训练后损失较基准有所改善
  • TorchScript导出成功
  • 导出模型与原模型输出一致
  • 模型已保存至指定路径