Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Teacher-Student知识蒸馏脚本 | |
将经过SFT+PPO RLHF的Teacher模型蒸馏到更小的Student模型 | |
""" | |
import os | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, Dataset | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TrainingArguments, | |
Trainer, | |
DataCollatorForLanguageModeling, | |
logging, | |
) | |
from datasets import load_dataset, Dataset as HFDataset | |
from peft import LoraConfig, get_peft_model, TaskType | |
import numpy as np | |
import wandb | |
from typing import Dict, List, Any, Optional | |
import json | |
from tqdm import tqdm | |
import warnings | |
warnings.filterwarnings("ignore") | |
logging.set_verbosity(logging.CRITICAL) | |
class DistillationConfig: | |
"""蒸馏训练配置""" | |
# 模型路径 | |
teacher_model_path = "./rlhf_teacher_model" # RLHF后的Teacher模型 | |
student_model_name = "microsoft/DialoGPT-medium" # 替换为实际的OpenAI OSS 20B模型 | |
# 蒸馏参数 | |
temperature = 4.0 # 蒸馏温度 | |
alpha = 0.7 # 蒸馏损失权重 | |
beta = 0.3 # 学生损失权重 | |
gamma = 0.1 # 特征匹配损失权重 | |
# 训练参数 | |
learning_rate = 1e-4 | |
num_train_epochs = 3 | |
per_device_train_batch_size = 2 | |
per_device_eval_batch_size = 4 | |
gradient_accumulation_steps = 8 | |
warmup_ratio = 0.1 | |
weight_decay = 0.01 | |
logging_steps = 50 | |
eval_steps = 500 | |
save_steps = 1000 | |
# LoRA配置(为Student模型添加LoRA以提高训练效率) | |
use_lora = True | |
lora_r = 32 | |
lora_alpha = 64 | |
lora_dropout = 0.1 | |
# 数据配置 | |
max_length = 512 | |
num_distill_samples = 10000 # 用于蒸馏的样本数量 | |
# 输出配置 | |
output_dir = "./distilled_student_model" | |
run_name = "teacher-student-distillation" | |
class DistillationDataset(Dataset): | |
"""蒸馏数据集类""" | |
def __init__(self, teacher_outputs: List[Dict], tokenizer, max_length: int = 512): | |
self.data = teacher_outputs | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
# 构建完整的输入-输出序列 | |
full_text = f"### Human: {item['prompt']}\n### Assistant: {item['response']}" | |
# Tokenize | |
encoded = self.tokenizer( | |
full_text, | |
truncation=True, | |
padding="max_length", | |
max_length=self.max_length, | |
return_tensors="pt" | |
) | |
return { | |
"input_ids": encoded["input_ids"].squeeze(), | |
"attention_mask": encoded["attention_mask"].squeeze(), | |
"teacher_logits": torch.tensor(item["teacher_logits"], dtype=torch.float), | |
"labels": encoded["input_ids"].squeeze() | |
} | |
class KnowledgeDistillationTrainer(Trainer): | |
"""知识蒸馏训练器""" | |
def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7, beta=0.3, gamma=0.1, **kwargs): | |
super().__init__(model=student_model, **kwargs) | |
self.teacher_model = teacher_model | |
self.teacher_model.eval() # 冻结Teacher模型 | |
self.temperature = temperature | |
self.alpha = alpha # 蒸馏损失权重 | |
self.beta = beta # 学生损失权重 | |
self.gamma = gamma # 特征匹配损失权重 | |
def compute_loss(self, model, inputs, return_outputs=False): | |
"""计算蒸馏损失""" | |
labels = inputs.get("labels") | |
teacher_logits = inputs.get("teacher_logits").to(model.device) | |
# Student模型前向传播 | |
student_outputs = model(**{k: v for k, v in inputs.items() if k not in ["teacher_logits"]}) | |
student_logits = student_outputs.logits | |
# 计算各种损失 | |
losses = {} | |
# 1. 标准语言模型损失 (学生模型自己的损失) | |
if labels is not None: | |
shift_logits = student_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
loss_fct = torch.nn.CrossEntropyLoss() | |
student_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
losses["student_loss"] = student_loss | |
# 2. 蒸馏损失 (KL散度) | |
if teacher_logits is not None: | |
# 确保维度匹配 | |
if teacher_logits.shape != student_logits.shape: | |
min_seq_len = min(teacher_logits.shape[1], student_logits.shape[1]) | |
teacher_logits = teacher_logits[:, :min_seq_len, :] | |
student_logits_for_distill = student_logits[:, :min_seq_len, :] | |
else: | |
student_logits_for_distill = student_logits | |
# 计算软标签概率 | |
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1) | |
student_log_probs = F.log_softmax(student_logits_for_distill / self.temperature, dim=-1) | |
# KL散度损失 | |
distill_loss = F.kl_div( | |
student_log_probs, | |
teacher_probs, | |
reduction="batchmean" | |
) * (self.temperature ** 2) | |
losses["distill_loss"] = distill_loss | |
# 3. 组合总损失 | |
total_loss = 0 | |
if "student_loss" in losses: | |
total_loss += self.beta * losses["student_loss"] | |
if "distill_loss" in losses: | |
total_loss += self.alpha * losses["distill_loss"] | |
# 记录各项损失 | |
self.log({ | |
"train/total_loss": total_loss.item(), | |
"train/student_loss": losses.get("student_loss", 0).item() if "student_loss" in losses else 0, | |
"train/distill_loss": losses.get("distill_loss", 0).item() if "distill_loss" in losses else 0, | |
}) | |
return (total_loss, student_outputs) if return_outputs else total_loss | |
def prepare_student_model(config: DistillationConfig): | |
"""准备Student模型""" | |
print("🎓 Preparing student model...") | |
# 加载Student基础模型 | |
student_model = AutoModelForCausalLM.from_pretrained( | |
config.student_model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
# 添加LoRA(可选,用于高效训练) | |
if config.use_lora: | |
print("🔧 Adding LoRA to student model...") | |
lora_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
inference_mode=False, | |
r=config.lora_r, | |
lora_alpha=config.lora_alpha, | |
lora_dropout=config.lora_dropout, | |
target_modules=[ | |
"q_proj", "k_proj", "v_proj", "o_proj", | |
"gate_proj", "up_proj", "down_proj", | |
] | |
) | |
student_model = get_peft_model(student_model, lora_config) | |
student_model.print_trainable_parameters() | |
return student_model | |
def load_teacher_model(config: DistillationConfig): | |
"""加载Teacher模型""" | |
print("👨🏫 Loading teacher model...") | |
teacher_model = AutoModelForCausalLM.from_pretrained( | |
config.teacher_model_path, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
teacher_model.eval() | |
return teacher_model | |
def generate_distillation_data(teacher_model, tokenizer, config: DistillationConfig): | |
"""生成蒸馏数据""" | |
print("📊 Generating distillation dataset...") | |
# 加载提示数据集 | |
dataset_sources = [ | |
"smangrul/ad-copy-generation", | |
# 可以添加更多数据源 | |
] | |
all_prompts = [] | |
for source in dataset_sources: | |
try: | |
ds = load_dataset(source, split="train") | |
# 提取提示词 | |
for item in ds: | |
if "conversations" in item and len(item["conversations"]) > 0: | |
prompt = item["conversations"][0].get("value", "") | |
if len(prompt.strip()) > 10: | |
all_prompts.append(prompt.strip()) | |
except Exception as e: | |
print(f"⚠️ Error loading {source}: {e}") | |
# 限制样本数量 | |
if len(all_prompts) > config.num_distill_samples: | |
all_prompts = all_prompts[:config.num_distill_samples] | |
print(f"📝 Generating responses for {len(all_prompts)} prompts...") | |
distillation_data = [] | |
teacher_model.eval() | |
with torch.no_grad(): | |
for i, prompt in enumerate(tqdm(all_prompts, desc="Generating teacher responses")): | |
try: | |
# 格式化输入 | |
formatted_prompt = f"### Human: {prompt}\n### Assistant:" | |
inputs = tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=config.max_length // 2 | |
).to(teacher_model.device) | |
# 生成响应 | |
outputs = teacher_model.generate( | |
**inputs, | |
max_new_tokens=200, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
return_dict_in_generate=True, | |
output_scores=True | |
) | |
# 解码响应 | |
generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:] | |
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
# 获取Teacher的logits | |
full_text = f"### Human: {prompt}\n### Assistant: {response}" | |
full_inputs = tokenizer( | |
full_text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=config.max_length | |
).to(teacher_model.device) | |
teacher_outputs = teacher_model(**full_inputs) | |
teacher_logits = teacher_outputs.logits.cpu().numpy() | |
distillation_data.append({ | |
"prompt": prompt, | |
"response": response, | |
"teacher_logits": teacher_logits.tolist() | |
}) | |
# 定期保存中间结果 | |
if (i + 1) % 100 == 0: | |
print(f"Generated {i + 1}/{len(all_prompts)} samples") | |
except Exception as e: | |
print(f"⚠️ Error generating for prompt {i}: {e}") | |
continue | |
print(f"✅ Generated {len(distillation_data)} teacher-student pairs") | |
# 保存蒸馏数据 | |
with open("distillation_data.json", "w", encoding="utf-8") as f: | |
json.dump(distillation_data, f, ensure_ascii=False, indent=2) | |
return distillation_data | |
def create_data_collator(tokenizer): | |
"""创建数据整理器""" | |
return DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False, | |
pad_to_multiple_of=8 | |
) | |
def run_distillation(): | |
"""主要的蒸馏训练流程""" | |
print("🚀 Starting Teacher-Student Distillation...") | |
config = DistillationConfig() | |
# 初始化wandb | |
wandb.init( | |
project="teacher-student-distillation", | |
config=vars(config), | |
name=config.run_name | |
) | |
# 加载tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# 加载模型 | |
teacher_model = load_teacher_model(config) | |
student_model = prepare_student_model(config) | |
# 生成蒸馏数据 | |
if os.path.exists("distillation_data.json"): | |
print("📂 Loading existing distillation data...") | |
with open("distillation_data.json", "r", encoding="utf-8") as f: | |
distillation_data = json.load(f) | |
else: | |
distillation_data = generate_distillation_data(teacher_model, tokenizer, config) | |
# 创建数据集 | |
train_size = int(0.9 * len(distillation_data)) | |
train_data = distillation_data[:train_size] | |
eval_data = distillation_data[train_size:] | |
train_dataset = DistillationDataset(train_data, tokenizer, config.max_length) | |
eval_dataset = DistillationDataset(eval_data, tokenizer, config.max_length) | |
print(f"📊 Training samples: {len(train_dataset)}") | |
print(f"📊 Evaluation samples: {len(eval_dataset)}") | |
# 训练参数 | |
training_args = TrainingArguments( | |
output_dir=config.output_dir, | |
overwrite_output_dir=True, | |
num_train_epochs=config.num_train_epochs, | |
per_device_train_batch_size=config.per_device_train_batch_size, | |
per_device_eval_batch_size=config.per_device_eval_batch_size, | |
gradient_accumulation_steps=config.gradient_accumulation_steps, | |
learning_rate=config.learning_rate, | |
weight_decay=config.weight_decay, | |
warmup_ratio=config.warmup_ratio, | |
logging_steps=config.logging_steps, | |
eval_steps=config.eval_steps, | |
save_steps=config.save_steps, | |
evaluation_strategy="steps", | |
save_strategy="steps", | |
load_best_model_at_end=True, | |
metric_for_best_model="eval_loss", | |
greater_is_better=False, | |
report_to="wandb", | |
run_name=config.run_name, | |
fp16=True, | |
dataloader_pin_memory=False, | |
remove_unused_columns=False, | |
group_by_length=True, | |
) | |
# 创建数据整理器 | |
data_collator = create_data_collator(tokenizer) | |
# 创建蒸馏训练器 | |
trainer = KnowledgeDistillationTrainer( | |
teacher_model=teacher_model, | |
student_model=student_model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
data_collator=data_collator, | |
tokenizer=tokenizer, | |
temperature=config.temperature, | |
alpha=config.alpha, | |
beta=config.beta, | |
gamma=config.gamma, | |
) | |
# 开始训练 | |
print("🔥 Starting distillation training...") | |
trainer.train() | |
# 保存最终模型 | |
print("💾 Saving distilled student model...") | |
trainer.save_model() | |
tokenizer.save_pretrained(config.output_dir) | |
# 评估模型 | |
print("🧪 Evaluating distilled model...") | |
evaluate_distilled_model(trainer.model, tokenizer, config) | |
wandb.finish() | |
print("✅ Distillation training completed!") | |
def evaluate_distilled_model(model, tokenizer, config: DistillationConfig): | |
"""评估蒸馏后的模型""" | |
print("📊 Evaluating distilled student model...") | |
test_prompts = [ | |
"Create an advertisement for a revolutionary AI-powered fitness tracker", | |
"Write marketing copy for an eco-friendly electric vehicle", | |
"Generate a slogan for a productivity app for remote workers", | |
"Create ad copy for a sustainable fashion brand targeting millennials", | |
"Write promotional content for a mental health app", | |
] | |
model.eval() | |
results = [] | |
for prompt in test_prompts: | |
formatted_prompt = f"### Human: {prompt}\n### Assistant:" | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
generated_text = response[len(formatted_prompt):].strip() | |
results.append({ | |
"prompt": prompt, | |
"response": generated_text | |
}) | |
print(f"\n🔍 Prompt: {prompt}") | |
print(f"📝 Student Response: {generated_text}") | |
print("-" * 80) | |
# 保存评估结果 | |
with open(f"{config.output_dir}/evaluation_results.json", "w", encoding="utf-8") as f: | |
json.dump(results, f, ensure_ascii=False, indent=2) | |
return results | |
if __name__ == "__main__": | |
# 设置环境变量 | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# 检查GPU | |
if torch.cuda.is_available(): | |
print(f"🔥 Using {torch.cuda.device_count()} GPUs") | |
for i in range(torch.cuda.device_count()): | |
print(f" GPU {i}: {torch.cuda.get_device_name(i)}") | |
else: | |
print("⚠️ Warning: No GPU available, using CPU (very slow)") | |
# 开始蒸馏训练 | |
run_distillation() |