Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Teacher-Student模型性能比较脚本 | |
比较RLHF Teacher模型和蒸馏后的Student模型的性能 | |
""" | |
import torch | |
import argparse | |
import json | |
import time | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from typing import List, Dict, Any | |
import numpy as np | |
from datetime import datetime | |
class ModelComparator: | |
def __init__(self, teacher_path: str, student_path: str): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("📥 Loading Teacher model...") | |
self.teacher_model = AutoModelForCausalLM.from_pretrained( | |
teacher_path, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_path) | |
print("📥 Loading Student model...") | |
self.student_model = AutoModelForCausalLM.from_pretrained( | |
student_path, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
self.student_tokenizer = AutoTokenizer.from_pretrained(student_path) | |
# 设置pad tokens | |
for tokenizer in [self.teacher_tokenizer, self.student_tokenizer]: | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
def generate_response(self, model, tokenizer, prompt: str, **kwargs) -> Dict[str, Any]: | |
"""生成响应并记录性能指标""" | |
formatted_prompt = f"### Human: {prompt}\n### Assistant:" | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
generation_config = { | |
"max_new_tokens": 200, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"do_sample": True, | |
"pad_token_id": tokenizer.eos_token_id, | |
**kwargs | |
} | |
# 测量生成时间 | |
start_time = time.time() | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, **generation_config) | |
generation_time = time.time() - start_time | |
# 解码响应 | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
generated_text = response[len(formatted_prompt):].strip() | |
# 计算tokens数量 | |
generated_tokens = len(tokenizer.encode(generated_text)) | |
return { | |
"response": generated_text, | |
"generation_time": generation_time, | |
"tokens_generated": generated_tokens, | |
"tokens_per_second": generated_tokens / generation_time if generation_time > 0 else 0, | |
"prompt_tokens": inputs.input_ids.shape[1], | |
"total_tokens": outputs.shape[1] | |
} | |
def calculate_model_size(self, model) -> Dict[str, Any]: | |
"""计算模型大小和参数量""" | |
param_count = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
# 估算模型大小(bytes) | |
model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) | |
model_size_mb = model_size_bytes / (1024 * 1024) | |
model_size_gb = model_size_mb / 1024 | |
return { | |
"total_parameters": param_count, | |
"trainable_parameters": trainable_params, | |
"model_size_mb": model_size_mb, | |
"model_size_gb": model_size_gb, | |
"compression_ratio": None # 将在比较时计算 | |
} | |
def evaluate_quality_metrics(self, responses: List[str]) -> Dict[str, float]: | |
"""评估生成质量指标""" | |
metrics = {} | |
# 平均响应长度 | |
avg_length = np.mean([len(resp.split()) for resp in responses]) | |
metrics["avg_response_length"] = avg_length | |
# 响应长度标准差 | |
length_std = np.std([len(resp.split()) for resp in responses]) | |
metrics["response_length_std"] = length_std | |
# 词汇丰富度(使用type-token ratio的简化版本) | |
all_words = [] | |
for resp in responses: | |
all_words.extend(resp.lower().split()) | |
if all_words: | |
unique_words = len(set(all_words)) | |
total_words = len(all_words) | |
metrics["vocabulary_richness"] = unique_words / total_words | |
else: | |
metrics["vocabulary_richness"] = 0.0 | |
# 平均句子数量 | |
avg_sentences = np.mean([resp.count('.') + resp.count('!') + resp.count('?') for resp in responses]) | |
metrics["avg_sentences_per_response"] = avg_sentences | |
return metrics | |
def run_comprehensive_comparison(self) -> Dict[str, Any]: | |
"""运行全面的性能比较""" | |
print("🔍 Running comprehensive Teacher-Student comparison...") | |
# 测试提示词集合 | |
test_prompts = [ | |
# 广告文案生成 | |
"Create an advertisement for a revolutionary smartphone with advanced AI features", | |
"Write marketing copy for an eco-friendly electric vehicle targeting urban professionals", | |
"Generate a catchy slogan for a fitness app that uses AI personal training", | |
"Create promotional content for a sustainable fashion brand targeting Gen Z", | |
"Write ad copy for a productivity software targeting remote workers", | |
# 不同复杂度的任务 | |
"Explain the benefits of renewable energy in simple terms", | |
"Write a brief product description for wireless headphones with noise cancellation", | |
"Create a social media post promoting a new coffee shop opening", | |
"Generate marketing text for a luxury watch brand", | |
"Write an email subject line for a summer sale promotion", | |
# 创意任务 | |
"Create a tagline for a travel app that focuses on sustainable tourism", | |
"Write a short product pitch for smart home security system", | |
"Generate advertising copy for a meal delivery service focusing on healthy options", | |
"Create marketing content for an online learning platform", | |
"Write promotional text for a mental wellness app" | |
] | |
# 初始化结果收集 | |
results = { | |
"comparison_date": datetime.now().isoformat(), | |
"test_prompts_count": len(test_prompts), | |
"teacher_results": {}, | |
"student_results": {}, | |
"performance_comparison": {}, | |
"detailed_responses": [] | |
} | |
# 获取模型信息 | |
print("📊 Analyzing model specifications...") | |
teacher_info = self.calculate_model_size(self.teacher_model) | |
student_info = self.calculate_model_size(self.student_model) |