File size: 14,423 Bytes
c1c9e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
#!/usr/bin/env python3
"""
PPO RLHF训练脚本 - 基于Teacher模型进行人类偏好对齐
输入: SFT Teacher模型 + 人类偏好数据
输出: RLHF对齐的Teacher模型
"""

import os
import torch
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import PeftModel, LoraConfig, get_peft_model, TaskType
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
import wandb
import numpy as np
from typing import List, Dict, Any
import warnings

warnings.filterwarnings("ignore")
logging.set_verbosity(logging.CRITICAL)

class RLHFConfig:
    """RLHF训练配置"""
    # 模型路径
    teacher_model_path = "./merged_model"  # 之前SFT训练的Teacher模型
    reward_model_name = "OpenAssistant/reward-model-deberta-v3-large-v2"  # 奖励模型
    
    # PPO训练参数
    learning_rate = 1e-5
    mini_batch_size = 1
    batch_size = 8
    gradient_accumulation_steps = 8
    ppo_epochs = 4
    max_grad_norm = 1.0
    
    # PPO特定参数
    init_kl_coef = 0.02
    target_kl = 0.01
    adap_kl_ctrl = True
    clip_reward_value = 5.0
    cliprange = 0.2
    cliprange_value = 0.2
    gamma = 1.0
    lam = 0.95
    
    # 生成参数
    max_new_tokens = 150
    temperature = 0.7
    top_p = 0.9
    do_sample = True
    
    # 训练控制
    total_episodes = 1000
    save_freq = 100
    eval_freq = 50
    output_dir = "./rlhf_teacher_model"
    
    # LoRA参数(如果使用LoRA进行RLHF)
    use_lora = True
    lora_r = 16
    lora_alpha = 32
    lora_dropout = 0.1

