fine-tuning

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Fine-Tuning

LLM微调

Adapt LLMs to specific tasks and domains efficiently.
高效将大语言模型(LLM)适配到特定任务与领域。

Quick Start

快速开始

LoRA Fine-Tuning with PEFT

基于PEFT的LoRA微调

python
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from trl import SFTTrainer
python
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from trl import SFTTrainer

Load base model

Load base model

model_name = "meta-llama/Llama-2-7b-hf" model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token
model_name = "meta-llama/Llama-2-7b-hf" model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token

Configure LoRA

Configure LoRA

lora_config = LoraConfig( r=16, # Rank lora_alpha=32, # Alpha scaling target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM )
lora_config = LoraConfig( r=16, # Rank lora_alpha=32, # Alpha scaling target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM )

Apply LoRA

Apply LoRA

model = get_peft_model(model, lora_config) model.print_trainable_parameters()
model = get_peft_model(model, lora_config) model.print_trainable_parameters()

trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06%

trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06%

Training arguments

Training arguments

training_args = TrainingArguments( output_dir="./output", num_train_epochs=3, per_device_train_batch_size=4, gradient_accumulation_steps=4, learning_rate=2e-4, warmup_ratio=0.03, logging_steps=10, save_strategy="epoch" )
training_args = TrainingArguments( output_dir="./output", num_train_epochs=3, per_device_train_batch_size=4, gradient_accumulation_steps=4, learning_rate=2e-4, warmup_ratio=0.03, logging_steps=10, save_strategy="epoch" )

Train

Train

trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer, max_seq_length=512 ) trainer.train()
undefined
trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer, max_seq_length=512 ) trainer.train()
undefined

QLoRA (4-bit Quantized LoRA)

QLoRA(4位量化LoRA)

python
from transformers import BitsAndBytesConfig
import torch
python
from transformers import BitsAndBytesConfig
import torch

Quantization config

Quantization config

bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True )
bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True )

Load quantized model

Load quantized model

model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto" )
model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto" )

Apply LoRA on top of quantized model

Apply LoRA on top of quantized model

model = get_peft_model(model, lora_config)
undefined
model = get_peft_model(model, lora_config)
undefined

Dataset Preparation

数据集准备

Instruction Dataset Format

指令数据集格式

python
undefined
python
undefined

Alpaca format

Alpaca format

instruction_format = { "instruction": "Summarize the following text.", "input": "The quick brown fox jumps over the lazy dog...", "output": "A fox jumps over a dog." }
instruction_format = { "instruction": "Summarize the following text.", "input": "The quick brown fox jumps over the lazy dog...", "output": "A fox jumps over a dog." }

ChatML format

ChatML format

chat_format = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Summarize this text: ..."}, {"role": "assistant", "content": "Summary: ..."} ]
chat_format = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Summarize this text: ..."}, {"role": "assistant", "content": "Summary: ..."} ]

Formatting function

Formatting function

def format_instruction(sample): return f"""### Instruction: {sample['instruction']}
def format_instruction(sample): return f"""### Instruction: {sample['instruction']}

Input:

Input:

{sample['input']}
{sample['input']}

Response:

Response:

{sample['output']}"""
undefined
{sample['output']}"""
undefined

Data Preparation Pipeline

数据准备流水线

python
from datasets import Dataset
import json

class DatasetPreparer:
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def prepare(self, data_path: str) -> Dataset:
        # Load raw data
        with open(data_path) as f:
            raw_data = json.load(f)

        # Format samples
        formatted = [self._format_sample(s) for s in raw_data]

        # Create dataset
        dataset = Dataset.from_dict({"text": formatted})

        # Tokenize
        return dataset.map(
            self._tokenize,
            batched=True,
            remove_columns=["text"]
        )

    def _format_sample(self, sample):
        return f"""<s>[INST] {sample['instruction']}

{sample['input']} [/INST] {sample['output']}</s>"""

    def _tokenize(self, examples):
        return self.tokenizer(
            examples["text"],
            truncation=True,
            max_length=self.max_length,
            padding="max_length"
        )
python
from datasets import Dataset
import json

class DatasetPreparer:
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def prepare(self, data_path: str) -> Dataset:
        # Load raw data
        with open(data_path) as f:
            raw_data = json.load(f)

        # Format samples
        formatted = [self._format_sample(s) for s in raw_data]

        # Create dataset
        dataset = Dataset.from_dict({"text": formatted})

        # Tokenize
        return dataset.map(
            self._tokenize,
            batched=True,
            remove_columns=["text"]
        )

    def _format_sample(self, sample):
        return f"""<s>[INST] {sample['instruction']}

{sample['input']} [/INST] {sample['output']}</s>"""

    def _tokenize(self, examples):
        return self.tokenizer(
            examples["text"],
            truncation=True,
            max_length=self.max_length,
            padding="max_length"
        )

Fine-Tuning Methods Comparison

微调方法对比

MethodVRAMSpeedQualityUse Case
Full Fine-Tune60GB+SlowBestUnlimited resources
LoRA16GBFastVery GoodMost applications
QLoRA8GBMediumGoodConsumer GPUs
Prefix Tuning8GBFastGoodFixed tasks
Prompt Tuning4GBVery FastModerateSimple adaptation
方法显存速度效果适用场景
全量微调60GB+最佳资源不受限场景
LoRA16GB优秀大多数应用场景
QLoRA8GB中等良好消费级GPU场景
Prefix Tuning8GB良好固定任务场景
Prompt Tuning4GB极快一般简单适配场景

