MedicalGPT-main / rl_training.py
nengrenjie83's picture
Upload 28 files
b78b52f
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description: Train a model from SFT using PPO
"""
import os
from dataclasses import dataclass, field
from glob import glob
from typing import Optional
import torch
from datasets import load_dataset
from loguru import logger
from peft import LoraConfig, TaskType
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
BloomForCausalLM,
AutoModelForCausalLM,
AutoModel,
LlamaTokenizer,
LlamaForCausalLM,
BloomTokenizerFast,
AutoTokenizer,
HfArgumentParser,
)
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from supervised_finetuning import get_conv_template
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
MODEL_CLASSES = {
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
}
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
"""
# Model arguments
model_type: str = field(
default=None,
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
)
model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The model checkpoint for weights initialization."}
)
reward_model_name_or_path: Optional[str] = field(default=None, metadata={"help": "The reward model name"})
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The tokenizer for weights initialization."}
)
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
device_map: Optional[str] = field(
default="auto",
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
)
trust_remote_code: bool = field(
default=True,
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
)
# Dataset arguments
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The template name."})
batch_size: Optional[int] = field(default=8, metadata={"help": "Batch size"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "PPO minibatch size"})
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=1,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None, metadata={"help": "The number of processes to use for the preprocessing."},
)
# Training arguments
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
target_modules: Optional[str] = field(default=None)
lora_rank: Optional[int] = field(default=8)
lora_dropout: Optional[float] = field(default=0.05)
lora_alpha: Optional[float] = field(default=32.0)
modules_to_save: Optional[str] = field(default=None)
peft_path: Optional[str] = field(default=None)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
early_stopping: Optional[bool] = field(default=False, metadata={"help": "Whether to early stop"})
target_kl: Optional[float] = field(default=0.1, metadata={"help": "The kl target for early stopping"})
reward_baseline: Optional[float] = field(
default=0.0, metadata={"help": "Baseline value that is subtracted from the reward"},
)
init_kl_coef: Optional[float] = field(
default=0.2, metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
)
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
learning_rate: Optional[float] = field(default=1.5e-5, metadata={"help": "Learning rate"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
output_dir: Optional[str] = field(default="outputs-rl", metadata={"help": "The output directory"})
seed: Optional[int] = field(default=0, metadata={"help": "Seed"})
max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
def __post_init__(self):
if self.model_type is None:
raise ValueError("You must specify a valid model_type to run training.")
if self.model_name_or_path is None:
raise ValueError("You must specify a valid model_name_or_path to run training.")
if self.reward_model_name_or_path is None:
raise ValueError("You must specify a valid reward_model_name_or_path to run training.")
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def get_reward_model_output(reward_model, reward_tokenizer, question, answer, device):
"""
Get the reward score for a given question and answer pair.
"""
inputs = reward_tokenizer(question, answer, return_tensors='pt').to(device)
score = reward_model(**inputs).logits[0].cpu().detach()
return score
def calculate_rewards(reward_score_outputs, reward_baseline=0):
"""
Calculate the reward for a given score output.
:param reward_score_outputs:
:param reward_baseline:
:return:
"""
rewards = []
for score in reward_score_outputs:
if isinstance(score, torch.Tensor) and score.numel() == 1:
reward_value = score.item() - reward_baseline
rewards.append(torch.tensor(reward_value))
else:
# Use the average of the tensor elements as `score` is multiple elements
reward_value = torch.mean(score).item() - reward_baseline
rewards.append(torch.tensor(reward_value))
return rewards
def main():
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
logger.info(f"Parse args: {args}")
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if args.model_type == 'bloom':
args.use_fast_tokenizer = True
# Load tokenizer
tokenizer_kwargs = {
"cache_dir": args.cache_dir,
"use_fast": args.use_fast_tokenizer,
"trust_remote_code": args.trust_remote_code,
}
tokenizer_name_or_path = args.tokenizer_name_or_path
if not tokenizer_name_or_path:
tokenizer_name_or_path = args.model_name_or_path
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0 # set as the <unk> token
logger.info("Load model")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=args.target_modules,
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
torch_dtype = (
args.torch_dtype
if args.torch_dtype in ["auto", None]
else getattr(torch, args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
if world_size > 1:
args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
config = config_class.from_pretrained(
args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=args.trust_remote_code,
cache_dir=args.cache_dir
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
args.model_name_or_path,
config=config,
load_in_8bit=args.load_in_8bit,
device_map=args.device_map,
trust_remote_code=args.trust_remote_code,
peft_config=peft_config if args.use_peft else None,
)
print_trainable_parameters(model)
# Load reward model
device = "cuda" if torch.cuda.is_available() else "cpu"
reward_model = AutoModelForSequenceClassification.from_pretrained(
args.reward_model_name_or_path,
config=config,
load_in_8bit=args.load_in_8bit,
trust_remote_code=args.trust_remote_code,
)
reward_model.to(device)
reward_tokenizer = AutoTokenizer.from_pretrained(
args.reward_model_name_or_path, **tokenizer_kwargs
)
# Get datasets
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
raw_datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
else:
data_files = {}
if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
f'{args.train_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"train files: {', '.join(train_data_files)}")
data_files["train"] = train_data_files
if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"eval files: {', '.join(eval_data_files)}")
data_files["validation"] = eval_data_files
raw_datasets = load_dataset(
'json',
data_files=data_files,
cache_dir=args.cache_dir,
)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
'json',
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
raw_datasets["train"] = load_dataset(
'json',
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
logger.info(f"Raw datasets: {raw_datasets}")
# Preprocessing the datasets
max_source_length = args.max_source_length
max_target_length = args.max_target_length
prompt_template = get_conv_template(args.template_name)
def preprocess_function(examples):
new_examples = {
"query": [],
"input_ids": [],
}
roles = ["human", "gpt"]
def get_prompt(examples):
for i, source in enumerate(examples['conversations']):
if len(source) < 2:
continue
data_role = source[0].get("from", "")
if data_role not in roles or data_role != roles[0]:
# Skip the first one if it is not from human
source = source[1:]
if len(source) < 2:
continue
messages = []
for j, sentence in enumerate(source):
data_role = sentence.get("from", "")
if data_role not in roles:
logger.warning(f"unknown role: {data_role}, {i}. (ignored)")
break
if data_role == roles[j % 2]:
messages.append(sentence["value"])
if len(messages) < 2 or len(messages) % 2 != 0:
continue
# Convert the list to pairs of elements
history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)]
yield prompt_template.get_prompt(history_messages)
for prompt in get_prompt(examples):
for i in range(len(prompt) // 2):
source_txt = prompt[2 * i]
tokenized_question = tokenizer(
source_txt, truncation=True, max_length=max_source_length, padding="max_length",
return_tensors="pt"
)
new_examples["query"].append(source_txt)
new_examples["input_ids"].append(tokenized_question["input_ids"])
return new_examples
# Preprocess the dataset
train_dataset = None
if args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets['train']
if args.max_train_samples is not None and args.max_train_samples > 0:
max_train_samples = min(len(train_dataset), args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
tokenized_dataset = train_dataset.shuffle().map(
preprocess_function,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=train_dataset.column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
train_dataset = tokenized_dataset.filter(
lambda x: len(x['input_ids']) > 0
)
logger.debug(f"Num train_samples: {len(train_dataset)}")
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
output_dir = args.output_dir
config = PPOConfig(
steps=args.max_steps,
model_name=args.model_name_or_path,
learning_rate=args.learning_rate,
log_with=args.report_to,
batch_size=args.batch_size,
mini_batch_size=args.mini_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
optimize_cuda_cache=True,
early_stopping=args.early_stopping,
target_kl=args.target_kl,
seed=args.seed,
init_kl_coef=args.init_kl_coef,
adap_kl_ctrl=args.adap_kl_ctrl,
project_kwargs={"logging_dir": output_dir},
)
# Set seed before initializing value head for deterministic eval
set_seed(config.seed)
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
trainer = PPOTrainer(
config,
model,
ref_model=None,
tokenizer=tokenizer,
dataset=train_dataset,
data_collator=collator,
)
# These arguments are passed to the `generate` function of the PPOTrainer
generation_kwargs = {
"max_new_tokens": max_target_length,
"temperature": 1.0,
"repetition_penalty": 1.0,
"top_p": 1.0,
"do_sample": True,
}
def save_model(save_dir):
trainer.accelerator.unwrap_model(trainer.model).save_pretrained(save_dir)
trainer.tokenizer.save_pretrained(save_dir)
# Training
if args.do_train:
logger.info("*** Train ***")
total_steps = config.total_ppo_epochs
for step, batch in tqdm(enumerate(trainer.dataloader)):
if step >= total_steps:
break
question_tensors = batch["input_ids"]
question_tensors = [torch.LongTensor(i).to(device).squeeze(0) for i in question_tensors]
responses = []
response_tensors = []
for q_tensor in question_tensors:
response_tensor = trainer.generate(
q_tensor,
return_prompt=False,
**generation_kwargs,
)
r = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)[0]
responses.append(r)
response_tensors.append(response_tensor.squeeze(0))
batch["response"] = responses
# Compute reward score
score_outputs = [
get_reward_model_output(reward_model, reward_tokenizer, q, r, device) for q, r in
zip(batch["query"], batch["response"])
]
rewards = calculate_rewards(score_outputs, args.reward_baseline)
# Run PPO step
try:
stats = trainer.step(question_tensors, response_tensors, rewards)
trainer.log_stats(stats, batch, rewards)
logger.debug(f"Step {step}/{total_steps}: reward score:{score_outputs}")
except ValueError as e:
logger.warning(f"Failed to log stats for step {step}, because of {e}")
if step and step % args.save_steps == 0:
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
save_model(save_dir)
# Save final model
save_model(output_dir)
if __name__ == "__main__":
main()