AdGPT / lauguage_model_fine_tuning /distillation /eval_compare_teacher_student.py
goodmodeler's picture
ADD: LLM SFT, RLHF and Distillation
c1c9e88
#!/usr/bin/env python3
"""
Teacher-Student模型性能比较脚本
比较RLHF Teacher模型和蒸馏后的Student模型的性能
"""
import torch
import argparse
import json
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict, Any
import numpy as np
from datetime import datetime
class ModelComparator:
def __init__(self, teacher_path: str, student_path: str):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print("📥 Loading Teacher model...")
self.teacher_model = AutoModelForCausalLM.from_pretrained(
teacher_path,
torch_dtype=torch.float16,
device_map="auto"
)
self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_path)
print("📥 Loading Student model...")
self.student_model = AutoModelForCausalLM.from_pretrained(
student_path,
torch_dtype=torch.float16,
device_map="auto"
)
self.student_tokenizer = AutoTokenizer.from_pretrained(student_path)
# 设置pad tokens
for tokenizer in [self.teacher_tokenizer, self.student_tokenizer]:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def generate_response(self, model, tokenizer, prompt: str, **kwargs) -> Dict[str, Any]:
"""生成响应并记录性能指标"""
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
generation_config = {
"max_new_tokens": 200,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
**kwargs
}
# 测量生成时间
start_time = time.time()
with torch.no_grad():
outputs = model.generate(**inputs, **generation_config)
generation_time = time.time() - start_time
# 解码响应
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text = response[len(formatted_prompt):].strip()
# 计算tokens数量
generated_tokens = len(tokenizer.encode(generated_text))
return {
"response": generated_text,
"generation_time": generation_time,
"tokens_generated": generated_tokens,
"tokens_per_second": generated_tokens / generation_time if generation_time > 0 else 0,
"prompt_tokens": inputs.input_ids.shape[1],
"total_tokens": outputs.shape[1]
}
def calculate_model_size(self, model) -> Dict[str, Any]:
"""计算模型大小和参数量"""
param_count = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# 估算模型大小(bytes)
model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
model_size_mb = model_size_bytes / (1024 * 1024)
model_size_gb = model_size_mb / 1024
return {
"total_parameters": param_count,
"trainable_parameters": trainable_params,
"model_size_mb": model_size_mb,
"model_size_gb": model_size_gb,
"compression_ratio": None # 将在比较时计算
}
def evaluate_quality_metrics(self, responses: List[str]) -> Dict[str, float]:
"""评估生成质量指标"""
metrics = {}
# 平均响应长度
avg_length = np.mean([len(resp.split()) for resp in responses])
metrics["avg_response_length"] = avg_length
# 响应长度标准差
length_std = np.std([len(resp.split()) for resp in responses])
metrics["response_length_std"] = length_std
# 词汇丰富度(使用type-token ratio的简化版本)
all_words = []
for resp in responses:
all_words.extend(resp.lower().split())
if all_words:
unique_words = len(set(all_words))
total_words = len(all_words)
metrics["vocabulary_richness"] = unique_words / total_words
else:
metrics["vocabulary_richness"] = 0.0
# 平均句子数量
avg_sentences = np.mean([resp.count('.') + resp.count('!') + resp.count('?') for resp in responses])
metrics["avg_sentences_per_response"] = avg_sentences
return metrics
def run_comprehensive_comparison(self) -> Dict[str, Any]:
"""运行全面的性能比较"""
print("🔍 Running comprehensive Teacher-Student comparison...")
# 测试提示词集合
test_prompts = [
# 广告文案生成
"Create an advertisement for a revolutionary smartphone with advanced AI features",
"Write marketing copy for an eco-friendly electric vehicle targeting urban professionals",
"Generate a catchy slogan for a fitness app that uses AI personal training",
"Create promotional content for a sustainable fashion brand targeting Gen Z",
"Write ad copy for a productivity software targeting remote workers",
# 不同复杂度的任务
"Explain the benefits of renewable energy in simple terms",
"Write a brief product description for wireless headphones with noise cancellation",
"Create a social media post promoting a new coffee shop opening",
"Generate marketing text for a luxury watch brand",
"Write an email subject line for a summer sale promotion",
# 创意任务
"Create a tagline for a travel app that focuses on sustainable tourism",
"Write a short product pitch for smart home security system",
"Generate advertising copy for a meal delivery service focusing on healthy options",
"Create marketing content for an online learning platform",
"Write promotional text for a mental wellness app"
]
# 初始化结果收集
results = {
"comparison_date": datetime.now().isoformat(),
"test_prompts_count": len(test_prompts),
"teacher_results": {},
"student_results": {},
"performance_comparison": {},
"detailed_responses": []
}
# 获取模型信息
print("📊 Analyzing model specifications...")
teacher_info = self.calculate_model_size(self.teacher_model)
student_info = self.calculate_model_size(self.student_model)