MedicalGPT-main / dpo_training.py
nengrenjie83's picture
Upload 28 files
b78b52f
raw
history blame
21 kB
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description: Train a model from SFT using DPO
"""
import os
from dataclasses import dataclass, field
from glob import glob
from typing import Dict, Optional
import torch
from datasets import load_dataset
from loguru import logger
from peft import LoraConfig, TaskType
from transformers import (
AutoConfig,
BloomForCausalLM,
AutoModelForCausalLM,
AutoModel,
LlamaTokenizer,
LlamaForCausalLM,
BloomTokenizerFast,
AutoTokenizer,
HfArgumentParser,
TrainingArguments,
BitsAndBytesConfig,
)
from transformers.deepspeed import is_deepspeed_zero3_enabled
from trl import DPOTrainer
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
MODEL_CLASSES = {
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
}
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with DPO
"""
# Model arguments
model_type: str = field(
default=None,
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
)
model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The model checkpoint for weights initialization."}
)
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The tokenizer for weights initialization."}
)
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
load_in_4bit: bool = field(default=False, metadata={"help": "Whether to load the model in 4bit mode or not."})
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
device_map: Optional[str] = field(
default="auto",
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
)
trust_remote_code: bool = field(
default=True,
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
)
# Dataset arguments
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "Train batch size per device"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "Eval batch size per device"})
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=1,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=4, metadata={"help": "The number of processes to use for the preprocessing."},
)
# Training arguments
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
target_modules: Optional[str] = field(default=None)
lora_rank: Optional[int] = field(default=8)
lora_dropout: Optional[float] = field(default=0.05)
lora_alpha: Optional[float] = field(default=16.0)
peft_path: Optional[str] = field(default=None)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for DPO loss"})
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "Learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "The lr scheduler type"})
warmup_steps: Optional[int] = field(default=100, metadata={"help": "The number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "The weight decay"})
optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer type"})
fp16: Optional[bool] = field(default=True, metadata={"help": "Whether to use fp16"})
bf16: Optional[bool] = field(default=False, metadata={"help": "Whether to use bf16"})
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "Whether to use gradient checkpointing"}
)
gradient_accumulation_steps: Optional[int] = field(
default=4, metadata={"help": "The number of gradient accumulation steps"}
)
save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
eval_steps: Optional[int] = field(default=50, metadata={"help": "X steps to evaluate the model"})
logging_steps: Optional[int] = field(default=1, metadata={"help": "X steps to log the model"})
output_dir: Optional[str] = field(default="outputs-dpo", metadata={"help": "The output directory"})
max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
eval_strategy: Optional[str] = field(default="steps", metadata={"help": "Evaluation strategy"})
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={"help": "Remove unused columns from the dataset if `datasets.Dataset` is used"},
)
report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
def __post_init__(self):
if self.model_type is None:
raise ValueError("You must specify a valid model_type to run training.")
if self.model_name_or_path is None:
raise ValueError("You must specify a valid model_name_or_path to run training.")
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def find_all_linear_names(peft_model, int4=False, int8=False):
"""Find all linear layer names in the model. reference from qlora paper."""
cls = torch.nn.Linear
if int4 or int8:
import bitsandbytes as bnb
if int4:
cls = bnb.nn.Linear4bit
elif int8:
cls = bnb.nn.Linear8bitLt
lora_module_names = set()
for name, module in peft_model.named_modules():
if isinstance(module, cls):
# last layer is not add to lora_module_names
if 'lm_head' in name:
continue
if 'output_layer' in name:
continue
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
return sorted(lora_module_names)
def return_prompt_and_responses(examples) -> Dict[str, str]:
"""Load the paired dataset and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts are structured as follows:
"Question: " + <prompt> + "\n\nAnswer: "
"""
return {
"prompt": ["Question: " + question + "\n\nAnswer: " for question in examples["question"]],
"chosen": examples["response_chosen"],
"rejected": examples["response_rejected"],
}
def main():
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
logger.info(f"Parse args: {args}")
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if args.model_type == 'bloom':
args.use_fast_tokenizer = True
# Load tokenizer
tokenizer_kwargs = {
"cache_dir": args.cache_dir,
"use_fast": args.use_fast_tokenizer,
"trust_remote_code": args.trust_remote_code,
}
tokenizer_name_or_path = args.tokenizer_name_or_path
if not tokenizer_name_or_path:
tokenizer_name_or_path = args.model_name_or_path
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0 # set as the <unk> token
# Get datasets
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
raw_datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
else:
data_files = {}
if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
f'{args.train_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"train files: {', '.join(train_data_files)}")
data_files["train"] = train_data_files
if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"eval files: {', '.join(eval_data_files)}")
data_files["validation"] = eval_data_files
raw_datasets = load_dataset(
'json',
data_files=data_files,
cache_dir=args.cache_dir,
)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
'json',
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
raw_datasets["train"] = load_dataset(
'json',
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
logger.info(f"Raw datasets: {raw_datasets}")
# Preprocessing the datasets
max_source_length = args.max_source_length
max_target_length = args.max_target_length
full_max_length = max_source_length + max_target_length
# Preprocess the dataset
train_dataset = None
max_train_samples = 0
if args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets['train']
max_train_samples = len(train_dataset)
if args.max_train_samples is not None and args.max_train_samples > 0:
max_train_samples = min(len(train_dataset), args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
tokenized_dataset = train_dataset.shuffle().map(
return_prompt_and_responses,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=train_dataset.column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
train_dataset = tokenized_dataset.filter(
lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
)
logger.debug(f"Num train_samples: {len(train_dataset)}")
logger.debug("First train example:")
logger.debug(train_dataset[0]['prompt'] + train_dataset[0]['chosen'])
eval_dataset = None
max_eval_samples = 0
if args.do_eval:
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
max_eval_samples = len(eval_dataset)
if args.max_eval_samples is not None and args.max_eval_samples > 0:
max_eval_samples = min(len(eval_dataset), args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
eval_dataset = eval_dataset.map(
return_prompt_and_responses,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=eval_dataset.column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
eval_dataset = eval_dataset.filter(
lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
)
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
logger.debug("First eval example:")
logger.debug(eval_dataset[0]['prompt'] + eval_dataset[0]['chosen'])
logger.info("Loading model")
torch_dtype = (
args.torch_dtype
if args.torch_dtype in ["auto", None]
else getattr(torch, args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
if args.qlora and is_deepspeed_zero3_enabled():
logger.warning("ZeRO3 are both currently incompatible with QLoRA.")
config = config_class.from_pretrained(
args.model_name_or_path,
trust_remote_code=args.trust_remote_code,
torch_dtype=torch_dtype,
cache_dir=args.cache_dir
)
model = model_class.from_pretrained(
args.model_name_or_path,
config=config,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
device_map=args.device_map,
trust_remote_code=args.trust_remote_code,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
) if args.qlora else None,
)
model_ref = model_class.from_pretrained(
args.model_name_or_path,
config=config,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
device_map=args.device_map,
trust_remote_code=args.trust_remote_code,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
) if args.qlora else None,
)
# Initialize our Trainer
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
model.config.use_cache = False
else:
model.config.use_cache = True
training_args = TrainingArguments(
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
max_steps=args.max_steps,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.gradient_checkpointing,
learning_rate=args.learning_rate,
evaluation_strategy=args.eval_strategy,
eval_steps=args.eval_steps,
output_dir=args.output_dir,
report_to=args.report_to,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.warmup_steps,
optim=args.optim,
bf16=args.bf16,
fp16=args.fp16,
remove_unused_columns=args.remove_unused_columns,
run_name=f"dpo_{args.model_type}",
)
# Initialize DPO trainer
target_modules = args.target_modules.split(',') if args.target_modules else None
if target_modules and 'all' in target_modules:
target_modules = find_all_linear_names(model, int4=args.load_in_4bit, int8=args.load_in_8bit)
logger.info(f"Peft target_modules: {target_modules}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config if args.use_peft else None,
max_prompt_length=args.max_source_length,
max_length=full_max_length,
)
print_trainable_parameters(trainer.model)
# Training
if args.do_train:
logger.info("*** Train ***")
train_result = trainer.train()
metrics = train_result.metrics
metrics["train_samples"] = max_train_samples
logger.debug(f"Training metrics: {metrics}")
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info(f"Saving model checkpoint to {args.output_dir}")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
trainer.model.save_pretrained(args.output_dir)
# Evaluation
if args.do_eval and trainer.is_world_process_zero():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = max_eval_samples
logger.debug(f"Eval metrics: {metrics}")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if __name__ == "__main__":
main()