LoRA Hyperparameters

LoRA超参数

yaml
rank (r):
  small (4-8): Simple tasks, less capacity
  medium (16-32): General fine-tuning
  large (64-128): Complex domain adaptation

alpha:
  rule: Usually 2x rank
  effect: Higher = more influence from LoRA weights

target_modules:
  attention: [q_proj, k_proj, v_proj, o_proj]
  mlp: [gate_proj, up_proj, down_proj]
  all: Maximum adaptation, more VRAM

dropout:
  typical: 0.05-0.1
  effect: Regularization, prevents overfitting
yaml
rank (r):
  small (4-8): Simple tasks, less capacity
  medium (16-32): General fine-tuning
  large (64-128): Complex domain adaptation

alpha:
  rule: Usually 2x rank
  effect: Higher = more influence from LoRA weights

target_modules:
  attention: [q_proj, k_proj, v_proj, o_proj]
  mlp: [gate_proj, up_proj, down_proj]
  all: Maximum adaptation, more VRAM

dropout:
  typical: 0.05-0.1
  effect: Regularization, prevents overfitting

Training Best Practices

训练最佳实践

Learning Rate Schedule

学习率调度

python
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=total_steps
)
python
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=total_steps
)

Gradient Checkpointing

梯度检查点

python
undefined
python
undefined

Save memory by recomputing activations

Save memory by recomputing activations

model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable()

Also enable for LoRA

Also enable for LoRA

model.enable_input_require_grads()
undefined
model.enable_input_require_grads()
undefined

Evaluation During Training

训练过程中的评估

python
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Shift for causal LM
    predictions = predictions[:, :-1]
    labels = labels[:, 1:]

    # Calculate perplexity
    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
    loss = loss_fct(predictions.view(-1, vocab_size), labels.view(-1))
    perplexity = torch.exp(loss)

    return {"perplexity": perplexity.item()}
python
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Shift for causal LM
    predictions = predictions[:, :-1]
    labels = labels[:, 1:]

    # Calculate perplexity
    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
    loss = loss_fct(predictions.view(-1, vocab_size), labels.view(-1))
    perplexity = torch.exp(loss)

    return {"perplexity": perplexity.item()}

Merging and Deploying

模型合并与部署

Merge LoRA Weights

合并LoRA权重

python
undefined
python
undefined

After training, merge LoRA into base model

After training, merge LoRA into base model

merged_model = model.merge_and_unload()
merged_model = model.merge_and_unload()

Save merged model

Save merged model

merged_model.save_pretrained("./merged_model") tokenizer.save_pretrained("./merged_model")
undefined
merged_model.save_pretrained("./merged_model") tokenizer.save_pretrained("./merged_model")
undefined

Multiple LoRA Adapters

多LoRA适配器

python
from peft import PeftModel
python
from peft import PeftModel

Load base model

Load base model

base_model = AutoModelForCausalLM.from_pretrained("base_model")
base_model = AutoModelForCausalLM.from_pretrained("base_model")

Load and switch between adapters

Load and switch between adapters

model = PeftModel.from_pretrained(base_model, "adapter_1") model.load_adapter("adapter_2", adapter_name="code")
model = PeftModel.from_pretrained(base_model, "adapter_1") model.load_adapter("adapter_2", adapter_name="code")

Switch adapters at runtime

Switch adapters at runtime

model.set_adapter("code") # Use code adapter model.set_adapter("default") # Use default adapter
undefined
model.set_adapter("code") # Use code adapter model.set_adapter("default") # Use default adapter
undefined

Common Issues

常见问题

IssueCauseSolution
Loss not decreasingLR too low/highAdjust learning rate
OOM errorsBatch too largeReduce batch, use gradient accumulation
OverfittingToo many epochsEarly stopping, more data
Catastrophic forgettingToo aggressive LRLower LR, shorter training
Poor qualityData issuesClean and validate dataset
问题原因解决方案
损失值不下降学习率过高/过低调整学习率
显存不足(OOM)错误批量大小过大减小批量,使用梯度累积
过拟合训练轮次过多提前停止训练,增加数据集
灾难性遗忘学习率过于激进降低学习率,缩短训练时长
模型效果差数据集问题清洗并验证数据集

Error Handling & Recovery

错误处理与恢复

python
from transformers import TrainerCallback

class CheckpointCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        # Always keep last 3 checkpoints
        pass

    def on_epoch_end(self, args, state, control, **kwargs):
        if state.best_metric is None:
            # Save checkpoint on each epoch
            control.should_save = True
python
from transformers import TrainerCallback

class CheckpointCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        # Always keep last 3 checkpoints
        pass

    def on_epoch_end(self, args, state, control, **kwargs):
        if state.best_metric is None:
            # Save checkpoint on each epoch
            control.should_save = True

Troubleshooting

问题排查

SymptomCauseSolution
NaN lossLR too highLower to 1e-5
No improvementLR too lowIncrease 10x
OOM mid-trainingBatch too largeEnable gradient checkpointing
症状原因解决方案
损失值为NaN学习率过高降低至1e-5
无性能提升学习率过低提升10倍
训练中途显存不足批量大小过大启用梯度检查点

Unit Test Template

单元测试模板

python
def test_lora_config():
    config = LoraConfig(r=16, lora_alpha=32)
    model = get_peft_model(base_model, config)
    assert model.print_trainable_parameters() < 1%
python
def test_lora_config():
    config = LoraConfig(r=16, lora_alpha=32)
    model = get_peft_model(base_model, config)
    assert model.print_trainable_parameters() < 1%