goodmodeler's picture
ADD: LLM SFT, RLHF and Distillation
c1c9e88
#!/usr/bin/env python3
"""
Teacher-Student知识蒸馏脚本
将经过SFT+PPO RLHF的Teacher模型蒸馏到更小的Student模型
"""
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
logging,
)
from datasets import load_dataset, Dataset as HFDataset
from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
import wandb
from typing import Dict, List, Any, Optional
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
logging.set_verbosity(logging.CRITICAL)
class DistillationConfig:
"""蒸馏训练配置"""
# 模型路径
teacher_model_path = "./rlhf_teacher_model" # RLHF后的Teacher模型
student_model_name = "microsoft/DialoGPT-medium" # 替换为实际的OpenAI OSS 20B模型
# 蒸馏参数
temperature = 4.0 # 蒸馏温度
alpha = 0.7 # 蒸馏损失权重
beta = 0.3 # 学生损失权重
gamma = 0.1 # 特征匹配损失权重
# 训练参数
learning_rate = 1e-4
num_train_epochs = 3
per_device_train_batch_size = 2
per_device_eval_batch_size = 4
gradient_accumulation_steps = 8
warmup_ratio = 0.1
weight_decay = 0.01
logging_steps = 50
eval_steps = 500
save_steps = 1000
# LoRA配置(为Student模型添加LoRA以提高训练效率)
use_lora = True
lora_r = 32
lora_alpha = 64
lora_dropout = 0.1
# 数据配置
max_length = 512
num_distill_samples = 10000 # 用于蒸馏的样本数量
# 输出配置
output_dir = "./distilled_student_model"
run_name = "teacher-student-distillation"
class DistillationDataset(Dataset):
"""蒸馏数据集类"""
def __init__(self, teacher_outputs: List[Dict], tokenizer, max_length: int = 512):
self.data = teacher_outputs
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# 构建完整的输入-输出序列
full_text = f"### Human: {item['prompt']}\n### Assistant: {item['response']}"
# Tokenize
encoded = self.tokenizer(
full_text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
)
return {
"input_ids": encoded["input_ids"].squeeze(),
"attention_mask": encoded["attention_mask"].squeeze(),
"teacher_logits": torch.tensor(item["teacher_logits"], dtype=torch.float),
"labels": encoded["input_ids"].squeeze()
}
class KnowledgeDistillationTrainer(Trainer):
"""知识蒸馏训练器"""
def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7, beta=0.3, gamma=0.1, **kwargs):
super().__init__(model=student_model, **kwargs)
self.teacher_model = teacher_model
self.teacher_model.eval() # 冻结Teacher模型
self.temperature = temperature
self.alpha = alpha # 蒸馏损失权重
self.beta = beta # 学生损失权重
self.gamma = gamma # 特征匹配损失权重
def compute_loss(self, model, inputs, return_outputs=False):
"""计算蒸馏损失"""
labels = inputs.get("labels")
teacher_logits = inputs.get("teacher_logits").to(model.device)
# Student模型前向传播
student_outputs = model(**{k: v for k, v in inputs.items() if k not in ["teacher_logits"]})
student_logits = student_outputs.logits
# 计算各种损失
losses = {}
# 1. 标准语言模型损失 (学生模型自己的损失)
if labels is not None:
shift_logits = student_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
student_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
losses["student_loss"] = student_loss
# 2. 蒸馏损失 (KL散度)
if teacher_logits is not None:
# 确保维度匹配
if teacher_logits.shape != student_logits.shape:
min_seq_len = min(teacher_logits.shape[1], student_logits.shape[1])
teacher_logits = teacher_logits[:, :min_seq_len, :]
student_logits_for_distill = student_logits[:, :min_seq_len, :]
else:
student_logits_for_distill = student_logits
# 计算软标签概率
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
student_log_probs = F.log_softmax(student_logits_for_distill / self.temperature, dim=-1)
# KL散度损失
distill_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction="batchmean"
) * (self.temperature ** 2)
losses["distill_loss"] = distill_loss
# 3. 组合总损失
total_loss = 0
if "student_loss" in losses:
total_loss += self.beta * losses["student_loss"]
if "distill_loss" in losses:
total_loss += self.alpha * losses["distill_loss"]
# 记录各项损失
self.log({
"train/total_loss": total_loss.item(),
"train/student_loss": losses.get("student_loss", 0).item() if "student_loss" in losses else 0,
"train/distill_loss": losses.get("distill_loss", 0).item() if "distill_loss" in losses else 0,
})
return (total_loss, student_outputs) if return_outputs else total_loss
def prepare_student_model(config: DistillationConfig):
"""准备Student模型"""
print("🎓 Preparing student model...")
# 加载Student基础模型
student_model = AutoModelForCausalLM.from_pretrained(
config.student_model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
# 添加LoRA(可选,用于高效训练)
if config.use_lora:
print("🔧 Adding LoRA to student model...")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
]
)
student_model = get_peft_model(student_model, lora_config)
student_model.print_trainable_parameters()
return student_model
def load_teacher_model(config: DistillationConfig):
"""加载Teacher模型"""
print("👨‍🏫 Loading teacher model...")
teacher_model = AutoModelForCausalLM.from_pretrained(
config.teacher_model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
teacher_model.eval()
return teacher_model
def generate_distillation_data(teacher_model, tokenizer, config: DistillationConfig):
"""生成蒸馏数据"""
print("📊 Generating distillation dataset...")
# 加载提示数据集
dataset_sources = [
"smangrul/ad-copy-generation",
# 可以添加更多数据源
]
all_prompts = []
for source in dataset_sources:
try:
ds = load_dataset(source, split="train")
# 提取提示词
for item in ds:
if "conversations" in item and len(item["conversations"]) > 0:
prompt = item["conversations"][0].get("value", "")
if len(prompt.strip()) > 10:
all_prompts.append(prompt.strip())
except Exception as e:
print(f"⚠️ Error loading {source}: {e}")
# 限制样本数量
if len(all_prompts) > config.num_distill_samples:
all_prompts = all_prompts[:config.num_distill_samples]
print(f"📝 Generating responses for {len(all_prompts)} prompts...")
distillation_data = []
teacher_model.eval()
with torch.no_grad():
for i, prompt in enumerate(tqdm(all_prompts, desc="Generating teacher responses")):
try:
# 格式化输入
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=config.max_length // 2
).to(teacher_model.device)
# 生成响应
outputs = teacher_model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True
)
# 解码响应
generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:]
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
# 获取Teacher的logits
full_text = f"### Human: {prompt}\n### Assistant: {response}"
full_inputs = tokenizer(
full_text,
return_tensors="pt",
truncation=True,
max_length=config.max_length
).to(teacher_model.device)
teacher_outputs = teacher_model(**full_inputs)
teacher_logits = teacher_outputs.logits.cpu().numpy()
distillation_data.append({
"prompt": prompt,
"response": response,
"teacher_logits": teacher_logits.tolist()
})
# 定期保存中间结果
if (i + 1) % 100 == 0:
print(f"Generated {i + 1}/{len(all_prompts)} samples")
except Exception as e:
print(f"⚠️ Error generating for prompt {i}: {e}")
continue
print(f"✅ Generated {len(distillation_data)} teacher-student pairs")
# 保存蒸馏数据
with open("distillation_data.json", "w", encoding="utf-8") as f:
json.dump(distillation_data, f, ensure_ascii=False, indent=2)
return distillation_data
def create_data_collator(tokenizer):
"""创建数据整理器"""
return DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=8
)
def run_distillation():
"""主要的蒸馏训练流程"""
print("🚀 Starting Teacher-Student Distillation...")
config = DistillationConfig()
# 初始化wandb
wandb.init(
project="teacher-student-distillation",
config=vars(config),
name=config.run_name
)
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 加载模型
teacher_model = load_teacher_model(config)
student_model = prepare_student_model(config)
# 生成蒸馏数据
if os.path.exists("distillation_data.json"):
print("📂 Loading existing distillation data...")
with open("distillation_data.json", "r", encoding="utf-8") as f:
distillation_data = json.load(f)
else:
distillation_data = generate_distillation_data(teacher_model, tokenizer, config)
# 创建数据集
train_size = int(0.9 * len(distillation_data))
train_data = distillation_data[:train_size]
eval_data = distillation_data[train_size:]
train_dataset = DistillationDataset(train_data, tokenizer, config.max_length)
eval_dataset = DistillationDataset(eval_data, tokenizer, config.max_length)
print(f"📊 Training samples: {len(train_dataset)}")
print(f"📊 Evaluation samples: {len(eval_dataset)}")
# 训练参数
training_args = TrainingArguments(
output_dir=config.output_dir,
overwrite_output_dir=True,
num_train_epochs=config.num_train_epochs,
per_device_train_batch_size=config.per_device_train_batch_size,
per_device_eval_batch_size=config.per_device_eval_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
warmup_ratio=config.warmup_ratio,
logging_steps=config.logging_steps,
eval_steps=config.eval_steps,
save_steps=config.save_steps,
evaluation_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to="wandb",
run_name=config.run_name,
fp16=True,
dataloader_pin_memory=False,
remove_unused_columns=False,
group_by_length=True,
)
# 创建数据整理器
data_collator = create_data_collator(tokenizer)
# 创建蒸馏训练器
trainer = KnowledgeDistillationTrainer(
teacher_model=teacher_model,
student_model=student_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
temperature=config.temperature,
alpha=config.alpha,
beta=config.beta,
gamma=config.gamma,
)
# 开始训练
print("🔥 Starting distillation training...")
trainer.train()
# 保存最终模型
print("💾 Saving distilled student model...")
trainer.save_model()
tokenizer.save_pretrained(config.output_dir)
# 评估模型
print("🧪 Evaluating distilled model...")
evaluate_distilled_model(trainer.model, tokenizer, config)
wandb.finish()
print("✅ Distillation training completed!")
def evaluate_distilled_model(model, tokenizer, config: DistillationConfig):
"""评估蒸馏后的模型"""
print("📊 Evaluating distilled student model...")
test_prompts = [
"Create an advertisement for a revolutionary AI-powered fitness tracker",
"Write marketing copy for an eco-friendly electric vehicle",
"Generate a slogan for a productivity app for remote workers",
"Create ad copy for a sustainable fashion brand targeting millennials",
"Write promotional content for a mental health app",
]
model.eval()
results = []
for prompt in test_prompts:
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text = response[len(formatted_prompt):].strip()
results.append({
"prompt": prompt,
"response": generated_text
})
print(f"\n🔍 Prompt: {prompt}")
print(f"📝 Student Response: {generated_text}")
print("-" * 80)
# 保存评估结果
with open(f"{config.output_dir}/evaluation_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
return results
if __name__ == "__main__":
# 设置环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 检查GPU
if torch.cuda.is_available():
print(f"🔥 Using {torch.cuda.device_count()} GPUs")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
else:
print("⚠️ Warning: No GPU available, using CPU (very slow)")
# 开始蒸馏训练
run_distillation()