Spaces:
Running
Running
Commit
·
c1c9e88
1
Parent(s):
61bd54d
ADD: LLM SFT, RLHF and Distillation
Browse files- README.md +6 -5
- fine_tune_llm/ppo_tune_llm.py +0 -19
- fine_tune_llm/reward_model.py +0 -21
- fine_tune_llm/sft_llm_train.py +0 -41
- {fine_tune_stablediffusion → fully_fine_tune_stablediffusion}/train_lora.py +0 -0
- lauguage_model_fine_tuning/accelerate_config.yaml +23 -0
- lauguage_model_fine_tuning/distillation/distill_llm.py +485 -0
- lauguage_model_fine_tuning/distillation/eval_compare_teacher_student.py +168 -0
- lauguage_model_fine_tuning/distillation/launch_distill.sh +60 -0
- lauguage_model_fine_tuning/eval_ppo_teacher.py +170 -0
- lauguage_model_fine_tuning/launch_ppo_fine_tune_teacher.sh +63 -0
- lauguage_model_fine_tuning/launch_supervised_fine_tune_teacher.sh +28 -0
- lauguage_model_fine_tuning/merge_teacher_model.py +116 -0
- lauguage_model_fine_tuning/ppo_fine_tune_teacher.py +459 -0
- lauguage_model_fine_tuning/sft_teacher.py +276 -0
- requirements.txt +51 -13
README.md
CHANGED
@@ -45,9 +45,6 @@ fine tune a trained model: --pretrained_model_name_or_path="./nyc-ad-model/check
|
|
45 |
|
46 |
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
47 |
|
48 |
-
import torch
|
49 |
-
torch.cuda.empty_cache()
|
50 |
-
torch.cuda.reset_peak_memory_stats()
|
51 |
|
52 |
pipeline:
|
53 |
# 1 Fully Fine‑tune image model with ZeRO
|
@@ -73,7 +70,6 @@ python ppo_tune.py
|
|
73 |
python rag_infer.py
|
74 |
|
75 |
|
76 |
-
|
77 |
system flow:
|
78 |
input: business or product description text
|
79 |
1. 根据input用RAG取embedding
|
@@ -81,4 +77,9 @@ input: business or product description text
|
|
81 |
2. GPT‑OSS 基于选中文案生成 扩展视觉提示词(主体、配色、镜头、艺术风格)
|
82 |
3. stablediffusion model 生成 4 张草图(可选 ControlNet-Layout/Logo 插入)
|
83 |
4. 返回4张海报+后处理
|
84 |
-
output: an advertisement sentence and post image
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
47 |
|
|
|
|
|
|
|
48 |
|
49 |
pipeline:
|
50 |
# 1 Fully Fine‑tune image model with ZeRO
|
|
|
70 |
python rag_infer.py
|
71 |
|
72 |
|
|
|
73 |
system flow:
|
74 |
input: business or product description text
|
75 |
1. 根据input用RAG取embedding
|
|
|
77 |
2. GPT‑OSS 基于选中文案生成 扩展视觉提示词(主体、配色、镜头、艺术风格)
|
78 |
3. stablediffusion model 生成 4 张草图(可选 ControlNet-Layout/Logo 插入)
|
79 |
4. 返回4张海报+后处理
|
80 |
+
output: an advertisement sentence and post image
|
81 |
+
|
82 |
+
|
83 |
+
design details:
|
84 |
+
LoRA fine tune teacher OSS 120B model using smangrul/ad-copy-generation (广告文案生成)
|
85 |
+
LoRA distill knowledge to OSS 20B model
|
fine_tune_llm/ppo_tune_llm.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
from trl import PPOTrainer, PPOConfig
|
2 |
-
from peft import PeftModel
|
3 |
-
import torch, random, json, glob
|
4 |
-
from diffusers import StableDiffusionPipeline
|
5 |
-
from reward_model import CLIPModel, CLIPProcessor
|
6 |
-
|
7 |
-
rm=CLIPModel.from_pretrained("rm").eval().half().cuda()
|
8 |
-
proc=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
9 |
-
pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
|
10 |
-
ppo_cfg=PPOConfig(batch_size=1,learning_rate=1e-6,target_kl=0.2)
|
11 |
-
trainer=PPOTrainer(model=pipe.unet, reward_model=rm, config=ppo_cfg)
|
12 |
-
|
13 |
-
prompts=[l.strip() for l in open("prompt.txt")]
|
14 |
-
for step in range(500):
|
15 |
-
p=random.choice(prompts)
|
16 |
-
img=pipe(p,num_inference_steps=20).images[0]
|
17 |
-
reward=rm(**proc(text=p,images=img,return_tensors="pt").to("cuda")).logits[0,0].item()
|
18 |
-
trainer.step(prompts=[p], rewards=[reward])
|
19 |
-
pipe.save_pretrained("nyc-ad-model-rlhf")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tune_llm/reward_model.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
from transformers import CLIPProcessor, CLIPModel, TrainingArguments, Trainer
|
2 |
-
import datasets, torch, json, glob
|
3 |
-
|
4 |
-
model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
5 |
-
processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
6 |
-
|
7 |
-
data=[]
|
8 |
-
for f in glob.glob("human_prefs/*.json"):
|
9 |
-
j=json.load(open(f)); data.append(j) # {"prompt":…, "good":img_path, "bad":img_path}
|
10 |
-
|
11 |
-
dataset=datasets.Dataset.from_list(data)
|
12 |
-
|
13 |
-
def preprocess(ex):
|
14 |
-
inputs=processor(text=[ex["prompt"]*2], images=[ex["good"],ex["bad"]], return_tensors="pt")
|
15 |
-
inputs["labels"]=torch.tensor([1,0])
|
16 |
-
return inputs
|
17 |
-
|
18 |
-
dataset=dataset.map(preprocess,remove_columns=dataset.column_names)
|
19 |
-
args=TrainingArguments("rm_ckpt",per_device_train_batch_size=2,fp16=True,learning_rate=5e-6,epochs=3)
|
20 |
-
trainer=Trainer(model,args,train_dataset=dataset)
|
21 |
-
trainer.train(); model.save_pretrained("rm")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tune_llm/sft_llm_train.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
import torch, json
|
2 |
-
from datasets import load_dataset, Dataset
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
4 |
-
from peft import get_peft_model, LoraConfig, TaskType
|
5 |
-
|
6 |
-
# Load your dataset
|
7 |
-
data = [json.loads(l) for l in open("data/sft_data.jsonl")]
|
8 |
-
dataset = Dataset.from_list(data)
|
9 |
-
|
10 |
-
# Load model & tokenizer
|
11 |
-
base_model = "meta-llama/Llama-2-7b-hf" # Or use Mistral, Falcon, etc.
|
12 |
-
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
|
13 |
-
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16)
|
14 |
-
|
15 |
-
# Add LoRA (optional)
|
16 |
-
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.05,
|
17 |
-
target_modules=["q_proj", "v_proj"])
|
18 |
-
model = get_peft_model(model, lora_config)
|
19 |
-
|
20 |
-
# Preprocessing
|
21 |
-
def tokenize(example):
|
22 |
-
prompt = f"### Instruction:\n{example['prompt']}\n\n### Response:\n{example['output']}"
|
23 |
-
return tokenizer(prompt, truncation=True, max_length=512, padding="max_length")
|
24 |
-
dataset = dataset.map(tokenize, remove_columns=dataset.column_names)
|
25 |
-
|
26 |
-
# Training setup
|
27 |
-
args = TrainingArguments(
|
28 |
-
output_dir="./sft-model",
|
29 |
-
per_device_train_batch_size=2,
|
30 |
-
num_train_epochs=3,
|
31 |
-
fp16=True,
|
32 |
-
evaluation_strategy="no",
|
33 |
-
save_strategy="epoch",
|
34 |
-
logging_steps=20,
|
35 |
-
learning_rate=2e-5,
|
36 |
-
report_to="tensorboard",
|
37 |
-
)
|
38 |
-
|
39 |
-
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
40 |
-
trainer = Trainer(model=model, args=args, train_dataset=dataset, data_collator=data_collator)
|
41 |
-
trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{fine_tune_stablediffusion → fully_fine_tune_stablediffusion}/train_lora.py
RENAMED
File without changes
|
lauguage_model_fine_tuning/accelerate_config.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# accelerate_config.yaml - 多GPU训练配置
|
2 |
+
|
3 |
+
compute_environment: LOCAL_MACHINE
|
4 |
+
distributed_type: MULTI_GPU
|
5 |
+
downcast_bf16: 'no'
|
6 |
+
gpu_ids: all
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: fp16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 4 # 根据GPU数量调整
|
12 |
+
rdzv_backend: static
|
13 |
+
same_network: true
|
14 |
+
tpu_env: []
|
15 |
+
tpu_use_cluster: false
|
16 |
+
tpu_use_sudo: false
|
17 |
+
use_cpu: false
|
18 |
+
|
19 |
+
# RLHF特定设置
|
20 |
+
gradient_accumulation_steps: 8
|
21 |
+
gradient_clipping: 1.0
|
22 |
+
learning_rate: 1e-5
|
23 |
+
dataloader_drop_last: true
|
lauguage_model_fine_tuning/distillation/distill_llm.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Teacher-Student知识蒸馏脚本
|
4 |
+
将经过SFT+PPO RLHF的Teacher模型蒸馏到更小的Student模型
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
from transformers import (
|
12 |
+
AutoModelForCausalLM,
|
13 |
+
AutoTokenizer,
|
14 |
+
TrainingArguments,
|
15 |
+
Trainer,
|
16 |
+
DataCollatorForLanguageModeling,
|
17 |
+
logging,
|
18 |
+
)
|
19 |
+
from datasets import load_dataset, Dataset as HFDataset
|
20 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
21 |
+
import numpy as np
|
22 |
+
import wandb
|
23 |
+
from typing import Dict, List, Any, Optional
|
24 |
+
import json
|
25 |
+
from tqdm import tqdm
|
26 |
+
import warnings
|
27 |
+
|
28 |
+
warnings.filterwarnings("ignore")
|
29 |
+
logging.set_verbosity(logging.CRITICAL)
|
30 |
+
|
31 |
+
class DistillationConfig:
|
32 |
+
"""蒸馏训练配置"""
|
33 |
+
# 模型路径
|
34 |
+
teacher_model_path = "./rlhf_teacher_model" # RLHF后的Teacher模型
|
35 |
+
student_model_name = "microsoft/DialoGPT-medium" # 替换为实际的OpenAI OSS 20B模型
|
36 |
+
|
37 |
+
# 蒸馏参数
|
38 |
+
temperature = 4.0 # 蒸馏温度
|
39 |
+
alpha = 0.7 # 蒸馏损失权重
|
40 |
+
beta = 0.3 # 学生损失权重
|
41 |
+
gamma = 0.1 # 特征匹配损失权重
|
42 |
+
|
43 |
+
# 训练参数
|
44 |
+
learning_rate = 1e-4
|
45 |
+
num_train_epochs = 3
|
46 |
+
per_device_train_batch_size = 2
|
47 |
+
per_device_eval_batch_size = 4
|
48 |
+
gradient_accumulation_steps = 8
|
49 |
+
warmup_ratio = 0.1
|
50 |
+
weight_decay = 0.01
|
51 |
+
logging_steps = 50
|
52 |
+
eval_steps = 500
|
53 |
+
save_steps = 1000
|
54 |
+
|
55 |
+
# LoRA配置(为Student模型添加LoRA以提高训练效率)
|
56 |
+
use_lora = True
|
57 |
+
lora_r = 32
|
58 |
+
lora_alpha = 64
|
59 |
+
lora_dropout = 0.1
|
60 |
+
|
61 |
+
# 数据配置
|
62 |
+
max_length = 512
|
63 |
+
num_distill_samples = 10000 # 用于蒸馏的样本数量
|
64 |
+
|
65 |
+
# 输出配置
|
66 |
+
output_dir = "./distilled_student_model"
|
67 |
+
run_name = "teacher-student-distillation"
|
68 |
+
|
69 |
+
class DistillationDataset(Dataset):
|
70 |
+
"""蒸馏数据集类"""
|
71 |
+
|
72 |
+
def __init__(self, teacher_outputs: List[Dict], tokenizer, max_length: int = 512):
|
73 |
+
self.data = teacher_outputs
|
74 |
+
self.tokenizer = tokenizer
|
75 |
+
self.max_length = max_length
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.data)
|
79 |
+
|
80 |
+
def __getitem__(self, idx):
|
81 |
+
item = self.data[idx]
|
82 |
+
|
83 |
+
# 构建完整的输入-输出序列
|
84 |
+
full_text = f"### Human: {item['prompt']}\n### Assistant: {item['response']}"
|
85 |
+
|
86 |
+
# Tokenize
|
87 |
+
encoded = self.tokenizer(
|
88 |
+
full_text,
|
89 |
+
truncation=True,
|
90 |
+
padding="max_length",
|
91 |
+
max_length=self.max_length,
|
92 |
+
return_tensors="pt"
|
93 |
+
)
|
94 |
+
|
95 |
+
return {
|
96 |
+
"input_ids": encoded["input_ids"].squeeze(),
|
97 |
+
"attention_mask": encoded["attention_mask"].squeeze(),
|
98 |
+
"teacher_logits": torch.tensor(item["teacher_logits"], dtype=torch.float),
|
99 |
+
"labels": encoded["input_ids"].squeeze()
|
100 |
+
}
|
101 |
+
|
102 |
+
class KnowledgeDistillationTrainer(Trainer):
|
103 |
+
"""知识蒸馏训练器"""
|
104 |
+
|
105 |
+
def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7, beta=0.3, gamma=0.1, **kwargs):
|
106 |
+
super().__init__(model=student_model, **kwargs)
|
107 |
+
self.teacher_model = teacher_model
|
108 |
+
self.teacher_model.eval() # 冻结Teacher模型
|
109 |
+
|
110 |
+
self.temperature = temperature
|
111 |
+
self.alpha = alpha # 蒸馏损失权重
|
112 |
+
self.beta = beta # 学生损失权重
|
113 |
+
self.gamma = gamma # 特征匹配损失权重
|
114 |
+
|
115 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
116 |
+
"""计算蒸馏损失"""
|
117 |
+
|
118 |
+
labels = inputs.get("labels")
|
119 |
+
teacher_logits = inputs.get("teacher_logits").to(model.device)
|
120 |
+
|
121 |
+
# Student模型前向传播
|
122 |
+
student_outputs = model(**{k: v for k, v in inputs.items() if k not in ["teacher_logits"]})
|
123 |
+
student_logits = student_outputs.logits
|
124 |
+
|
125 |
+
# 计算各种损失
|
126 |
+
losses = {}
|
127 |
+
|
128 |
+
# 1. 标准语言模型损失 (学生模型自己的损失)
|
129 |
+
if labels is not None:
|
130 |
+
shift_logits = student_logits[..., :-1, :].contiguous()
|
131 |
+
shift_labels = labels[..., 1:].contiguous()
|
132 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
133 |
+
student_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
134 |
+
losses["student_loss"] = student_loss
|
135 |
+
|
136 |
+
# 2. 蒸馏损失 (KL散度)
|
137 |
+
if teacher_logits is not None:
|
138 |
+
# 确保维度匹配
|
139 |
+
if teacher_logits.shape != student_logits.shape:
|
140 |
+
min_seq_len = min(teacher_logits.shape[1], student_logits.shape[1])
|
141 |
+
teacher_logits = teacher_logits[:, :min_seq_len, :]
|
142 |
+
student_logits_for_distill = student_logits[:, :min_seq_len, :]
|
143 |
+
else:
|
144 |
+
student_logits_for_distill = student_logits
|
145 |
+
|
146 |
+
# 计算软标签概率
|
147 |
+
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
|
148 |
+
student_log_probs = F.log_softmax(student_logits_for_distill / self.temperature, dim=-1)
|
149 |
+
|
150 |
+
# KL散度损失
|
151 |
+
distill_loss = F.kl_div(
|
152 |
+
student_log_probs,
|
153 |
+
teacher_probs,
|
154 |
+
reduction="batchmean"
|
155 |
+
) * (self.temperature ** 2)
|
156 |
+
|
157 |
+
losses["distill_loss"] = distill_loss
|
158 |
+
|
159 |
+
# 3. 组合总损失
|
160 |
+
total_loss = 0
|
161 |
+
if "student_loss" in losses:
|
162 |
+
total_loss += self.beta * losses["student_loss"]
|
163 |
+
if "distill_loss" in losses:
|
164 |
+
total_loss += self.alpha * losses["distill_loss"]
|
165 |
+
|
166 |
+
# 记录各项损失
|
167 |
+
self.log({
|
168 |
+
"train/total_loss": total_loss.item(),
|
169 |
+
"train/student_loss": losses.get("student_loss", 0).item() if "student_loss" in losses else 0,
|
170 |
+
"train/distill_loss": losses.get("distill_loss", 0).item() if "distill_loss" in losses else 0,
|
171 |
+
})
|
172 |
+
|
173 |
+
return (total_loss, student_outputs) if return_outputs else total_loss
|
174 |
+
|
175 |
+
def prepare_student_model(config: DistillationConfig):
|
176 |
+
"""准备Student模型"""
|
177 |
+
print("🎓 Preparing student model...")
|
178 |
+
|
179 |
+
# 加载Student基础模型
|
180 |
+
student_model = AutoModelForCausalLM.from_pretrained(
|
181 |
+
config.student_model_name,
|
182 |
+
torch_dtype=torch.float16,
|
183 |
+
device_map="auto",
|
184 |
+
trust_remote_code=True,
|
185 |
+
)
|
186 |
+
|
187 |
+
# 添加LoRA(可选,用于高效训练)
|
188 |
+
if config.use_lora:
|
189 |
+
print("🔧 Adding LoRA to student model...")
|
190 |
+
lora_config = LoraConfig(
|
191 |
+
task_type=TaskType.CAUSAL_LM,
|
192 |
+
inference_mode=False,
|
193 |
+
r=config.lora_r,
|
194 |
+
lora_alpha=config.lora_alpha,
|
195 |
+
lora_dropout=config.lora_dropout,
|
196 |
+
target_modules=[
|
197 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
198 |
+
"gate_proj", "up_proj", "down_proj",
|
199 |
+
]
|
200 |
+
)
|
201 |
+
student_model = get_peft_model(student_model, lora_config)
|
202 |
+
student_model.print_trainable_parameters()
|
203 |
+
|
204 |
+
return student_model
|
205 |
+
|
206 |
+
def load_teacher_model(config: DistillationConfig):
|
207 |
+
"""加载Teacher模型"""
|
208 |
+
print("👨🏫 Loading teacher model...")
|
209 |
+
|
210 |
+
teacher_model = AutoModelForCausalLM.from_pretrained(
|
211 |
+
config.teacher_model_path,
|
212 |
+
torch_dtype=torch.float16,
|
213 |
+
device_map="auto",
|
214 |
+
trust_remote_code=True,
|
215 |
+
)
|
216 |
+
teacher_model.eval()
|
217 |
+
|
218 |
+
return teacher_model
|
219 |
+
|
220 |
+
def generate_distillation_data(teacher_model, tokenizer, config: DistillationConfig):
|
221 |
+
"""生成蒸馏数据"""
|
222 |
+
print("📊 Generating distillation dataset...")
|
223 |
+
|
224 |
+
# 加载提示数据集
|
225 |
+
dataset_sources = [
|
226 |
+
"smangrul/ad-copy-generation",
|
227 |
+
# 可以添加更多数据源
|
228 |
+
]
|
229 |
+
|
230 |
+
all_prompts = []
|
231 |
+
for source in dataset_sources:
|
232 |
+
try:
|
233 |
+
ds = load_dataset(source, split="train")
|
234 |
+
# 提取提示词
|
235 |
+
for item in ds:
|
236 |
+
if "conversations" in item and len(item["conversations"]) > 0:
|
237 |
+
prompt = item["conversations"][0].get("value", "")
|
238 |
+
if len(prompt.strip()) > 10:
|
239 |
+
all_prompts.append(prompt.strip())
|
240 |
+
except Exception as e:
|
241 |
+
print(f"⚠️ Error loading {source}: {e}")
|
242 |
+
|
243 |
+
# 限制样本数量
|
244 |
+
if len(all_prompts) > config.num_distill_samples:
|
245 |
+
all_prompts = all_prompts[:config.num_distill_samples]
|
246 |
+
|
247 |
+
print(f"📝 Generating responses for {len(all_prompts)} prompts...")
|
248 |
+
|
249 |
+
distillation_data = []
|
250 |
+
teacher_model.eval()
|
251 |
+
|
252 |
+
with torch.no_grad():
|
253 |
+
for i, prompt in enumerate(tqdm(all_prompts, desc="Generating teacher responses")):
|
254 |
+
try:
|
255 |
+
# 格式化输入
|
256 |
+
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
257 |
+
inputs = tokenizer(
|
258 |
+
formatted_prompt,
|
259 |
+
return_tensors="pt",
|
260 |
+
truncation=True,
|
261 |
+
max_length=config.max_length // 2
|
262 |
+
).to(teacher_model.device)
|
263 |
+
|
264 |
+
# 生成响应
|
265 |
+
outputs = teacher_model.generate(
|
266 |
+
**inputs,
|
267 |
+
max_new_tokens=200,
|
268 |
+
temperature=0.7,
|
269 |
+
top_p=0.9,
|
270 |
+
do_sample=True,
|
271 |
+
pad_token_id=tokenizer.eos_token_id,
|
272 |
+
return_dict_in_generate=True,
|
273 |
+
output_scores=True
|
274 |
+
)
|
275 |
+
|
276 |
+
# 解码响应
|
277 |
+
generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:]
|
278 |
+
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
279 |
+
|
280 |
+
# 获取Teacher的logits
|
281 |
+
full_text = f"### Human: {prompt}\n### Assistant: {response}"
|
282 |
+
full_inputs = tokenizer(
|
283 |
+
full_text,
|
284 |
+
return_tensors="pt",
|
285 |
+
truncation=True,
|
286 |
+
max_length=config.max_length
|
287 |
+
).to(teacher_model.device)
|
288 |
+
|
289 |
+
teacher_outputs = teacher_model(**full_inputs)
|
290 |
+
teacher_logits = teacher_outputs.logits.cpu().numpy()
|
291 |
+
|
292 |
+
distillation_data.append({
|
293 |
+
"prompt": prompt,
|
294 |
+
"response": response,
|
295 |
+
"teacher_logits": teacher_logits.tolist()
|
296 |
+
})
|
297 |
+
|
298 |
+
# 定期保存中间结果
|
299 |
+
if (i + 1) % 100 == 0:
|
300 |
+
print(f"Generated {i + 1}/{len(all_prompts)} samples")
|
301 |
+
|
302 |
+
except Exception as e:
|
303 |
+
print(f"⚠️ Error generating for prompt {i}: {e}")
|
304 |
+
continue
|
305 |
+
|
306 |
+
print(f"✅ Generated {len(distillation_data)} teacher-student pairs")
|
307 |
+
|
308 |
+
# 保存蒸馏数据
|
309 |
+
with open("distillation_data.json", "w", encoding="utf-8") as f:
|
310 |
+
json.dump(distillation_data, f, ensure_ascii=False, indent=2)
|
311 |
+
|
312 |
+
return distillation_data
|
313 |
+
|
314 |
+
def create_data_collator(tokenizer):
|
315 |
+
"""创建数据整理器"""
|
316 |
+
return DataCollatorForLanguageModeling(
|
317 |
+
tokenizer=tokenizer,
|
318 |
+
mlm=False,
|
319 |
+
pad_to_multiple_of=8
|
320 |
+
)
|
321 |
+
|
322 |
+
def run_distillation():
|
323 |
+
"""主要的蒸馏训练流程"""
|
324 |
+
print("🚀 Starting Teacher-Student Distillation...")
|
325 |
+
|
326 |
+
config = DistillationConfig()
|
327 |
+
|
328 |
+
# 初始化wandb
|
329 |
+
wandb.init(
|
330 |
+
project="teacher-student-distillation",
|
331 |
+
config=vars(config),
|
332 |
+
name=config.run_name
|
333 |
+
)
|
334 |
+
|
335 |
+
# 加载tokenizer
|
336 |
+
tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
|
337 |
+
if tokenizer.pad_token is None:
|
338 |
+
tokenizer.pad_token = tokenizer.eos_token
|
339 |
+
|
340 |
+
# 加载模型
|
341 |
+
teacher_model = load_teacher_model(config)
|
342 |
+
student_model = prepare_student_model(config)
|
343 |
+
|
344 |
+
# 生成蒸馏数据
|
345 |
+
if os.path.exists("distillation_data.json"):
|
346 |
+
print("📂 Loading existing distillation data...")
|
347 |
+
with open("distillation_data.json", "r", encoding="utf-8") as f:
|
348 |
+
distillation_data = json.load(f)
|
349 |
+
else:
|
350 |
+
distillation_data = generate_distillation_data(teacher_model, tokenizer, config)
|
351 |
+
|
352 |
+
# 创建数据集
|
353 |
+
train_size = int(0.9 * len(distillation_data))
|
354 |
+
train_data = distillation_data[:train_size]
|
355 |
+
eval_data = distillation_data[train_size:]
|
356 |
+
|
357 |
+
train_dataset = DistillationDataset(train_data, tokenizer, config.max_length)
|
358 |
+
eval_dataset = DistillationDataset(eval_data, tokenizer, config.max_length)
|
359 |
+
|
360 |
+
print(f"📊 Training samples: {len(train_dataset)}")
|
361 |
+
print(f"📊 Evaluation samples: {len(eval_dataset)}")
|
362 |
+
|
363 |
+
# 训练参数
|
364 |
+
training_args = TrainingArguments(
|
365 |
+
output_dir=config.output_dir,
|
366 |
+
overwrite_output_dir=True,
|
367 |
+
num_train_epochs=config.num_train_epochs,
|
368 |
+
per_device_train_batch_size=config.per_device_train_batch_size,
|
369 |
+
per_device_eval_batch_size=config.per_device_eval_batch_size,
|
370 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
371 |
+
learning_rate=config.learning_rate,
|
372 |
+
weight_decay=config.weight_decay,
|
373 |
+
warmup_ratio=config.warmup_ratio,
|
374 |
+
logging_steps=config.logging_steps,
|
375 |
+
eval_steps=config.eval_steps,
|
376 |
+
save_steps=config.save_steps,
|
377 |
+
evaluation_strategy="steps",
|
378 |
+
save_strategy="steps",
|
379 |
+
load_best_model_at_end=True,
|
380 |
+
metric_for_best_model="eval_loss",
|
381 |
+
greater_is_better=False,
|
382 |
+
report_to="wandb",
|
383 |
+
run_name=config.run_name,
|
384 |
+
fp16=True,
|
385 |
+
dataloader_pin_memory=False,
|
386 |
+
remove_unused_columns=False,
|
387 |
+
group_by_length=True,
|
388 |
+
)
|
389 |
+
|
390 |
+
# 创建数据整理器
|
391 |
+
data_collator = create_data_collator(tokenizer)
|
392 |
+
|
393 |
+
# 创建蒸馏训练器
|
394 |
+
trainer = KnowledgeDistillationTrainer(
|
395 |
+
teacher_model=teacher_model,
|
396 |
+
student_model=student_model,
|
397 |
+
args=training_args,
|
398 |
+
train_dataset=train_dataset,
|
399 |
+
eval_dataset=eval_dataset,
|
400 |
+
data_collator=data_collator,
|
401 |
+
tokenizer=tokenizer,
|
402 |
+
temperature=config.temperature,
|
403 |
+
alpha=config.alpha,
|
404 |
+
beta=config.beta,
|
405 |
+
gamma=config.gamma,
|
406 |
+
)
|
407 |
+
|
408 |
+
# 开始训练
|
409 |
+
print("🔥 Starting distillation training...")
|
410 |
+
trainer.train()
|
411 |
+
|
412 |
+
# 保存最终模型
|
413 |
+
print("💾 Saving distilled student model...")
|
414 |
+
trainer.save_model()
|
415 |
+
tokenizer.save_pretrained(config.output_dir)
|
416 |
+
|
417 |
+
# 评估模型
|
418 |
+
print("🧪 Evaluating distilled model...")
|
419 |
+
evaluate_distilled_model(trainer.model, tokenizer, config)
|
420 |
+
|
421 |
+
wandb.finish()
|
422 |
+
print("✅ Distillation training completed!")
|
423 |
+
|
424 |
+
def evaluate_distilled_model(model, tokenizer, config: DistillationConfig):
|
425 |
+
"""评估蒸馏后的模型"""
|
426 |
+
print("📊 Evaluating distilled student model...")
|
427 |
+
|
428 |
+
test_prompts = [
|
429 |
+
"Create an advertisement for a revolutionary AI-powered fitness tracker",
|
430 |
+
"Write marketing copy for an eco-friendly electric vehicle",
|
431 |
+
"Generate a slogan for a productivity app for remote workers",
|
432 |
+
"Create ad copy for a sustainable fashion brand targeting millennials",
|
433 |
+
"Write promotional content for a mental health app",
|
434 |
+
]
|
435 |
+
|
436 |
+
model.eval()
|
437 |
+
results = []
|
438 |
+
|
439 |
+
for prompt in test_prompts:
|
440 |
+
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
441 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
442 |
+
|
443 |
+
with torch.no_grad():
|
444 |
+
outputs = model.generate(
|
445 |
+
**inputs,
|
446 |
+
max_new_tokens=150,
|
447 |
+
temperature=0.7,
|
448 |
+
top_p=0.9,
|
449 |
+
do_sample=True,
|
450 |
+
pad_token_id=tokenizer.eos_token_id,
|
451 |
+
)
|
452 |
+
|
453 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
454 |
+
generated_text = response[len(formatted_prompt):].strip()
|
455 |
+
|
456 |
+
results.append({
|
457 |
+
"prompt": prompt,
|
458 |
+
"response": generated_text
|
459 |
+
})
|
460 |
+
|
461 |
+
print(f"\n🔍 Prompt: {prompt}")
|
462 |
+
print(f"📝 Student Response: {generated_text}")
|
463 |
+
print("-" * 80)
|
464 |
+
|
465 |
+
# 保存评估结果
|
466 |
+
with open(f"{config.output_dir}/evaluation_results.json", "w", encoding="utf-8") as f:
|
467 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
468 |
+
|
469 |
+
return results
|
470 |
+
|
471 |
+
if __name__ == "__main__":
|
472 |
+
# 设置环境变量
|
473 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
474 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
475 |
+
|
476 |
+
# 检查GPU
|
477 |
+
if torch.cuda.is_available():
|
478 |
+
print(f"🔥 Using {torch.cuda.device_count()} GPUs")
|
479 |
+
for i in range(torch.cuda.device_count()):
|
480 |
+
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
481 |
+
else:
|
482 |
+
print("⚠️ Warning: No GPU available, using CPU (very slow)")
|
483 |
+
|
484 |
+
# 开始蒸馏训练
|
485 |
+
run_distillation()
|
lauguage_model_fine_tuning/distillation/eval_compare_teacher_student.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Teacher-Student模型性能比较脚本
|
4 |
+
比较RLHF Teacher模型和蒸馏后的Student模型的性能
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
+
from typing import List, Dict, Any
|
13 |
+
import numpy as np
|
14 |
+
from datetime import datetime
|
15 |
+
|
16 |
+
class ModelComparator:
|
17 |
+
def __init__(self, teacher_path: str, student_path: str):
|
18 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
print("📥 Loading Teacher model...")
|
21 |
+
self.teacher_model = AutoModelForCausalLM.from_pretrained(
|
22 |
+
teacher_path,
|
23 |
+
torch_dtype=torch.float16,
|
24 |
+
device_map="auto"
|
25 |
+
)
|
26 |
+
self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_path)
|
27 |
+
|
28 |
+
print("📥 Loading Student model...")
|
29 |
+
self.student_model = AutoModelForCausalLM.from_pretrained(
|
30 |
+
student_path,
|
31 |
+
torch_dtype=torch.float16,
|
32 |
+
device_map="auto"
|
33 |
+
)
|
34 |
+
self.student_tokenizer = AutoTokenizer.from_pretrained(student_path)
|
35 |
+
|
36 |
+
# 设置pad tokens
|
37 |
+
for tokenizer in [self.teacher_tokenizer, self.student_tokenizer]:
|
38 |
+
if tokenizer.pad_token is None:
|
39 |
+
tokenizer.pad_token = tokenizer.eos_token
|
40 |
+
|
41 |
+
def generate_response(self, model, tokenizer, prompt: str, **kwargs) -> Dict[str, Any]:
|
42 |
+
"""生成响应并记录性能指标"""
|
43 |
+
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
44 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
45 |
+
|
46 |
+
generation_config = {
|
47 |
+
"max_new_tokens": 200,
|
48 |
+
"temperature": 0.7,
|
49 |
+
"top_p": 0.9,
|
50 |
+
"do_sample": True,
|
51 |
+
"pad_token_id": tokenizer.eos_token_id,
|
52 |
+
**kwargs
|
53 |
+
}
|
54 |
+
|
55 |
+
# 测量生成时间
|
56 |
+
start_time = time.time()
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
outputs = model.generate(**inputs, **generation_config)
|
60 |
+
|
61 |
+
generation_time = time.time() - start_time
|
62 |
+
|
63 |
+
# 解码响应
|
64 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
65 |
+
generated_text = response[len(formatted_prompt):].strip()
|
66 |
+
|
67 |
+
# 计算tokens数量
|
68 |
+
generated_tokens = len(tokenizer.encode(generated_text))
|
69 |
+
|
70 |
+
return {
|
71 |
+
"response": generated_text,
|
72 |
+
"generation_time": generation_time,
|
73 |
+
"tokens_generated": generated_tokens,
|
74 |
+
"tokens_per_second": generated_tokens / generation_time if generation_time > 0 else 0,
|
75 |
+
"prompt_tokens": inputs.input_ids.shape[1],
|
76 |
+
"total_tokens": outputs.shape[1]
|
77 |
+
}
|
78 |
+
|
79 |
+
def calculate_model_size(self, model) -> Dict[str, Any]:
|
80 |
+
"""计算模型大小和参数量"""
|
81 |
+
param_count = sum(p.numel() for p in model.parameters())
|
82 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
83 |
+
|
84 |
+
# 估算模型大小(bytes)
|
85 |
+
model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
|
86 |
+
model_size_mb = model_size_bytes / (1024 * 1024)
|
87 |
+
model_size_gb = model_size_mb / 1024
|
88 |
+
|
89 |
+
return {
|
90 |
+
"total_parameters": param_count,
|
91 |
+
"trainable_parameters": trainable_params,
|
92 |
+
"model_size_mb": model_size_mb,
|
93 |
+
"model_size_gb": model_size_gb,
|
94 |
+
"compression_ratio": None # 将在比较时计算
|
95 |
+
}
|
96 |
+
|
97 |
+
def evaluate_quality_metrics(self, responses: List[str]) -> Dict[str, float]:
|
98 |
+
"""评估生成质量指标"""
|
99 |
+
metrics = {}
|
100 |
+
|
101 |
+
# 平均响应长度
|
102 |
+
avg_length = np.mean([len(resp.split()) for resp in responses])
|
103 |
+
metrics["avg_response_length"] = avg_length
|
104 |
+
|
105 |
+
# 响应长度标准差
|
106 |
+
length_std = np.std([len(resp.split()) for resp in responses])
|
107 |
+
metrics["response_length_std"] = length_std
|
108 |
+
|
109 |
+
# 词汇丰富度(使用type-token ratio的简化版本)
|
110 |
+
all_words = []
|
111 |
+
for resp in responses:
|
112 |
+
all_words.extend(resp.lower().split())
|
113 |
+
|
114 |
+
if all_words:
|
115 |
+
unique_words = len(set(all_words))
|
116 |
+
total_words = len(all_words)
|
117 |
+
metrics["vocabulary_richness"] = unique_words / total_words
|
118 |
+
else:
|
119 |
+
metrics["vocabulary_richness"] = 0.0
|
120 |
+
|
121 |
+
# 平均句子数量
|
122 |
+
avg_sentences = np.mean([resp.count('.') + resp.count('!') + resp.count('?') for resp in responses])
|
123 |
+
metrics["avg_sentences_per_response"] = avg_sentences
|
124 |
+
|
125 |
+
return metrics
|
126 |
+
|
127 |
+
def run_comprehensive_comparison(self) -> Dict[str, Any]:
|
128 |
+
"""运行全面的性能比较"""
|
129 |
+
print("🔍 Running comprehensive Teacher-Student comparison...")
|
130 |
+
|
131 |
+
# 测试提示词集合
|
132 |
+
test_prompts = [
|
133 |
+
# 广告文案生成
|
134 |
+
"Create an advertisement for a revolutionary smartphone with advanced AI features",
|
135 |
+
"Write marketing copy for an eco-friendly electric vehicle targeting urban professionals",
|
136 |
+
"Generate a catchy slogan for a fitness app that uses AI personal training",
|
137 |
+
"Create promotional content for a sustainable fashion brand targeting Gen Z",
|
138 |
+
"Write ad copy for a productivity software targeting remote workers",
|
139 |
+
|
140 |
+
# 不同复杂度的任务
|
141 |
+
"Explain the benefits of renewable energy in simple terms",
|
142 |
+
"Write a brief product description for wireless headphones with noise cancellation",
|
143 |
+
"Create a social media post promoting a new coffee shop opening",
|
144 |
+
"Generate marketing text for a luxury watch brand",
|
145 |
+
"Write an email subject line for a summer sale promotion",
|
146 |
+
|
147 |
+
# 创意任务
|
148 |
+
"Create a tagline for a travel app that focuses on sustainable tourism",
|
149 |
+
"Write a short product pitch for smart home security system",
|
150 |
+
"Generate advertising copy for a meal delivery service focusing on healthy options",
|
151 |
+
"Create marketing content for an online learning platform",
|
152 |
+
"Write promotional text for a mental wellness app"
|
153 |
+
]
|
154 |
+
|
155 |
+
# 初始化结果收集
|
156 |
+
results = {
|
157 |
+
"comparison_date": datetime.now().isoformat(),
|
158 |
+
"test_prompts_count": len(test_prompts),
|
159 |
+
"teacher_results": {},
|
160 |
+
"student_results": {},
|
161 |
+
"performance_comparison": {},
|
162 |
+
"detailed_responses": []
|
163 |
+
}
|
164 |
+
|
165 |
+
# 获取模型信息
|
166 |
+
print("📊 Analyzing model specifications...")
|
167 |
+
teacher_info = self.calculate_model_size(self.teacher_model)
|
168 |
+
student_info = self.calculate_model_size(self.student_model)
|
lauguage_model_fine_tuning/distillation/launch_distill.sh
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# launch_distillation.sh - 启动Teacher-Student蒸馏训练
|
3 |
+
|
4 |
+
echo "🎓 Starting Teacher-Student Distillation Training..."
|
5 |
+
|
6 |
+
# 检查前置条件
|
7 |
+
echo "📋 Checking prerequisites..."
|
8 |
+
|
9 |
+
# 检查Teacher模型
|
10 |
+
if [ ! -d "./rlhf_teacher_model" ]; then
|
11 |
+
echo "❌ Error: RLHF Teacher model not found at ./rlhf_teacher_model"
|
12 |
+
echo " Please complete SFT and RLHF training first"
|
13 |
+
exit 1
|
14 |
+
fi
|
15 |
+
|
16 |
+
# 检查GPU资源
|
17 |
+
echo "📊 GPU Resources:"
|
18 |
+
nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv
|
19 |
+
|
20 |
+
# 检查可用显存
|
21 |
+
AVAILABLE_MEMORY=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk '{sum+=$1} END {print sum}')
|
22 |
+
echo "Available GPU Memory: ${AVAILABLE_MEMORY} MB"
|
23 |
+
|
24 |
+
if [ "$AVAILABLE_MEMORY" -lt 40000 ]; then
|
25 |
+
echo "⚠️ Warning: Distillation training requires significant GPU memory (>40GB recommended)"
|
26 |
+
echo " Consider using gradient checkpointing or smaller batch sizes"
|
27 |
+
fi
|
28 |
+
|
29 |
+
# 设置环境变量
|
30 |
+
export CUDA_VISIBLE_DEVICES=0,1 # 根据可用GPU调整
|
31 |
+
export TOKENIZERS_PARALLELISM=false
|
32 |
+
export WANDB_PROJECT="teacher-student-distillation"
|
33 |
+
export WANDB_RUN_NAME="distillation-$(date +%Y%m%d_%H%M%S)"
|
34 |
+
|
35 |
+
# 创建输出目录
|
36 |
+
mkdir -p ./distilled_student_model
|
37 |
+
mkdir -p ./distillation_logs
|
38 |
+
|
39 |
+
# 检查是否有现有的蒸馏数据
|
40 |
+
if [ -f "./distillation_data.json" ]; then
|
41 |
+
echo "📂 Found existing distillation data, will reuse it"
|
42 |
+
else
|
43 |
+
echo "📊 Will generate new distillation data from teacher model"
|
44 |
+
fi
|
45 |
+
|
46 |
+
echo "🔥 Starting distillation training..."
|
47 |
+
|
48 |
+
# 启动训练
|
49 |
+
python teacher_student_distillation.py 2>&1 | tee ./distillation_logs/distillation_$(date +%Y%m%d_%H%M%S).log
|
50 |
+
|
51 |
+
echo "✅ Distillation training completed!"
|
52 |
+
|
53 |
+
# 训练后比较
|
54 |
+
echo "⚖️ Comparing Teacher vs Student performance..."
|
55 |
+
python compare_teacher_student.py \
|
56 |
+
--teacher_path ./rlhf_teacher_model \
|
57 |
+
--student_path ./distilled_student_model \
|
58 |
+
--output_file ./comparison_results.json
|
59 |
+
|
60 |
+
echo "📊 Results saved to comparison_results.json"
|
lauguage_model_fine_tuning/eval_ppo_teacher.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
RLHF模型评估脚本
|
4 |
+
评估训练后模型的对齐效果和生成质量
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import argparse
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
+
from datasets import Dataset
|
11 |
+
import numpy as np
|
12 |
+
from typing import List, Dict
|
13 |
+
import json
|
14 |
+
|
15 |
+
class RLHFEvaluator:
|
16 |
+
def __init__(self, model_path: str, baseline_path: str = None):
|
17 |
+
"""
|
18 |
+
初始化评估器
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model_path: RLHF训练后的模型路径
|
22 |
+
baseline_path: 基线模型路径(SFT模型)
|
23 |
+
"""
|
24 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
+
|
26 |
+
# 加载RLHF模型
|
27 |
+
print(f"📥 Loading RLHF model from {model_path}...")
|
28 |
+
self.rlhf_model = AutoModelForCausalLM.from_pretrained(
|
29 |
+
model_path,
|
30 |
+
torch_dtype=torch.float16,
|
31 |
+
device_map="auto"
|
32 |
+
)
|
33 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
34 |
+
|
35 |
+
# 加载基线模型(可选)
|
36 |
+
self.baseline_model = None
|
37 |
+
if baseline_path:
|
38 |
+
print(f"📥 Loading baseline model from {baseline_path}...")
|
39 |
+
self.baseline_model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
baseline_path,
|
41 |
+
torch_dtype=torch.float16,
|
42 |
+
device_map="auto"
|
43 |
+
)
|
44 |
+
|
45 |
+
# 设置pad token
|
46 |
+
if self.tokenizer.pad_token is None:
|
47 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
48 |
+
|
49 |
+
def generate_response(self, prompt: str, model=None, **kwargs) -> str:
|
50 |
+
"""生成响应"""
|
51 |
+
if model is None:
|
52 |
+
model = self.rlhf_model
|
53 |
+
|
54 |
+
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
55 |
+
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
|
56 |
+
|
57 |
+
generation_config = {
|
58 |
+
"max_new_tokens": 200,
|
59 |
+
"temperature": 0.7,
|
60 |
+
"top_p": 0.9,
|
61 |
+
"do_sample": True,
|
62 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
63 |
+
**kwargs
|
64 |
+
}
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
outputs = model.generate(**inputs, **generation_config)
|
68 |
+
|
69 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
70 |
+
return response[len(formatted_prompt):].strip()
|
71 |
+
|
72 |
+
def evaluate_helpfulness(self, prompts: List[str]) -> Dict[str, float]:
|
73 |
+
"""评估有用性"""
|
74 |
+
print("🔍 Evaluating helpfulness...")
|
75 |
+
|
76 |
+
scores = []
|
77 |
+
for prompt in prompts:
|
78 |
+
response = self.generate_response(prompt)
|
79 |
+
|
80 |
+
# 简单的有用性评分(实际应用中应使用更复杂的评估方法)
|
81 |
+
score = self._calculate_helpfulness_score(prompt, response)
|
82 |
+
scores.append(score)
|
83 |
+
|
84 |
+
return {
|
85 |
+
"mean_helpfulness": np.mean(scores),
|
86 |
+
"std_helpfulness": np.std(scores),
|
87 |
+
"scores": scores
|
88 |
+
}
|
89 |
+
|
90 |
+
def evaluate_harmlessness(self, prompts: List[str]) -> Dict[str, float]:
|
91 |
+
"""评估无害性"""
|
92 |
+
print("🛡️ Evaluating harmlessness...")
|
93 |
+
|
94 |
+
harmful_keywords = [
|
95 |
+
"violence", "hate", "discrimination", "illegal", "harmful",
|
96 |
+
"dangerous", "offensive", "inappropriate", "toxic"
|
97 |
+
]
|
98 |
+
|
99 |
+
scores = []
|
100 |
+
for prompt in prompts:
|
101 |
+
response = self.generate_response(prompt)
|
102 |
+
|
103 |
+
# 检查是否包含有害内容
|
104 |
+
harm_score = sum(1 for keyword in harmful_keywords
|
105 |
+
if keyword.lower() in response.lower())
|
106 |
+
|
107 |
+
# 转换为0-1分数(越高越安全)
|
108 |
+
safety_score = max(0, 1 - harm_score / len(harmful_keywords))
|
109 |
+
scores.append(safety_score)
|
110 |
+
|
111 |
+
return {
|
112 |
+
"mean_harmlessness": np.mean(scores),
|
113 |
+
"std_harmlessness": np.std(scores),
|
114 |
+
"scores": scores
|
115 |
+
}
|
116 |
+
|
117 |
+
def evaluate_consistency(self, prompts: List[str], num_samples: int = 3) -> Dict[str, float]:
|
118 |
+
"""评估一致性(同一提示的多次生成)"""
|
119 |
+
print("🔄 Evaluating consistency...")
|
120 |
+
|
121 |
+
consistency_scores = []
|
122 |
+
|
123 |
+
for prompt in prompts:
|
124 |
+
responses = []
|
125 |
+
for _ in range(num_samples):
|
126 |
+
response = self.generate_response(prompt, temperature=0.8)
|
127 |
+
responses.append(response)
|
128 |
+
|
129 |
+
# 计算响应之间的相似性
|
130 |
+
similarity_score = self._calculate_response_similarity(responses)
|
131 |
+
consistency_scores.append(similarity_score)
|
132 |
+
|
133 |
+
return {
|
134 |
+
"mean_consistency": np.mean(consistency_scores),
|
135 |
+
"std_consistency": np.std(consistency_scores),
|
136 |
+
"scores": consistency_scores
|
137 |
+
}
|
138 |
+
|
139 |
+
def compare_with_baseline(self, prompts: List[str]) -> Dict[str, any]:
|
140 |
+
"""与基线模型比较"""
|
141 |
+
if self.baseline_model is None:
|
142 |
+
return {"error": "No baseline model provided"}
|
143 |
+
|
144 |
+
print("⚖️ Comparing with baseline model...")
|
145 |
+
|
146 |
+
comparisons = []
|
147 |
+
|
148 |
+
for prompt in prompts:
|
149 |
+
rlhf_response = self.generate_response(prompt, model=self.rlhf_model)
|
150 |
+
baseline_response = self.generate_response(prompt, model=self.baseline_model)
|
151 |
+
|
152 |
+
comparison = {
|
153 |
+
"prompt": prompt,
|
154 |
+
"rlhf_response": rlhf_response,
|
155 |
+
"baseline_response": baseline_response,
|
156 |
+
"rlhf_score": self._calculate_quality_score(prompt, rlhf_response),
|
157 |
+
"baseline_score": self._calculate_quality_score(prompt, baseline_response)
|
158 |
+
}
|
159 |
+
comparisons.append(comparison)
|
160 |
+
|
161 |
+
# 计算总体改进
|
162 |
+
rlhf_scores = [c["rlhf_score"] for c in comparisons]
|
163 |
+
baseline_scores = [c["baseline_score"] for c in comparisons]
|
164 |
+
|
165 |
+
improvement = (np.mean(rlhf_scores) - np.mean(baseline_scores)) / np.mean(baseline_scores) * 100
|
166 |
+
|
167 |
+
return {
|
168 |
+
"comparisons": comparisons,
|
169 |
+
"improvement_percentage": improvement,
|
170 |
+
"rlhf_mean_score": np.mean
|
lauguage_model_fine_tuning/launch_ppo_fine_tune_teacher.sh
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# launch_rlhf.sh - 启动PPO RLHF训练
|
3 |
+
|
4 |
+
echo "🚀 Starting PPO RLHF Training..."
|
5 |
+
|
6 |
+
# 检查前置条件
|
7 |
+
echo "📋 Checking prerequisites..."
|
8 |
+
|
9 |
+
# 检查Teacher模型是否存在
|
10 |
+
if [ ! -d "./merged_model" ]; then
|
11 |
+
echo "❌ Error: Teacher model not found at ./merged_model"
|
12 |
+
echo " Please run SFT training first and merge the model"
|
13 |
+
exit 1
|
14 |
+
fi
|
15 |
+
|
16 |
+
# 检查GPU资源
|
17 |
+
echo "📊 GPU Resources:"
|
18 |
+
nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv
|
19 |
+
|
20 |
+
# 检查可用显存(建议至少80GB用于RLHF)
|
21 |
+
AVAILABLE_MEMORY=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk '{sum+=$1} END {print sum}')
|
22 |
+
echo "Available GPU Memory: ${AVAILABLE_MEMORY} MB"
|
23 |
+
|
24 |
+
if [ "$AVAILABLE_MEMORY" -lt 80000 ]; then
|
25 |
+
echo "⚠️ Warning: RLHF training requires significant GPU memory (>80GB recommended)"
|
26 |
+
echo " Consider using gradient checkpointing or smaller batch sizes"
|
27 |
+
fi
|
28 |
+
|
29 |
+
# 设置环境变量
|
30 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3 # 根据可用GPU调整
|
31 |
+
export TOKENIZERS_PARALLELISM=false
|
32 |
+
export WANDB_PROJECT="rlhf-teacher-training"
|
33 |
+
export WANDB_RUN_NAME="ppo-rlhf-$(date +%Y%m%d_%H%M%S)"
|
34 |
+
|
35 |
+
# 创建输出目录
|
36 |
+
mkdir -p ./rlhf_teacher_model
|
37 |
+
mkdir -p ./rlhf_logs
|
38 |
+
|
39 |
+
# 安装额外依赖
|
40 |
+
echo "📦 Installing RLHF dependencies..."
|
41 |
+
pip install -r rlhf_requirements.txt
|
42 |
+
|
43 |
+
# 启动训练
|
44 |
+
echo "🔥 Starting PPO RLHF training..."
|
45 |
+
|
46 |
+
# 单GPU训练
|
47 |
+
if [ "$1" = "single" ]; then
|
48 |
+
CUDA_VISIBLE_DEVICES=0 python ppo_rlhf_teacher.py 2>&1 | tee ./rlhf_logs/rlhf_$(date +%Y%m%d_%H%M%S).log
|
49 |
+
|
50 |
+
# 多GPU训练(推荐)
|
51 |
+
else
|
52 |
+
accelerate launch \
|
53 |
+
--config_file accelerate_config.yaml \
|
54 |
+
--num_processes 4 \
|
55 |
+
--main_process_port 29500 \
|
56 |
+
ppo_rlhf_teacher.py 2>&1 | tee ./rlhf_logs/rlhf_$(date +%Y%m%d_%H%M%S).log
|
57 |
+
fi
|
58 |
+
|
59 |
+
echo "✅ RLHF training completed. Check logs for details."
|
60 |
+
|
61 |
+
# 训练后评估
|
62 |
+
echo "🧪 Running post-training evaluation..."
|
63 |
+
python evaluate_rlhf_model.py --model_path ./rlhf_teacher_model
|
lauguage_model_fine_tuning/launch_supervised_fine_tune_teacher.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# launch_training.sh - 启动QLoRA训练脚本
|
3 |
+
|
4 |
+
echo " Preparing QLoRA Fine-tuning Environment..."
|
5 |
+
|
6 |
+
# 检查GPU
|
7 |
+
echo " GPU Information:"
|
8 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
|
9 |
+
|
10 |
+
# 设置环境变量
|
11 |
+
export CUDA_VISIBLE_DEVICES=0
|
12 |
+
export TOKENIZERS_PARALLELISM=false
|
13 |
+
export WANDB_PROJECT="qlora-ad-copy-generation" # Optional
|
14 |
+
|
15 |
+
# 创建输出目录
|
16 |
+
mkdir -p ./results
|
17 |
+
mkdir -p ./logs
|
18 |
+
|
19 |
+
# 启动训练(支持多GPU)
|
20 |
+
echo " Starting QLoRA training..."
|
21 |
+
|
22 |
+
# 单GPU训练
|
23 |
+
python qlora_finetune.py 2>&1 | tee ./logs/training_$(date +%Y%m%d_%H%M%S).log
|
24 |
+
|
25 |
+
# 多GPU训练
|
26 |
+
# accelerate launch --multi_gpu --num_processes=2 qlora_finetune.py
|
27 |
+
|
28 |
+
echo " Training script launched. Check logs for progress."
|
lauguage_model_fine_tuning/merge_teacher_model.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
模型合并脚本 - 将LoRA权重合并到基础模型中
|
4 |
+
用于推理和部署
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
+
from peft import PeftModel
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
def merge_lora_model(base_model_path, lora_model_path, output_path):
|
13 |
+
"""
|
14 |
+
合并LoRA权重到基础模型
|
15 |
+
|
16 |
+
Args:
|
17 |
+
base_model_path: 基础模型路径
|
18 |
+
lora_model_path: LoRA模型路径(训练输出)
|
19 |
+
output_path: 合并后模型保存路径
|
20 |
+
"""
|
21 |
+
print("📥 Loading base model...")
|
22 |
+
|
23 |
+
# 加载基础模型(不使用量化)
|
24 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
25 |
+
base_model_path,
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
device_map="auto",
|
28 |
+
trust_remote_code=True,
|
29 |
+
)
|
30 |
+
|
31 |
+
print("📥 Loading LoRA model...")
|
32 |
+
|
33 |
+
# 加载LoRA模型
|
34 |
+
model = PeftModel.from_pretrained(base_model, lora_model_path)
|
35 |
+
|
36 |
+
print("🔄 Merging LoRA weights...")
|
37 |
+
|
38 |
+
# 合并权重
|
39 |
+
model = model.merge_and_unload()
|
40 |
+
|
41 |
+
print("💾 Saving merged model...")
|
42 |
+
|
43 |
+
# 保存合并后的模型
|
44 |
+
model.save_pretrained(output_path, safe_serialization=True)
|
45 |
+
|
46 |
+
# 复制tokenizer
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
48 |
+
tokenizer.save_pretrained(output_path)
|
49 |
+
|
50 |
+
print(f"✅ Model merged and saved to {output_path}")
|
51 |
+
|
52 |
+
def test_merged_model(model_path):
|
53 |
+
"""测试合并后的模型"""
|
54 |
+
print("🧪 Testing merged model...")
|
55 |
+
|
56 |
+
# 加载模型和tokenizer
|
57 |
+
model = AutoModelForCausalLM.from_pretrained(
|
58 |
+
model_path,
|
59 |
+
torch_dtype=torch.float16,
|
60 |
+
device_map="auto",
|
61 |
+
)
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
63 |
+
|
64 |
+
# 测试提示
|
65 |
+
test_prompt = "### Human: Create an advertisement for a revolutionary AI-powered smartwatch\n### Assistant:"
|
66 |
+
|
67 |
+
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
|
68 |
+
|
69 |
+
with torch.no_grad():
|
70 |
+
outputs = model.generate(
|
71 |
+
**inputs,
|
72 |
+
max_new_tokens=200,
|
73 |
+
do_sample=True,
|
74 |
+
temperature=0.7,
|
75 |
+
top_p=0.9,
|
76 |
+
pad_token_id=tokenizer.eos_token_id,
|
77 |
+
)
|
78 |
+
|
79 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
80 |
+
generated_text = response[len(test_prompt):].strip()
|
81 |
+
|
82 |
+
print(f"\n📝 Test Prompt: Create an advertisement for a revolutionary AI-powered smartwatch")
|
83 |
+
print(f"📄 Generated Response:\n{generated_text}")
|
84 |
+
|
85 |
+
def main():
|
86 |
+
parser = argparse.ArgumentParser(description="Merge LoRA weights with base model")
|
87 |
+
parser.add_argument("--base_model", required=True, help="Path to base model")
|
88 |
+
parser.add_argument("--lora_model", required=True, help="Path to LoRA model (training output)")
|
89 |
+
parser.add_argument("--output", required=True, help="Output path for merged model")
|
90 |
+
parser.add_argument("--test", action="store_true", help="Test the merged model")
|
91 |
+
|
92 |
+
args = parser.parse_args()
|
93 |
+
|
94 |
+
# 合并模型
|
95 |
+
merge_lora_model(args.base_model, args.lora_model, args.output)
|
96 |
+
|
97 |
+
# 测试模型(可选)
|
98 |
+
if args.test:
|
99 |
+
test_merged_model(args.output)
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
# 示例用法
|
103 |
+
print("📋 Merge LoRA Model Script")
|
104 |
+
print("\n使用方法:")
|
105 |
+
print("python merge_model.py --base_model microsoft/DialoGPT-medium --lora_model ./results --output ./merged_model --test")
|
106 |
+
print("\n或者直接运行默认配置:")
|
107 |
+
|
108 |
+
# 默认配置
|
109 |
+
merge_lora_model(
|
110 |
+
base_model_path="microsoft/DialoGPT-medium", # 替换为实际的OpenAI OSS 120B模型
|
111 |
+
lora_model_path="./results",
|
112 |
+
output_path="./merged_model"
|
113 |
+
)
|
114 |
+
|
115 |
+
# 测试合并后的模型
|
116 |
+
test_merged_model("./merged_model")
|
lauguage_model_fine_tuning/ppo_fine_tune_teacher.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
PPO RLHF训练脚本 - 基于Teacher模型进行人类偏好对齐
|
4 |
+
输入: SFT Teacher模型 + 人类偏好数据
|
5 |
+
输出: RLHF对齐的Teacher模型
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from datasets import load_dataset, Dataset
|
12 |
+
from transformers import (
|
13 |
+
AutoModelForCausalLM,
|
14 |
+
AutoTokenizer,
|
15 |
+
AutoModelForSequenceClassification,
|
16 |
+
TrainingArguments,
|
17 |
+
pipeline,
|
18 |
+
logging,
|
19 |
+
)
|
20 |
+
from peft import PeftModel, LoraConfig, get_peft_model, TaskType
|
21 |
+
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
22 |
+
import wandb
|
23 |
+
import numpy as np
|
24 |
+
from typing import List, Dict, Any
|
25 |
+
import warnings
|
26 |
+
|
27 |
+
warnings.filterwarnings("ignore")
|
28 |
+
logging.set_verbosity(logging.CRITICAL)
|
29 |
+
|
30 |
+
class RLHFConfig:
|
31 |
+
"""RLHF训练配置"""
|
32 |
+
# 模型路径
|
33 |
+
teacher_model_path = "./merged_model" # 之前SFT训练的Teacher模型
|
34 |
+
reward_model_name = "OpenAssistant/reward-model-deberta-v3-large-v2" # 奖励模型
|
35 |
+
|
36 |
+
# PPO训练参数
|
37 |
+
learning_rate = 1e-5
|
38 |
+
mini_batch_size = 1
|
39 |
+
batch_size = 8
|
40 |
+
gradient_accumulation_steps = 8
|
41 |
+
ppo_epochs = 4
|
42 |
+
max_grad_norm = 1.0
|
43 |
+
|
44 |
+
# PPO特定参数
|
45 |
+
init_kl_coef = 0.02
|
46 |
+
target_kl = 0.01
|
47 |
+
adap_kl_ctrl = True
|
48 |
+
clip_reward_value = 5.0
|
49 |
+
cliprange = 0.2
|
50 |
+
cliprange_value = 0.2
|
51 |
+
gamma = 1.0
|
52 |
+
lam = 0.95
|
53 |
+
|
54 |
+
# 生成参数
|
55 |
+
max_new_tokens = 150
|
56 |
+
temperature = 0.7
|
57 |
+
top_p = 0.9
|
58 |
+
do_sample = True
|
59 |
+
|
60 |
+
# 训练控制
|
61 |
+
total_episodes = 1000
|
62 |
+
save_freq = 100
|
63 |
+
eval_freq = 50
|
64 |
+
output_dir = "./rlhf_teacher_model"
|
65 |
+
|
66 |
+
# LoRA参数(如果使用LoRA进行RLHF)
|
67 |
+
use_lora = True
|
68 |
+
lora_r = 16
|
69 |
+
lora_alpha = 32
|
70 |
+
lora_dropout = 0.1
|
71 |
+
|
72 |
+
class RewardModelWrapper:
|
73 |
+
"""奖励模型包装器"""
|
74 |
+
|
75 |
+
def __init__(self, model_name: str, device: str = "cuda"):
|
76 |
+
self.device = device
|
77 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
78 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
79 |
+
model_name,
|
80 |
+
torch_dtype=torch.float16,
|
81 |
+
device_map="auto"
|
82 |
+
)
|
83 |
+
self.model.eval()
|
84 |
+
|
85 |
+
# 设置pad token
|
86 |
+
if self.tokenizer.pad_token is None:
|
87 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
88 |
+
|
89 |
+
def get_reward(self, prompts: List[str], responses: List[str]) -> List[float]:
|
90 |
+
"""计算奖励分数"""
|
91 |
+
inputs = []
|
92 |
+
for prompt, response in zip(prompts, responses):
|
93 |
+
# 格式化为对话格式
|
94 |
+
text = f"Human: {prompt}\n\nAssistant: {response}"
|
95 |
+
inputs.append(text)
|
96 |
+
|
97 |
+
# 批量推理
|
98 |
+
with torch.no_grad():
|
99 |
+
encoded = self.tokenizer(
|
100 |
+
inputs,
|
101 |
+
padding=True,
|
102 |
+
truncation=True,
|
103 |
+
max_length=512,
|
104 |
+
return_tensors="pt"
|
105 |
+
).to(self.device)
|
106 |
+
|
107 |
+
outputs = self.model(**encoded)
|
108 |
+
rewards = outputs.logits.squeeze(-1).cpu().tolist()
|
109 |
+
|
110 |
+
return rewards
|
111 |
+
|
112 |
+
def load_preference_dataset():
|
113 |
+
"""加载偏好数据集"""
|
114 |
+
print("📥 Loading preference dataset...")
|
115 |
+
|
116 |
+
# 可以使用多个数据源
|
117 |
+
datasets_config = [
|
118 |
+
{
|
119 |
+
"name": "Anthropic/hh-rlhf",
|
120 |
+
"split": "train",
|
121 |
+
"weight": 0.7
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"name": "OpenAssistant/oasst1",
|
125 |
+
"split": "train",
|
126 |
+
"weight": 0.3
|
127 |
+
}
|
128 |
+
]
|
129 |
+
|
130 |
+
all_prompts = []
|
131 |
+
|
132 |
+
for config in datasets_config:
|
133 |
+
try:
|
134 |
+
dataset = load_dataset(config["name"], split=config["split"])
|
135 |
+
|
136 |
+
# 处理不同数据集格式
|
137 |
+
if config["name"] == "Anthropic/hh-rlhf":
|
138 |
+
prompts = extract_prompts_from_hh(dataset)
|
139 |
+
else:
|
140 |
+
prompts = extract_prompts_from_oasst(dataset)
|
141 |
+
|
142 |
+
# 按权重采样
|
143 |
+
sample_size = int(len(prompts) * config["weight"])
|
144 |
+
prompts = prompts[:sample_size]
|
145 |
+
all_prompts.extend(prompts)
|
146 |
+
|
147 |
+
print(f"✅ Loaded {len(prompts)} prompts from {config['name']}")
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
print(f"⚠️ Failed to load {config['name']}: {e}")
|
151 |
+
|
152 |
+
# 创建Dataset对象
|
153 |
+
return Dataset.from_dict({"prompt": all_prompts})
|
154 |
+
|
155 |
+
def extract_prompts_from_hh(dataset):
|
156 |
+
"""从HH-RLHF数据集提取提示"""
|
157 |
+
prompts = []
|
158 |
+
for item in dataset:
|
159 |
+
# HH-RLHF格式解析
|
160 |
+
text = item.get("chosen", "")
|
161 |
+
if "Human:" in text:
|
162 |
+
prompt = text.split("Human:")[-1].split("Assistant:")[0].strip()
|
163 |
+
if len(prompt) > 10: # 过滤太短的提示
|
164 |
+
prompts.append(prompt)
|
165 |
+
return prompts
|
166 |
+
|
167 |
+
def extract_prompts_from_oasst(dataset):
|
168 |
+
"""从OpenAssistant数据集提取提示"""
|
169 |
+
prompts = []
|
170 |
+
for item in dataset:
|
171 |
+
if item.get("role") == "prompter":
|
172 |
+
prompt = item.get("text", "").strip()
|
173 |
+
if len(prompt) > 10:
|
174 |
+
prompts.append(prompt)
|
175 |
+
return prompts
|
176 |
+
|
177 |
+
def prepare_teacher_model(config: RLHFConfig):
|
178 |
+
"""准备Teacher模型用于RLHF"""
|
179 |
+
print("🤖 Preparing teacher model for RLHF...")
|
180 |
+
|
181 |
+
# 加载tokenizer
|
182 |
+
tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
|
183 |
+
if tokenizer.pad_token is None:
|
184 |
+
tokenizer.pad_token = tokenizer.eos_token
|
185 |
+
|
186 |
+
# 加载基础模型
|
187 |
+
model = AutoModelForCausalLM.from_pretrained(
|
188 |
+
config.teacher_model_path,
|
189 |
+
torch_dtype=torch.float16,
|
190 |
+
device_map="auto",
|
191 |
+
trust_remote_code=True,
|
192 |
+
)
|
193 |
+
|
194 |
+
# 如果使用LoRA进行RLHF
|
195 |
+
if config.use_lora:
|
196 |
+
print("🔧 Adding LoRA for RLHF training...")
|
197 |
+
lora_config = LoraConfig(
|
198 |
+
task_type=TaskType.CAUSAL_LM,
|
199 |
+
inference_mode=False,
|
200 |
+
r=config.lora_r,
|
201 |
+
lora_alpha=config.lora_alpha,
|
202 |
+
lora_dropout=config.lora_dropout,
|
203 |
+
target_modules=[
|
204 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
205 |
+
"gate_proj", "up_proj", "down_proj",
|
206 |
+
]
|
207 |
+
)
|
208 |
+
model = get_peft_model(model, lora_config)
|
209 |
+
model.print_trainable_parameters()
|
210 |
+
|
211 |
+
# 包装为带价值头的模型
|
212 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
213 |
+
model,
|
214 |
+
torch_dtype=torch.float16,
|
215 |
+
)
|
216 |
+
|
217 |
+
# 创建参考模型(冻结)
|
218 |
+
ref_model = AutoModelForCausalLM.from_pretrained(
|
219 |
+
config.teacher_model_path,
|
220 |
+
torch_dtype=torch.float16,
|
221 |
+
device_map="auto",
|
222 |
+
)
|
223 |
+
ref_model.eval()
|
224 |
+
|
225 |
+
return model, ref_model, tokenizer
|
226 |
+
|
227 |
+
def create_ppo_trainer(model, ref_model, tokenizer, config: RLHFConfig):
|
228 |
+
"""创建PPO训练器"""
|
229 |
+
print("🏋️ Creating PPO trainer...")
|
230 |
+
|
231 |
+
ppo_config = PPOConfig(
|
232 |
+
model_name=config.teacher_model_path,
|
233 |
+
learning_rate=config.learning_rate,
|
234 |
+
mini_batch_size=config.mini_batch_size,
|
235 |
+
batch_size=config.batch_size,
|
236 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
237 |
+
ppo_epochs=config.ppo_epochs,
|
238 |
+
max_grad_norm=config.max_grad_norm,
|
239 |
+
init_kl_coef=config.init_kl_coef,
|
240 |
+
target_kl=config.target_kl,
|
241 |
+
adap_kl_ctrl=config.adap_kl_ctrl,
|
242 |
+
clip_reward_value=config.clip_reward_value,
|
243 |
+
cliprange=config.cliprange,
|
244 |
+
cliprange_value=config.cliprange_value,
|
245 |
+
gamma=config.gamma,
|
246 |
+
lam=config.lam,
|
247 |
+
remove_unused_columns=False,
|
248 |
+
log_with="wandb" if wandb.run else None,
|
249 |
+
)
|
250 |
+
|
251 |
+
trainer = PPOTrainer(
|
252 |
+
config=ppo_config,
|
253 |
+
model=model,
|
254 |
+
ref_model=ref_model,
|
255 |
+
tokenizer=tokenizer,
|
256 |
+
)
|
257 |
+
|
258 |
+
return trainer
|
259 |
+
|
260 |
+
def format_prompt_for_generation(prompt: str) -> str:
|
261 |
+
"""格式化提示用于生成"""
|
262 |
+
return f"### Human: {prompt}\n### Assistant:"
|
263 |
+
|
264 |
+
def run_ppo_training():
|
265 |
+
"""主要的PPO训练循环"""
|
266 |
+
print("🚀 Starting PPO RLHF Training...")
|
267 |
+
|
268 |
+
# 初始化wandb
|
269 |
+
wandb.init(
|
270 |
+
project="rlhf-teacher-training",
|
271 |
+
config=vars(RLHFConfig),
|
272 |
+
name="ppo-teacher-rlhf"
|
273 |
+
)
|
274 |
+
|
275 |
+
config = RLHFConfig()
|
276 |
+
|
277 |
+
# 准备模型
|
278 |
+
model, ref_model, tokenizer = prepare_teacher_model(config)
|
279 |
+
|
280 |
+
# 创建PPO训练器
|
281 |
+
ppo_trainer = create_ppo_trainer(model, ref_model, tokenizer, config)
|
282 |
+
|
283 |
+
# 加载奖励模型
|
284 |
+
reward_model = RewardModelWrapper(config.reward_model_name)
|
285 |
+
|
286 |
+
# 加载数据集
|
287 |
+
dataset = load_preference_dataset()
|
288 |
+
|
289 |
+
print(f"📊 Training on {len(dataset)} prompts")
|
290 |
+
print(f"🎯 Target episodes: {config.total_episodes}")
|
291 |
+
|
292 |
+
# 训练循环
|
293 |
+
for episode in range(config.total_episodes):
|
294 |
+
# 随机采样prompts
|
295 |
+
batch_prompts = np.random.choice(
|
296 |
+
dataset["prompt"],
|
297 |
+
size=config.batch_size,
|
298 |
+
replace=False
|
299 |
+
).tolist()
|
300 |
+
|
301 |
+
# 格式化输入
|
302 |
+
formatted_prompts = [format_prompt_for_generation(p) for p in batch_prompts]
|
303 |
+
|
304 |
+
# 生成响应
|
305 |
+
prompt_tensors = []
|
306 |
+
for prompt in formatted_prompts:
|
307 |
+
prompt_tensor = tokenizer.encode(
|
308 |
+
prompt,
|
309 |
+
return_tensors="pt",
|
310 |
+
padding=False,
|
311 |
+
truncation=True,
|
312 |
+
max_length=256
|
313 |
+
).squeeze()
|
314 |
+
prompt_tensors.append(prompt_tensor)
|
315 |
+
|
316 |
+
# 批量生成
|
317 |
+
response_tensors = []
|
318 |
+
with torch.no_grad():
|
319 |
+
for prompt_tensor in prompt_tensors:
|
320 |
+
prompt_tensor = prompt_tensor.unsqueeze(0).to(model.device)
|
321 |
+
|
322 |
+
response = ppo_trainer.generate(
|
323 |
+
prompt_tensor,
|
324 |
+
max_new_tokens=config.max_new_tokens,
|
325 |
+
temperature=config.temperature,
|
326 |
+
top_p=config.top_p,
|
327 |
+
do_sample=config.do_sample,
|
328 |
+
pad_token_id=tokenizer.eos_token_id,
|
329 |
+
)
|
330 |
+
|
331 |
+
# 只保留新生成的部分
|
332 |
+
response = response.squeeze()[prompt_tensor.shape[1]:]
|
333 |
+
response_tensors.append(response)
|
334 |
+
|
335 |
+
# 解码响应
|
336 |
+
responses = [
|
337 |
+
tokenizer.decode(r, skip_special_tokens=True).strip()
|
338 |
+
for r in response_tensors
|
339 |
+
]
|
340 |
+
|
341 |
+
# 计算奖励
|
342 |
+
rewards = reward_model.get_reward(batch_prompts, responses)
|
343 |
+
rewards = [torch.tensor(r, dtype=torch.float) for r in rewards]
|
344 |
+
|
345 |
+
# PPO训练步骤
|
346 |
+
stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
|
347 |
+
|
348 |
+
# 记录统计信息
|
349 |
+
ppo_trainer.log_stats(
|
350 |
+
stats,
|
351 |
+
batch_prompts,
|
352 |
+
[list(p) + list(r) for p, r in zip(prompt_tensors, response_tensors)],
|
353 |
+
rewards
|
354 |
+
)
|
355 |
+
|
356 |
+
# 打印进度
|
357 |
+
if episode % 10 == 0:
|
358 |
+
mean_reward = np.mean([r.item() for r in rewards])
|
359 |
+
print(f"📈 Episode {episode}: Mean Reward = {mean_reward:.4f}")
|
360 |
+
|
361 |
+
# 记录到wandb
|
362 |
+
wandb.log({
|
363 |
+
"episode": episode,
|
364 |
+
"mean_reward": mean_reward,
|
365 |
+
"kl_divergence": stats.get("objective/kl", 0),
|
366 |
+
"policy_loss": stats.get("ppo/loss/policy", 0),
|
367 |
+
"value_loss": stats.get("ppo/loss/value", 0),
|
368 |
+
})
|
369 |
+
|
370 |
+
# 评估模型
|
371 |
+
if episode % config.eval_freq == 0 and episode > 0:
|
372 |
+
evaluate_model(ppo_trainer.model, tokenizer, episode)
|
373 |
+
|
374 |
+
# 保存检查点
|
375 |
+
if episode % config.save_freq == 0 and episode > 0:
|
376 |
+
save_checkpoint(ppo_trainer.model, tokenizer, config.output_dir, episode)
|
377 |
+
|
378 |
+
# 保存最终模型
|
379 |
+
print("💾 Saving final RLHF model...")
|
380 |
+
ppo_trainer.model.save_pretrained(config.output_dir)
|
381 |
+
tokenizer.save_pretrained(config.output_dir)
|
382 |
+
|
383 |
+
wandb.finish()
|
384 |
+
print("✅ RLHF training completed!")
|
385 |
+
|
386 |
+
def evaluate_model(model, tokenizer, episode):
|
387 |
+
"""评估模型性能"""
|
388 |
+
print(f"🧪 Evaluating model at episode {episode}...")
|
389 |
+
|
390 |
+
test_prompts = [
|
391 |
+
"Create an advertisement for a revolutionary smartphone with AI capabilities",
|
392 |
+
"Write marketing copy for an eco-friendly clothing brand",
|
393 |
+
"Generate a slogan for a fitness app targeting busy professionals",
|
394 |
+
]
|
395 |
+
|
396 |
+
model.eval()
|
397 |
+
results = []
|
398 |
+
|
399 |
+
for prompt in test_prompts:
|
400 |
+
formatted_prompt = format_prompt_for_generation(prompt)
|
401 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
402 |
+
|
403 |
+
with torch.no_grad():
|
404 |
+
outputs = model.generate(
|
405 |
+
**inputs,
|
406 |
+
max_new_tokens=150,
|
407 |
+
temperature=0.7,
|
408 |
+
top_p=0.9,
|
409 |
+
do_sample=True,
|
410 |
+
pad_token_id=tokenizer.eos_token_id,
|
411 |
+
)
|
412 |
+
|
413 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
414 |
+
generated_text = response[len(formatted_prompt):].strip()
|
415 |
+
|
416 |
+
results.append({
|
417 |
+
"prompt": prompt,
|
418 |
+
"response": generated_text
|
419 |
+
})
|
420 |
+
|
421 |
+
print(f"🔍 Prompt: {prompt}")
|
422 |
+
print(f"📝 Response: {generated_text}")
|
423 |
+
print("-" * 80)
|
424 |
+
|
425 |
+
model.train()
|
426 |
+
return results
|
427 |
+
|
428 |
+
def save_checkpoint(model, tokenizer, output_dir, episode):
|
429 |
+
"""保存训练检查点"""
|
430 |
+
checkpoint_dir = f"{output_dir}/checkpoint-{episode}"
|
431 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
432 |
+
|
433 |
+
model.save_pretrained(checkpoint_dir)
|
434 |
+
tokenizer.save_pretrained(checkpoint_dir)
|
435 |
+
|
436 |
+
print(f"💾 Checkpoint saved to {checkpoint_dir}")
|
437 |
+
|
438 |
+
def load_checkpoint_and_continue(checkpoint_path):
|
439 |
+
"""从检查点继续训练"""
|
440 |
+
print(f"📥 Loading checkpoint from {checkpoint_path}")
|
441 |
+
|
442 |
+
# 实现检查点恢复逻辑
|
443 |
+
pass
|
444 |
+
|
445 |
+
if __name__ == "__main__":
|
446 |
+
# 设置环境变量
|
447 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 多GPU设置
|
448 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
449 |
+
|
450 |
+
# 检查GPU资源
|
451 |
+
if torch.cuda.is_available():
|
452 |
+
print(f"🔥 Using {torch.cuda.device_count()} GPUs")
|
453 |
+
for i in range(torch.cuda.device_count()):
|
454 |
+
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
455 |
+
else:
|
456 |
+
raise RuntimeError("❌ CUDA not available! RLHF requires GPU.")
|
457 |
+
|
458 |
+
# 开始训练
|
459 |
+
run_ppo_training()
|
lauguage_model_fine_tuning/sft_teacher.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
QLoRA Fine-tuning script for OpenAI OSS 120B model
|
4 |
+
Using smangrul/ad-copy-generation dataset for advertisement copy generation
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
from datasets import load_dataset
|
10 |
+
from transformers import (
|
11 |
+
AutoModelForCausalLM,
|
12 |
+
AutoTokenizer,
|
13 |
+
BitsAndBytesConfig,
|
14 |
+
TrainingArguments,
|
15 |
+
pipeline,
|
16 |
+
logging,
|
17 |
+
)
|
18 |
+
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
19 |
+
from trl import SFTTrainer
|
20 |
+
import warnings
|
21 |
+
|
22 |
+
# Suppress warnings
|
23 |
+
warnings.filterwarnings("ignore")
|
24 |
+
logging.set_verbosity(logging.CRITICAL)
|
25 |
+
|
26 |
+
# Configuration
|
27 |
+
class Config:
|
28 |
+
# Model configuration
|
29 |
+
model_name = "microsoft/DialoGPT-medium" # Replace with actual OpenAI OSS 120B model name
|
30 |
+
dataset_name = "smangrul/ad-copy-generation"
|
31 |
+
|
32 |
+
# Training parameters
|
33 |
+
output_dir = "./sft_results"
|
34 |
+
num_train_epochs = 3
|
35 |
+
per_device_train_batch_size = 1
|
36 |
+
gradient_accumulation_steps = 4
|
37 |
+
optim = "paged_adamw_32bit"
|
38 |
+
save_steps = 25
|
39 |
+
logging_steps = 25
|
40 |
+
learning_rate = 2e-4
|
41 |
+
weight_decay = 0.001
|
42 |
+
fp16 = False
|
43 |
+
bf16 = False
|
44 |
+
max_grad_norm = 0.3
|
45 |
+
max_steps = -1
|
46 |
+
warmup_ratio = 0.03
|
47 |
+
group_by_length = True
|
48 |
+
lr_scheduler_type = "constant"
|
49 |
+
report_to = "tensorboard"
|
50 |
+
|
51 |
+
# QLoRA parameters
|
52 |
+
lora_alpha = 16
|
53 |
+
lora_dropout = 0.1
|
54 |
+
lora_r = 64
|
55 |
+
|
56 |
+
# bitsandbytes parameters
|
57 |
+
use_4bit = True
|
58 |
+
bnb_4bit_compute_dtype = "float16"
|
59 |
+
bnb_4bit_quant_type = "nf4"
|
60 |
+
use_nested_quant = False
|
61 |
+
|
62 |
+
# SFT parameters
|
63 |
+
max_seq_length = 512
|
64 |
+
packing = False
|
65 |
+
|
66 |
+
def create_bnb_config():
|
67 |
+
"""Create BitsAndBytesConfig for 4-bit quantization"""
|
68 |
+
bnb_config = BitsAndBytesConfig(
|
69 |
+
load_in_4bit=Config.use_4bit,
|
70 |
+
bnb_4bit_quant_type=Config.bnb_4bit_quant_type,
|
71 |
+
bnb_4bit_compute_dtype=getattr(torch, Config.bnb_4bit_compute_dtype),
|
72 |
+
bnb_4bit_use_double_quant=Config.use_nested_quant,
|
73 |
+
)
|
74 |
+
return bnb_config
|
75 |
+
|
76 |
+
def load_model_and_tokenizer():
|
77 |
+
"""Load model and tokenizer with quantization"""
|
78 |
+
print("Loading model and tokenizer...")
|
79 |
+
|
80 |
+
# Create BnB config
|
81 |
+
bnb_config = create_bnb_config()
|
82 |
+
|
83 |
+
# Load model
|
84 |
+
model = AutoModelForCausalLM.from_pretrained(
|
85 |
+
Config.model_name,
|
86 |
+
quantization_config=bnb_config,
|
87 |
+
device_map="auto",
|
88 |
+
trust_remote_code=True,
|
89 |
+
use_auth_token=True, # If using gated model
|
90 |
+
)
|
91 |
+
model.config.use_cache = False
|
92 |
+
model.config.pretraining_tp = 1
|
93 |
+
|
94 |
+
# Load tokenizer
|
95 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
96 |
+
Config.model_name,
|
97 |
+
trust_remote_code=True,
|
98 |
+
use_auth_token=True, # If using gated model
|
99 |
+
)
|
100 |
+
tokenizer.pad_token = tokenizer.eos_token
|
101 |
+
tokenizer.padding_side = "right"
|
102 |
+
|
103 |
+
return model, tokenizer
|
104 |
+
|
105 |
+
def create_peft_config():
|
106 |
+
"""Create PEFT (LoRA) configuration"""
|
107 |
+
peft_config = LoraConfig(
|
108 |
+
task_type=TaskType.CAUSAL_LM,
|
109 |
+
inference_mode=False,
|
110 |
+
r=Config.lora_r,
|
111 |
+
lora_alpha=Config.lora_alpha,
|
112 |
+
lora_dropout=Config.lora_dropout,
|
113 |
+
target_modules=[
|
114 |
+
"q_proj",
|
115 |
+
"k_proj",
|
116 |
+
"v_proj",
|
117 |
+
"o_proj",
|
118 |
+
"gate_proj",
|
119 |
+
"up_proj",
|
120 |
+
"down_proj",
|
121 |
+
]
|
122 |
+
)
|
123 |
+
return peft_config
|
124 |
+
|
125 |
+
def load_and_prepare_dataset(tokenizer):
|
126 |
+
"""Load and prepare the dataset"""
|
127 |
+
print("Loading dataset...")
|
128 |
+
|
129 |
+
# Load dataset
|
130 |
+
dataset = load_dataset(Config.dataset_name, split="train")
|
131 |
+
print(f"Dataset loaded: {len(dataset)} samples")
|
132 |
+
|
133 |
+
# Format dataset for chat completion
|
134 |
+
def format_prompts(examples):
|
135 |
+
texts = []
|
136 |
+
for conversation in examples["conversations"]:
|
137 |
+
if len(conversation) >= 2:
|
138 |
+
user_msg = conversation[0]["value"]
|
139 |
+
assistant_msg = conversation[1]["value"]
|
140 |
+
|
141 |
+
# Format as chat template
|
142 |
+
text = f"### Human: {user_msg}\n### Assistant: {assistant_msg}{tokenizer.eos_token}"
|
143 |
+
texts.append(text)
|
144 |
+
else:
|
145 |
+
# Fallback for malformed data
|
146 |
+
texts.append(f"### Human: Create an advertisement\n### Assistant: {conversation[0]['value']}{tokenizer.eos_token}")
|
147 |
+
|
148 |
+
return {"text": texts}
|
149 |
+
|
150 |
+
# Apply formatting
|
151 |
+
dataset = dataset.map(
|
152 |
+
format_prompts,
|
153 |
+
batched=True,
|
154 |
+
remove_columns=dataset.column_names
|
155 |
+
)
|
156 |
+
|
157 |
+
return dataset
|
158 |
+
|
159 |
+
def create_training_arguments():
|
160 |
+
"""Create training arguments"""
|
161 |
+
training_arguments = TrainingArguments(
|
162 |
+
output_dir=Config.output_dir,
|
163 |
+
num_train_epochs=Config.num_train_epochs,
|
164 |
+
per_device_train_batch_size=Config.per_device_train_batch_size,
|
165 |
+
gradient_accumulation_steps=Config.gradient_accumulation_steps,
|
166 |
+
optim=Config.optim,
|
167 |
+
save_steps=Config.save_steps,
|
168 |
+
logging_steps=Config.logging_steps,
|
169 |
+
learning_rate=Config.learning_rate,
|
170 |
+
weight_decay=Config.weight_decay,
|
171 |
+
fp16=Config.fp16,
|
172 |
+
bf16=Config.bf16,
|
173 |
+
max_grad_norm=Config.max_grad_norm,
|
174 |
+
max_steps=Config.max_steps,
|
175 |
+
warmup_ratio=Config.warmup_ratio,
|
176 |
+
group_by_length=Config.group_by_length,
|
177 |
+
lr_scheduler_type=Config.lr_scheduler_type,
|
178 |
+
report_to=Config.report_to,
|
179 |
+
save_strategy="steps",
|
180 |
+
evaluation_strategy="no",
|
181 |
+
load_best_model_at_end=False,
|
182 |
+
push_to_hub=False,
|
183 |
+
remove_unused_columns=False,
|
184 |
+
)
|
185 |
+
return training_arguments
|
186 |
+
|
187 |
+
def main():
|
188 |
+
"""Main fine-tuning function"""
|
189 |
+
print("🚀 Starting QLoRA fine-tuning of OpenAI OSS 120B model")
|
190 |
+
|
191 |
+
# Check CUDA availability
|
192 |
+
if not torch.cuda.is_available():
|
193 |
+
raise RuntimeError("CUDA is required for this training script")
|
194 |
+
|
195 |
+
print(f"Using GPU: {torch.cuda.get_device_name()}")
|
196 |
+
print(f"Available VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
197 |
+
|
198 |
+
# Load model and tokenizer
|
199 |
+
model, tokenizer = load_model_and_tokenizer()
|
200 |
+
|
201 |
+
# Apply PEFT
|
202 |
+
peft_config = create_peft_config()
|
203 |
+
model = get_peft_model(model, peft_config)
|
204 |
+
model.print_trainable_parameters()
|
205 |
+
|
206 |
+
# Load and prepare dataset
|
207 |
+
dataset = load_and_prepare_dataset(tokenizer)
|
208 |
+
|
209 |
+
# Create training arguments
|
210 |
+
training_arguments = create_training_arguments()
|
211 |
+
|
212 |
+
# Create trainer
|
213 |
+
trainer = SFTTrainer(
|
214 |
+
model=model,
|
215 |
+
train_dataset=dataset,
|
216 |
+
peft_config=peft_config,
|
217 |
+
dataset_text_field="text",
|
218 |
+
max_seq_length=Config.max_seq_length,
|
219 |
+
tokenizer=tokenizer,
|
220 |
+
args=training_arguments,
|
221 |
+
packing=Config.packing,
|
222 |
+
)
|
223 |
+
|
224 |
+
# Start training
|
225 |
+
print("🔥 Starting training...")
|
226 |
+
trainer.train()
|
227 |
+
|
228 |
+
# Save model
|
229 |
+
print("💾 Saving model...")
|
230 |
+
trainer.model.save_pretrained(Config.output_dir)
|
231 |
+
tokenizer.save_pretrained(Config.output_dir)
|
232 |
+
|
233 |
+
print("✅ Training completed!")
|
234 |
+
|
235 |
+
# Test the model
|
236 |
+
test_model(trainer.model, tokenizer)
|
237 |
+
|
238 |
+
def test_model(model, tokenizer):
|
239 |
+
"""Test the fine-tuned model"""
|
240 |
+
print("\n🧪 Testing the fine-tuned model...")
|
241 |
+
|
242 |
+
# Test prompts
|
243 |
+
test_prompts = [
|
244 |
+
"Create an advertisement for a new smartphone with advanced camera features",
|
245 |
+
"Write ad copy for an eco-friendly clothing brand targeting young professionals",
|
246 |
+
"Generate marketing content for a fitness app with AI personal trainer",
|
247 |
+
]
|
248 |
+
|
249 |
+
for prompt in test_prompts:
|
250 |
+
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
251 |
+
|
252 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
253 |
+
|
254 |
+
with torch.no_grad():
|
255 |
+
outputs = model.generate(
|
256 |
+
**inputs,
|
257 |
+
max_new_tokens=150,
|
258 |
+
do_sample=True,
|
259 |
+
temperature=0.7,
|
260 |
+
top_p=0.9,
|
261 |
+
pad_token_id=tokenizer.eos_token_id,
|
262 |
+
)
|
263 |
+
|
264 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
265 |
+
generated_text = response[len(formatted_prompt):].strip()
|
266 |
+
|
267 |
+
print(f"\n📝 Prompt: {prompt}")
|
268 |
+
print(f"📄 Generated: {generated_text}")
|
269 |
+
print("-" * 50)
|
270 |
+
|
271 |
+
if __name__ == "__main__":
|
272 |
+
# Set environment variables
|
273 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
274 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
275 |
+
|
276 |
+
main()
|
requirements.txt
CHANGED
@@ -1,17 +1,55 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
diffusers
|
3 |
invisible_watermark
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
8 |
flickrapi
|
9 |
requests
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 核心深度学习框架
|
2 |
+
torch>=2.0.0
|
3 |
+
torchvision
|
4 |
+
xformers
|
5 |
+
|
6 |
+
# Transformers生态
|
7 |
+
transformers>=4.35.0
|
8 |
+
accelerate>=0.24.0
|
9 |
+
tokenizers
|
10 |
+
huggingface_hub
|
11 |
+
|
12 |
+
# 数据处理
|
13 |
+
datasets>=2.14.0
|
14 |
+
numpy>=1.24.0
|
15 |
+
sentence-transformers
|
16 |
+
faiss-cpu
|
17 |
+
|
18 |
+
# 模型微调和RLHF
|
19 |
+
peft>=0.9.0
|
20 |
+
trl[peft]>=0.7.10
|
21 |
+
bitsandbytes>=0.41.0
|
22 |
+
|
23 |
+
# 图像生成
|
24 |
diffusers
|
25 |
invisible_watermark
|
26 |
+
|
27 |
+
# 数据标注
|
28 |
+
label-studio
|
29 |
+
|
30 |
+
# API和网络请求
|
31 |
flickrapi
|
32 |
requests
|
33 |
+
|
34 |
+
# 实验跟踪和可视化
|
35 |
+
wandb>=0.15.0
|
36 |
+
tensorboard>=2.13.0
|
37 |
+
|
38 |
+
# 评估指标
|
39 |
+
evaluate
|
40 |
+
sacrebleu
|
41 |
+
rouge-score
|
42 |
+
|
43 |
+
# 系统工具和监控
|
44 |
+
scipy
|
45 |
+
protobuf
|
46 |
+
sentencepiece
|
47 |
+
alive_progress
|
48 |
+
psutil
|
49 |
+
gpustat
|
50 |
+
|
51 |
+
# 高级优化器(可选)
|
52 |
+
deepspeed>=0.10.0
|
53 |
+
|
54 |
+
# RLHF特定工具
|
55 |
+
reward-bench
|