class RewardModelWrapper:
    """奖励模型包装器"""
    
    def __init__(self, model_name: str, device: str = "cuda"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.model.eval()
        
        # 设置pad token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def get_reward(self, prompts: List[str], responses: List[str]) -> List[float]:
        """计算奖励分数"""
        inputs = []
        for prompt, response in zip(prompts, responses):
            # 格式化为对话格式
            text = f"Human: {prompt}\n\nAssistant: {response}"
            inputs.append(text)
        
        # 批量推理
        with torch.no_grad():
            encoded = self.tokenizer(
                inputs,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(self.device)
            
            outputs = self.model(**encoded)
            rewards = outputs.logits.squeeze(-1).cpu().tolist()
        
        return rewards

def load_preference_dataset():
    """加载偏好数据集"""
    print("📥 Loading preference dataset...")
    
    # 可以使用多个数据源
    datasets_config = [
        {
            "name": "Anthropic/hh-rlhf",
            "split": "train",
            "weight": 0.7
        },
        {
            "name": "OpenAssistant/oasst1", 
            "split": "train",
            "weight": 0.3
        }
    ]
    
    all_prompts = []
    
    for config in datasets_config:
        try:
            dataset = load_dataset(config["name"], split=config["split"])
            
            # 处理不同数据集格式
            if config["name"] == "Anthropic/hh-rlhf":
                prompts = extract_prompts_from_hh(dataset)
            else:
                prompts = extract_prompts_from_oasst(dataset)
            
            # 按权重采样
            sample_size = int(len(prompts) * config["weight"])
            prompts = prompts[:sample_size]
            all_prompts.extend(prompts)
            
            print(f"✅ Loaded {len(prompts)} prompts from {config['name']}")
            
        except Exception as e:
            print(f"⚠️ Failed to load {config['name']}: {e}")
    
    # 创建Dataset对象
    return Dataset.from_dict({"prompt": all_prompts})

def extract_prompts_from_hh(dataset):
    """从HH-RLHF数据集提取提示"""
    prompts = []
    for item in dataset:
        # HH-RLHF格式解析
        text = item.get("chosen", "")
        if "Human:" in text:
            prompt = text.split("Human:")[-1].split("Assistant:")[0].strip()
            if len(prompt) > 10:  # 过滤太短的提示
                prompts.append(prompt)
    return prompts

def extract_prompts_from_oasst(dataset):
    """从OpenAssistant数据集提取提示"""
    prompts = []
    for item in dataset:
        if item.get("role") == "prompter":
            prompt = item.get("text", "").strip()
            if len(prompt) > 10:
                prompts.append(prompt)
    return prompts

def prepare_teacher_model(config: RLHFConfig):
    """准备Teacher模型用于RLHF"""
    print("🤖 Preparing teacher model for RLHF...")
    
    # 加载tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # 加载基础模型
    model = AutoModelForCausalLM.from_pretrained(
        config.teacher_model_path,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
    
    # 如果使用LoRA进行RLHF
    if config.use_lora:
        print("🔧 Adding LoRA for RLHF training...")
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            target_modules=[
                "q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj",
            ]
        )
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    
    # 包装为带价值头的模型
    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        model,
        torch_dtype=torch.float16,
    )
    
    # 创建参考模型(冻结)
    ref_model = AutoModelForCausalLM.from_pretrained(
        config.teacher_model_path,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    ref_model.eval()
    
    return model, ref_model, tokenizer

def create_ppo_trainer(model, ref_model, tokenizer, config: RLHFConfig):
    """创建PPO训练器"""
    print("🏋️ Creating PPO trainer...")
    
    ppo_config = PPOConfig(
        model_name=config.teacher_model_path,
        learning_rate=config.learning_rate,
        mini_batch_size=config.mini_batch_size,
        batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        ppo_epochs=config.ppo_epochs,
        max_grad_norm=config.max_grad_norm,
        init_kl_coef=config.init_kl_coef,
        target_kl=config.target_kl,
        adap_kl_ctrl=config.adap_kl_ctrl,
        clip_reward_value=config.clip_reward_value,
        cliprange=config.cliprange,
        cliprange_value=config.cliprange_value,
        gamma=config.gamma,
        lam=config.lam,
        remove_unused_columns=False,
        log_with="wandb" if wandb.run else None,
    )
    
    trainer = PPOTrainer(
        config=ppo_config,
        model=model,
        ref_model=ref_model,
        tokenizer=tokenizer,
    )
    
    return trainer

def format_prompt_for_generation(prompt: str) -> str:
    """格式化提示用于生成"""
    return f"### Human: {prompt}\n### Assistant:"

def run_ppo_training():
    """主要的PPO训练循环"""
    print("🚀 Starting PPO RLHF Training...")
    
    # 初始化wandb
    wandb.init(
        project="rlhf-teacher-training",
        config=vars(RLHFConfig),
        name="ppo-teacher-rlhf"
    )
    
    config = RLHFConfig()
    
    # 准备模型
    model, ref_model, tokenizer = prepare_teacher_model(config)
    
    # 创建PPO训练器
    ppo_trainer = create_ppo_trainer(model, ref_model, tokenizer, config)
    
    # 加载奖励模型
    reward_model = RewardModelWrapper(config.reward_model_name)
    
    # 加载数据集
    dataset = load_preference_dataset()
    
    print(f"📊 Training on {len(dataset)} prompts")
    print(f"🎯 Target episodes: {config.total_episodes}")
    
    # 训练循环
    for episode in range(config.total_episodes):
        # 随机采样prompts
        batch_prompts = np.random.choice(
            dataset["prompt"], 
            size=config.batch_size, 
            replace=False
        ).tolist()
        
        # 格式化输入
        formatted_prompts = [format_prompt_for_generation(p) for p in batch_prompts]
        
        # 生成响应
        prompt_tensors = []
        for prompt in formatted_prompts:
            prompt_tensor = tokenizer.encode(
                prompt, 
                return_tensors="pt", 
                padding=False, 
                truncation=True,
                max_length=256
            ).squeeze()
            prompt_tensors.append(prompt_tensor)
        
        # 批量生成
        response_tensors = []
        with torch.no_grad():
            for prompt_tensor in prompt_tensors:
                prompt_tensor = prompt_tensor.unsqueeze(0).to(model.device)
                
                response = ppo_trainer.generate(
                    prompt_tensor,
                    max_new_tokens=config.max_new_tokens,
                    temperature=config.temperature,
                    top_p=config.top_p,
                    do_sample=config.do_sample,
                    pad_token_id=tokenizer.eos_token_id,
                )
                
                # 只保留新生成的部分
                response = response.squeeze()[prompt_tensor.shape[1]:]
                response_tensors.append(response)
        
        # 解码响应
        responses = [
            tokenizer.decode(r, skip_special_tokens=True).strip() 
            for r in response_tensors
        ]
        
        # 计算奖励
        rewards = reward_model.get_reward(batch_prompts, responses)
        rewards = [torch.tensor(r, dtype=torch.float) for r in rewards]
        
        # PPO训练步骤
        stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
        
        # 记录统计信息
        ppo_trainer.log_stats(
            stats, 
            batch_prompts,
            [list(p) + list(r) for p, r in zip(prompt_tensors, response_tensors)],
            rewards
        )
        
        # 打印进度
        if episode % 10 == 0:
            mean_reward = np.mean([r.item() for r in rewards])
            print(f"📈 Episode {episode}: Mean Reward = {mean_reward:.4f}")
            
            # 记录到wandb
            wandb.log({
                "episode": episode,
                "mean_reward": mean_reward,
                "kl_divergence": stats.get("objective/kl", 0),
                "policy_loss": stats.get("ppo/loss/policy", 0),
                "value_loss": stats.get("ppo/loss/value", 0),
            })
        
        # 评估模型
        if episode % config.eval_freq == 0 and episode > 0:
            evaluate_model(ppo_trainer.model, tokenizer, episode)
        
        # 保存检查点
        if episode % config.save_freq == 0 and episode > 0:
            save_checkpoint(ppo_trainer.model, tokenizer, config.output_dir, episode)
    
    # 保存最终模型
    print("💾 Saving final RLHF model...")
    ppo_trainer.model.save_pretrained(config.output_dir)
    tokenizer.save_pretrained(config.output_dir)
    
    wandb.finish()
    print("✅ RLHF training completed!")

def evaluate_model(model, tokenizer, episode):
    """评估模型性能"""
    print(f"🧪 Evaluating model at episode {episode}...")
    
    test_prompts = [
        "Create an advertisement for a revolutionary smartphone with AI capabilities",
        "Write marketing copy for an eco-friendly clothing brand",
        "Generate a slogan for a fitness app targeting busy professionals",
    ]
    
    model.eval()
    results = []
    
    for prompt in test_prompts:
        formatted_prompt = format_prompt_for_generation(prompt)
        inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=150,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_text = response[len(formatted_prompt):].strip()
        
        results.append({
            "prompt": prompt,
            "response": generated_text
        })
        
        print(f"🔍 Prompt: {prompt}")
        print(f"📝 Response: {generated_text}")
        print("-" * 80)
    
    model.train()
    return results

def save_checkpoint(model, tokenizer, output_dir, episode):
    """保存训练检查点"""
    checkpoint_dir = f"{output_dir}/checkpoint-{episode}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    model.save_pretrained(checkpoint_dir)
    tokenizer.save_pretrained(checkpoint_dir)
    
    print(f"💾 Checkpoint saved to {checkpoint_dir}")

def load_checkpoint_and_continue(checkpoint_path):
    """从检查点继续训练"""
    print(f"📥 Loading checkpoint from {checkpoint_path}")
    
    # 实现检查点恢复逻辑
    pass

if __name__ == "__main__":
    # 设置环境变量
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"  # 多GPU设置
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    # 检查GPU资源
    if torch.cuda.is_available():
        print(f"🔥 Using {torch.cuda.device_count()} GPUs")
        for i in range(torch.cuda.device_count()):
            print(f"   GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        raise RuntimeError("❌ CUDA not available! RLHF requires GPU.")
    
    # 开始训练
    run_ppo_training()