Spaces:
Configuration error
Configuration error
# -*- coding: utf-8 -*- | |
""" | |
@author:XuMing([email protected]) | |
@description: Train a model from SFT using PPO | |
""" | |
import os | |
from dataclasses import dataclass, field | |
from glob import glob | |
from typing import Optional | |
import torch | |
from datasets import load_dataset | |
from loguru import logger | |
from peft import LoraConfig, TaskType | |
from tqdm import tqdm | |
from transformers import ( | |
AutoConfig, | |
AutoModelForSequenceClassification, | |
BloomForCausalLM, | |
AutoModelForCausalLM, | |
AutoModel, | |
LlamaTokenizer, | |
LlamaForCausalLM, | |
BloomTokenizerFast, | |
AutoTokenizer, | |
HfArgumentParser, | |
) | |
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed | |
from supervised_finetuning import get_conv_template | |
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE" | |
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
MODEL_CLASSES = { | |
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast), | |
"chatglm": (AutoConfig, AutoModel, AutoTokenizer), | |
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer), | |
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer), | |
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer), | |
} | |
class ScriptArguments: | |
""" | |
The name of the Casual LM model we wish to fine with PPO | |
""" | |
# Model arguments | |
model_type: str = field( | |
default=None, | |
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())} | |
) | |
model_name_or_path: Optional[str] = field( | |
default=None, metadata={"help": "The model checkpoint for weights initialization."} | |
) | |
reward_model_name_or_path: Optional[str] = field(default=None, metadata={"help": "The reward model name"}) | |
tokenizer_name_or_path: Optional[str] = field( | |
default=None, metadata={"help": "The tokenizer for weights initialization."} | |
) | |
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."}) | |
cache_dir: Optional[str] = field( | |
default=None, | |
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, | |
) | |
use_fast_tokenizer: bool = field( | |
default=False, | |
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, | |
) | |
torch_dtype: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " | |
"dtype will be automatically derived from the model's weights." | |
), | |
"choices": ["auto", "bfloat16", "float16", "float32"], | |
}, | |
) | |
device_map: Optional[str] = field( | |
default="auto", | |
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "}, | |
) | |
trust_remote_code: bool = field( | |
default=True, | |
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."}, | |
) | |
# Dataset arguments | |
dataset_name: Optional[str] = field( | |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} | |
) | |
dataset_config_name: Optional[str] = field( | |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} | |
) | |
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."}) | |
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, ) | |
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The template name."}) | |
batch_size: Optional[int] = field(default=8, metadata={"help": "Batch size"}) | |
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "PPO minibatch size"}) | |
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"}) | |
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"}) | |
min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"}) | |
max_train_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of training examples to this " | |
"value if set." | |
) | |
}, | |
) | |
max_eval_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of evaluation examples to this " | |
"value if set." | |
) | |
}, | |
) | |
overwrite_cache: bool = field( | |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | |
) | |
validation_split_percentage: Optional[int] = field( | |
default=1, | |
metadata={ | |
"help": "The percentage of the train set used as validation set in case there's no validation split" | |
}, | |
) | |
preprocessing_num_workers: Optional[int] = field( | |
default=None, metadata={"help": "The number of processes to use for the preprocessing."}, | |
) | |
# Training arguments | |
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"}) | |
target_modules: Optional[str] = field(default=None) | |
lora_rank: Optional[int] = field(default=8) | |
lora_dropout: Optional[float] = field(default=0.05) | |
lora_alpha: Optional[float] = field(default=32.0) | |
modules_to_save: Optional[str] = field(default=None) | |
peft_path: Optional[str] = field(default=None) | |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) | |
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."}) | |
early_stopping: Optional[bool] = field(default=False, metadata={"help": "Whether to early stop"}) | |
target_kl: Optional[float] = field(default=0.1, metadata={"help": "The kl target for early stopping"}) | |
reward_baseline: Optional[float] = field( | |
default=0.0, metadata={"help": "Baseline value that is subtracted from the reward"}, | |
) | |
init_kl_coef: Optional[float] = field( | |
default=0.2, metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"}, | |
) | |
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"}) | |
learning_rate: Optional[float] = field(default=1.5e-5, metadata={"help": "Learning rate"}) | |
gradient_accumulation_steps: Optional[int] = field( | |
default=1, metadata={"help": "the number of gradient accumulation steps"} | |
) | |
save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"}) | |
output_dir: Optional[str] = field(default="outputs-rl", metadata={"help": "The output directory"}) | |
seed: Optional[int] = field(default=0, metadata={"help": "Seed"}) | |
max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"}) | |
report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"}) | |
def __post_init__(self): | |
if self.model_type is None: | |
raise ValueError("You must specify a valid model_type to run training.") | |
if self.model_name_or_path is None: | |
raise ValueError("You must specify a valid model_name_or_path to run training.") | |
if self.reward_model_name_or_path is None: | |
raise ValueError("You must specify a valid reward_model_name_or_path to run training.") | |
def print_trainable_parameters(model): | |
""" | |
Prints the number of trainable parameters in the model. | |
""" | |
trainable_params = 0 | |
all_param = 0 | |
for _, param in model.named_parameters(): | |
all_param += param.numel() | |
if param.requires_grad: | |
trainable_params += param.numel() | |
print( | |
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" | |
) | |
def get_reward_model_output(reward_model, reward_tokenizer, question, answer, device): | |
""" | |
Get the reward score for a given question and answer pair. | |
""" | |
inputs = reward_tokenizer(question, answer, return_tensors='pt').to(device) | |
score = reward_model(**inputs).logits[0].cpu().detach() | |
return score | |
def calculate_rewards(reward_score_outputs, reward_baseline=0): | |
""" | |
Calculate the reward for a given score output. | |
:param reward_score_outputs: | |
:param reward_baseline: | |
:return: | |
""" | |
rewards = [] | |
for score in reward_score_outputs: | |
if isinstance(score, torch.Tensor) and score.numel() == 1: | |
reward_value = score.item() - reward_baseline | |
rewards.append(torch.tensor(reward_value)) | |
else: | |
# Use the average of the tensor elements as `score` is multiple elements | |
reward_value = torch.mean(score).item() - reward_baseline | |
rewards.append(torch.tensor(reward_value)) | |
return rewards | |
def main(): | |
parser = HfArgumentParser(ScriptArguments) | |
args = parser.parse_args_into_dataclasses()[0] | |
logger.info(f"Parse args: {args}") | |
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] | |
if args.model_type == 'bloom': | |
args.use_fast_tokenizer = True | |
# Load tokenizer | |
tokenizer_kwargs = { | |
"cache_dir": args.cache_dir, | |
"use_fast": args.use_fast_tokenizer, | |
"trust_remote_code": args.trust_remote_code, | |
} | |
tokenizer_name_or_path = args.tokenizer_name_or_path | |
if not tokenizer_name_or_path: | |
tokenizer_name_or_path = args.model_name_or_path | |
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = 0 # set as the <unk> token | |
logger.info("Load model") | |
peft_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
target_modules=args.target_modules, | |
inference_mode=False, | |
r=args.lora_rank, | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
) | |
torch_dtype = ( | |
args.torch_dtype | |
if args.torch_dtype in ["auto", None] | |
else getattr(torch, args.torch_dtype) | |
) | |
world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
if world_size > 1: | |
args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0} | |
config = config_class.from_pretrained( | |
args.model_name_or_path, | |
torch_dtype=torch_dtype, | |
trust_remote_code=args.trust_remote_code, | |
cache_dir=args.cache_dir | |
) | |
model = AutoModelForCausalLMWithValueHead.from_pretrained( | |
args.model_name_or_path, | |
config=config, | |
load_in_8bit=args.load_in_8bit, | |
device_map=args.device_map, | |
trust_remote_code=args.trust_remote_code, | |
peft_config=peft_config if args.use_peft else None, | |
) | |
print_trainable_parameters(model) | |
# Load reward model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
reward_model = AutoModelForSequenceClassification.from_pretrained( | |
args.reward_model_name_or_path, | |
config=config, | |
load_in_8bit=args.load_in_8bit, | |
trust_remote_code=args.trust_remote_code, | |
) | |
reward_model.to(device) | |
reward_tokenizer = AutoTokenizer.from_pretrained( | |
args.reward_model_name_or_path, **tokenizer_kwargs | |
) | |
# Get datasets | |
if args.dataset_name is not None: | |
# Downloading and loading a dataset from the hub. | |
raw_datasets = load_dataset( | |
args.dataset_name, | |
args.dataset_config_name, | |
cache_dir=args.cache_dir, | |
) | |
if "validation" not in raw_datasets.keys(): | |
raw_datasets["validation"] = load_dataset( | |
args.dataset_name, | |
args.dataset_config_name, | |
split=f"train[:{args.validation_split_percentage}%]", | |
cache_dir=args.cache_dir, | |
) | |
raw_datasets["train"] = load_dataset( | |
args.dataset_name, | |
args.dataset_config_name, | |
split=f"train[{args.validation_split_percentage}%:]", | |
cache_dir=args.cache_dir, | |
) | |
else: | |
data_files = {} | |
if args.train_file_dir is not None and os.path.exists(args.train_file_dir): | |
train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob( | |
f'{args.train_file_dir}/**/*.jsonl', recursive=True) | |
logger.info(f"train files: {', '.join(train_data_files)}") | |
data_files["train"] = train_data_files | |
if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir): | |
eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob( | |
f'{args.validation_file_dir}/**/*.jsonl', recursive=True) | |
logger.info(f"eval files: {', '.join(eval_data_files)}") | |
data_files["validation"] = eval_data_files | |
raw_datasets = load_dataset( | |
'json', | |
data_files=data_files, | |
cache_dir=args.cache_dir, | |
) | |
# If no validation data is there, validation_split_percentage will be used to divide the dataset. | |
if "validation" not in raw_datasets.keys(): | |
raw_datasets["validation"] = load_dataset( | |
'json', | |
data_files=data_files, | |
split=f"train[:{args.validation_split_percentage}%]", | |
cache_dir=args.cache_dir, | |
) | |
raw_datasets["train"] = load_dataset( | |
'json', | |
data_files=data_files, | |
split=f"train[{args.validation_split_percentage}%:]", | |
cache_dir=args.cache_dir, | |
) | |
logger.info(f"Raw datasets: {raw_datasets}") | |
# Preprocessing the datasets | |
max_source_length = args.max_source_length | |
max_target_length = args.max_target_length | |
prompt_template = get_conv_template(args.template_name) | |
def preprocess_function(examples): | |
new_examples = { | |
"query": [], | |
"input_ids": [], | |
} | |
roles = ["human", "gpt"] | |
def get_prompt(examples): | |
for i, source in enumerate(examples['conversations']): | |
if len(source) < 2: | |
continue | |
data_role = source[0].get("from", "") | |
if data_role not in roles or data_role != roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
if len(source) < 2: | |
continue | |
messages = [] | |
for j, sentence in enumerate(source): | |
data_role = sentence.get("from", "") | |
if data_role not in roles: | |
logger.warning(f"unknown role: {data_role}, {i}. (ignored)") | |
break | |
if data_role == roles[j % 2]: | |
messages.append(sentence["value"]) | |
if len(messages) < 2 or len(messages) % 2 != 0: | |
continue | |
# Convert the list to pairs of elements | |
history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)] | |
yield prompt_template.get_prompt(history_messages) | |
for prompt in get_prompt(examples): | |
for i in range(len(prompt) // 2): | |
source_txt = prompt[2 * i] | |
tokenized_question = tokenizer( | |
source_txt, truncation=True, max_length=max_source_length, padding="max_length", | |
return_tensors="pt" | |
) | |
new_examples["query"].append(source_txt) | |
new_examples["input_ids"].append(tokenized_question["input_ids"]) | |
return new_examples | |
# Preprocess the dataset | |
train_dataset = None | |
if args.do_train: | |
if "train" not in raw_datasets: | |
raise ValueError("--do_train requires a train dataset") | |
train_dataset = raw_datasets['train'] | |
if args.max_train_samples is not None and args.max_train_samples > 0: | |
max_train_samples = min(len(train_dataset), args.max_train_samples) | |
train_dataset = train_dataset.select(range(max_train_samples)) | |
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}") | |
tokenized_dataset = train_dataset.shuffle().map( | |
preprocess_function, | |
batched=True, | |
num_proc=args.preprocessing_num_workers, | |
remove_columns=train_dataset.column_names, | |
load_from_cache_file=not args.overwrite_cache, | |
desc="Running tokenizer on dataset", | |
) | |
train_dataset = tokenized_dataset.filter( | |
lambda x: len(x['input_ids']) > 0 | |
) | |
logger.debug(f"Num train_samples: {len(train_dataset)}") | |
def collator(data): | |
return dict((key, [d[key] for d in data]) for key in data[0]) | |
output_dir = args.output_dir | |
config = PPOConfig( | |
steps=args.max_steps, | |
model_name=args.model_name_or_path, | |
learning_rate=args.learning_rate, | |
log_with=args.report_to, | |
batch_size=args.batch_size, | |
mini_batch_size=args.mini_batch_size, | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
optimize_cuda_cache=True, | |
early_stopping=args.early_stopping, | |
target_kl=args.target_kl, | |
seed=args.seed, | |
init_kl_coef=args.init_kl_coef, | |
adap_kl_ctrl=args.adap_kl_ctrl, | |
project_kwargs={"logging_dir": output_dir}, | |
) | |
# Set seed before initializing value head for deterministic eval | |
set_seed(config.seed) | |
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer | |
trainer = PPOTrainer( | |
config, | |
model, | |
ref_model=None, | |
tokenizer=tokenizer, | |
dataset=train_dataset, | |
data_collator=collator, | |
) | |
# These arguments are passed to the `generate` function of the PPOTrainer | |
generation_kwargs = { | |
"max_new_tokens": max_target_length, | |
"temperature": 1.0, | |
"repetition_penalty": 1.0, | |
"top_p": 1.0, | |
"do_sample": True, | |
} | |
def save_model(save_dir): | |
trainer.accelerator.unwrap_model(trainer.model).save_pretrained(save_dir) | |
trainer.tokenizer.save_pretrained(save_dir) | |
# Training | |
if args.do_train: | |
logger.info("*** Train ***") | |
total_steps = config.total_ppo_epochs | |
for step, batch in tqdm(enumerate(trainer.dataloader)): | |
if step >= total_steps: | |
break | |
question_tensors = batch["input_ids"] | |
question_tensors = [torch.LongTensor(i).to(device).squeeze(0) for i in question_tensors] | |
responses = [] | |
response_tensors = [] | |
for q_tensor in question_tensors: | |
response_tensor = trainer.generate( | |
q_tensor, | |
return_prompt=False, | |
**generation_kwargs, | |
) | |
r = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)[0] | |
responses.append(r) | |
response_tensors.append(response_tensor.squeeze(0)) | |
batch["response"] = responses | |
# Compute reward score | |
score_outputs = [ | |
get_reward_model_output(reward_model, reward_tokenizer, q, r, device) for q, r in | |
zip(batch["query"], batch["response"]) | |
] | |
rewards = calculate_rewards(score_outputs, args.reward_baseline) | |
# Run PPO step | |
try: | |
stats = trainer.step(question_tensors, response_tensors, rewards) | |
trainer.log_stats(stats, batch, rewards) | |
logger.debug(f"Step {step}/{total_steps}: reward score:{score_outputs}") | |
except ValueError as e: | |
logger.warning(f"Failed to log stats for step {step}, because of {e}") | |
if step and step % args.save_steps == 0: | |
save_dir = os.path.join(output_dir, f"checkpoint-{step}") | |
save_model(save_dir) | |
# Save final model | |
save_model(output_dir) | |
if __name__ == "__main__": | |
main() | |