AdGPT / lauguage_model_fine_tuning /ppo_fine_tune_teacher.py
goodmodeler's picture
ADD: LLM SFT, RLHF and Distillation
c1c9e88
#!/usr/bin/env python3
"""
PPO RLHF训练脚本 - 基于Teacher模型进行人类偏好对齐
输入: SFT Teacher模型 + 人类偏好数据
输出: RLHF对齐的Teacher模型
"""
import os
import torch
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
pipeline,
logging,
)
from peft import PeftModel, LoraConfig, get_peft_model, TaskType
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
import wandb
import numpy as np
from typing import List, Dict, Any
import warnings
warnings.filterwarnings("ignore")
logging.set_verbosity(logging.CRITICAL)
class RLHFConfig:
"""RLHF训练配置"""
# 模型路径
teacher_model_path = "./merged_model" # 之前SFT训练的Teacher模型
reward_model_name = "OpenAssistant/reward-model-deberta-v3-large-v2" # 奖励模型
# PPO训练参数
learning_rate = 1e-5
mini_batch_size = 1
batch_size = 8
gradient_accumulation_steps = 8
ppo_epochs = 4
max_grad_norm = 1.0
# PPO特定参数
init_kl_coef = 0.02
target_kl = 0.01
adap_kl_ctrl = True
clip_reward_value = 5.0
cliprange = 0.2
cliprange_value = 0.2
gamma = 1.0
lam = 0.95
# 生成参数
max_new_tokens = 150
temperature = 0.7
top_p = 0.9
do_sample = True
# 训练控制
total_episodes = 1000
save_freq = 100
eval_freq = 50
output_dir = "./rlhf_teacher_model"
# LoRA参数(如果使用LoRA进行RLHF)
use_lora = True
lora_r = 16
lora_alpha = 32
lora_dropout = 0.1
class RewardModelWrapper:
"""奖励模型包装器"""
def __init__(self, model_name: str, device: str = "cuda"):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.model.eval()
# 设置pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def get_reward(self, prompts: List[str], responses: List[str]) -> List[float]:
"""计算奖励分数"""
inputs = []
for prompt, response in zip(prompts, responses):
# 格式化为对话格式
text = f"Human: {prompt}\n\nAssistant: {response}"
inputs.append(text)
# 批量推理
with torch.no_grad():
encoded = self.tokenizer(
inputs,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
outputs = self.model(**encoded)
rewards = outputs.logits.squeeze(-1).cpu().tolist()
return rewards
def load_preference_dataset():
"""加载偏好数据集"""
print("📥 Loading preference dataset...")
# 可以使用多个数据源
datasets_config = [
{
"name": "Anthropic/hh-rlhf",
"split": "train",
"weight": 0.7
},
{
"name": "OpenAssistant/oasst1",
"split": "train",
"weight": 0.3
}
]
all_prompts = []
for config in datasets_config:
try:
dataset = load_dataset(config["name"], split=config["split"])
# 处理不同数据集格式
if config["name"] == "Anthropic/hh-rlhf":
prompts = extract_prompts_from_hh(dataset)
else:
prompts = extract_prompts_from_oasst(dataset)
# 按权重采样
sample_size = int(len(prompts) * config["weight"])
prompts = prompts[:sample_size]
all_prompts.extend(prompts)
print(f"✅ Loaded {len(prompts)} prompts from {config['name']}")
except Exception as e:
print(f"⚠️ Failed to load {config['name']}: {e}")
# 创建Dataset对象
return Dataset.from_dict({"prompt": all_prompts})
def extract_prompts_from_hh(dataset):
"""从HH-RLHF数据集提取提示"""
prompts = []
for item in dataset:
# HH-RLHF格式解析
text = item.get("chosen", "")
if "Human:" in text:
prompt = text.split("Human:")[-1].split("Assistant:")[0].strip()
if len(prompt) > 10: # 过滤太短的提示
prompts.append(prompt)
return prompts
def extract_prompts_from_oasst(dataset):
"""从OpenAssistant数据集提取提示"""
prompts = []
for item in dataset:
if item.get("role") == "prompter":
prompt = item.get("text", "").strip()
if len(prompt) > 10:
prompts.append(prompt)
return prompts
def prepare_teacher_model(config: RLHFConfig):
"""准备Teacher模型用于RLHF"""
print("🤖 Preparing teacher model for RLHF...")
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(
config.teacher_model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
# 如果使用LoRA进行RLHF
if config.use_lora:
print("🔧 Adding LoRA for RLHF training...")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
]
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# 包装为带价值头的模型
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model,
torch_dtype=torch.float16,
)
# 创建参考模型(冻结)
ref_model = AutoModelForCausalLM.from_pretrained(
config.teacher_model_path,
torch_dtype=torch.float16,
device_map="auto",
)
ref_model.eval()
return model, ref_model, tokenizer
def create_ppo_trainer(model, ref_model, tokenizer, config: RLHFConfig):
"""创建PPO训练器"""
print("🏋️ Creating PPO trainer...")
ppo_config = PPOConfig(
model_name=config.teacher_model_path,
learning_rate=config.learning_rate,
mini_batch_size=config.mini_batch_size,
batch_size=config.batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
ppo_epochs=config.ppo_epochs,
max_grad_norm=config.max_grad_norm,
init_kl_coef=config.init_kl_coef,
target_kl=config.target_kl,
adap_kl_ctrl=config.adap_kl_ctrl,
clip_reward_value=config.clip_reward_value,
cliprange=config.cliprange,
cliprange_value=config.cliprange_value,
gamma=config.gamma,
lam=config.lam,
remove_unused_columns=False,
log_with="wandb" if wandb.run else None,
)
trainer = PPOTrainer(
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
)
return trainer
def format_prompt_for_generation(prompt: str) -> str:
"""格式化提示用于生成"""
return f"### Human: {prompt}\n### Assistant:"
def run_ppo_training():
"""主要的PPO训练循环"""
print("🚀 Starting PPO RLHF Training...")
# 初始化wandb
wandb.init(
project="rlhf-teacher-training",
config=vars(RLHFConfig),
name="ppo-teacher-rlhf"
)
config = RLHFConfig()
# 准备模型
model, ref_model, tokenizer = prepare_teacher_model(config)
# 创建PPO训练器
ppo_trainer = create_ppo_trainer(model, ref_model, tokenizer, config)
# 加载奖励模型
reward_model = RewardModelWrapper(config.reward_model_name)
# 加载数据集
dataset = load_preference_dataset()
print(f"📊 Training on {len(dataset)} prompts")
print(f"🎯 Target episodes: {config.total_episodes}")
# 训练循环
for episode in range(config.total_episodes):
# 随机采样prompts
batch_prompts = np.random.choice(
dataset["prompt"],
size=config.batch_size,
replace=False
).tolist()
# 格式化输入
formatted_prompts = [format_prompt_for_generation(p) for p in batch_prompts]
# 生成响应
prompt_tensors = []
for prompt in formatted_prompts:
prompt_tensor = tokenizer.encode(
prompt,
return_tensors="pt",
padding=False,
truncation=True,
max_length=256
).squeeze()
prompt_tensors.append(prompt_tensor)
# 批量生成
response_tensors = []
with torch.no_grad():
for prompt_tensor in prompt_tensors:
prompt_tensor = prompt_tensor.unsqueeze(0).to(model.device)
response = ppo_trainer.generate(
prompt_tensor,
max_new_tokens=config.max_new_tokens,
temperature=config.temperature,
top_p=config.top_p,
do_sample=config.do_sample,
pad_token_id=tokenizer.eos_token_id,
)
# 只保留新生成的部分
response = response.squeeze()[prompt_tensor.shape[1]:]
response_tensors.append(response)
# 解码响应
responses = [
tokenizer.decode(r, skip_special_tokens=True).strip()
for r in response_tensors
]
# 计算奖励
rewards = reward_model.get_reward(batch_prompts, responses)
rewards = [torch.tensor(r, dtype=torch.float) for r in rewards]
# PPO训练步骤
stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
# 记录统计信息
ppo_trainer.log_stats(
stats,
batch_prompts,
[list(p) + list(r) for p, r in zip(prompt_tensors, response_tensors)],
rewards
)
# 打印进度
if episode % 10 == 0:
mean_reward = np.mean([r.item() for r in rewards])
print(f"📈 Episode {episode}: Mean Reward = {mean_reward:.4f}")
# 记录到wandb
wandb.log({
"episode": episode,
"mean_reward": mean_reward,
"kl_divergence": stats.get("objective/kl", 0),
"policy_loss": stats.get("ppo/loss/policy", 0),
"value_loss": stats.get("ppo/loss/value", 0),
})
# 评估模型
if episode % config.eval_freq == 0 and episode > 0:
evaluate_model(ppo_trainer.model, tokenizer, episode)
# 保存检查点
if episode % config.save_freq == 0 and episode > 0:
save_checkpoint(ppo_trainer.model, tokenizer, config.output_dir, episode)
# 保存最终模型
print("💾 Saving final RLHF model...")
ppo_trainer.model.save_pretrained(config.output_dir)
tokenizer.save_pretrained(config.output_dir)
wandb.finish()
print("✅ RLHF training completed!")
def evaluate_model(model, tokenizer, episode):
"""评估模型性能"""
print(f"🧪 Evaluating model at episode {episode}...")
test_prompts = [
"Create an advertisement for a revolutionary smartphone with AI capabilities",
"Write marketing copy for an eco-friendly clothing brand",
"Generate a slogan for a fitness app targeting busy professionals",
]
model.eval()
results = []
for prompt in test_prompts:
formatted_prompt = format_prompt_for_generation(prompt)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text = response[len(formatted_prompt):].strip()
results.append({
"prompt": prompt,
"response": generated_text
})
print(f"🔍 Prompt: {prompt}")
print(f"📝 Response: {generated_text}")
print("-" * 80)
model.train()
return results
def save_checkpoint(model, tokenizer, output_dir, episode):
"""保存训练检查点"""
checkpoint_dir = f"{output_dir}/checkpoint-{episode}"
os.makedirs(checkpoint_dir, exist_ok=True)
model.save_pretrained(checkpoint_dir)
tokenizer.save_pretrained(checkpoint_dir)
print(f"💾 Checkpoint saved to {checkpoint_dir}")
def load_checkpoint_and_continue(checkpoint_path):
"""从检查点继续训练"""
print(f"📥 Loading checkpoint from {checkpoint_path}")
# 实现检查点恢复逻辑
pass
if __name__ == "__main__":
# 设置环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 多GPU设置
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 检查GPU资源
if torch.cuda.is_available():
print(f"🔥 Using {torch.cuda.device_count()} GPUs")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
else:
raise RuntimeError("❌ CUDA not available! RLHF requires GPU.")
# 开始训练
run_ppo_training()