|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
import copy |
|
import os |
|
from dataclasses import dataclass, field |
|
import random |
|
import json |
|
import logging |
|
import pathlib |
|
from typing import Dict, Optional, Sequence, List |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
from deepspeed import zero |
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType |
|
|
|
import transformers |
|
from torch.utils.data import Dataset |
|
from transformers import Trainer, AddedToken, BitsAndBytesConfig, deepspeed |
|
|
|
from fastchat.train.train_flant5 import ( |
|
smart_tokenizer_and_embedding_resize, |
|
make_supervised_data_module, |
|
) |
|
|
|
from fastchat.train.train_lora import get_peft_state_maybe_zero_3 |
|
|
|
from fastchat.model.model_adapter import get_conversation_template |
|
|
|
default_conversation = get_conversation_template("t5") |
|
|
|
|
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_PAD_TOKEN = "[PAD]" |
|
DEFAULT_EOS_TOKEN = "</s>" |
|
DEFAULT_BOS_TOKEN = "</s>" |
|
DEFAULT_UNK_TOKEN = "</s>" |
|
|
|
|
|
@dataclass |
|
class LoraArguments: |
|
lora_r: int = 8 |
|
lora_alpha: int = 16 |
|
lora_dropout: float = 0.05 |
|
lora_target_modules: List[str] = field(default_factory=lambda: ["q", "v"]) |
|
lora_weight_path: str = "" |
|
lora_bias: str = "none" |
|
q_lora: bool = False |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
|
|
|
|
|
@dataclass |
|
class DataArguments: |
|
data_path: str = field( |
|
default=None, metadata={"help": "Path to the training data."} |
|
) |
|
lazy_preprocess: bool = False |
|
num_data: int = -1 |
|
preprocessed_path: str = field( |
|
default=None, metadata={"help": "Path to the preprocessed training data."} |
|
) |
|
|
|
|
|
@dataclass |
|
class TrainingArguments(transformers.TrainingArguments): |
|
cache_dir: Optional[str] = field(default=None) |
|
optim: str = field(default="adamw_torch") |
|
model_max_length: int = field( |
|
default=2048, |
|
metadata={ |
|
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." |
|
}, |
|
) |
|
|
|
|
|
def safe_save_model_for_hf_trainer( |
|
trainer: transformers.Trainer, output_dir: str, state_dict: dict |
|
): |
|
"""Collects the state dict and dump to disk.""" |
|
|
|
if trainer.args.should_save: |
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} |
|
del state_dict |
|
trainer._save(output_dir, state_dict=cpu_state_dict) |
|
|
|
|
|
def train(): |
|
parser = transformers.HfArgumentParser( |
|
(ModelArguments, DataArguments, TrainingArguments, LoraArguments) |
|
) |
|
( |
|
model_args, |
|
data_args, |
|
training_args, |
|
lora_args, |
|
) = parser.parse_args_into_dataclasses() |
|
|
|
device_map = None |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
ddp = world_size != 1 |
|
if lora_args.q_lora: |
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None |
|
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): |
|
logging.warning( |
|
"FSDP and ZeRO3 are both currently incompatible with QLoRA." |
|
) |
|
|
|
compute_dtype = ( |
|
torch.float16 |
|
if training_args.fp16 |
|
else (torch.bfloat16 if training_args.bf16 else torch.float32) |
|
) |
|
|
|
model = transformers.AutoModelForSeq2SeqLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
device_map=device_map, |
|
quantization_config=BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=compute_dtype, |
|
) |
|
if lora_args.q_lora |
|
else None, |
|
) |
|
|
|
lora_config = LoraConfig( |
|
r=lora_args.lora_r, |
|
lora_alpha=lora_args.lora_alpha, |
|
target_modules=lora_args.lora_target_modules, |
|
lora_dropout=lora_args.lora_dropout, |
|
bias=lora_args.lora_bias, |
|
task_type=TaskType.SEQ_2_SEQ_LM, |
|
) |
|
|
|
if lora_args.q_lora: |
|
model = prepare_model_for_kbit_training( |
|
model, use_gradient_checkpointing=training_args.gradient_checkpointing |
|
) |
|
if not ddp and torch.cuda.device_count() > 1: |
|
|
|
model.is_parallelizable = True |
|
model.model_parallel = True |
|
|
|
model = get_peft_model(model, lora_config) |
|
if training_args.deepspeed is not None and training_args.local_rank == 0: |
|
model.print_trainable_parameters() |
|
|
|
if training_args.gradient_checkpointing: |
|
model.enable_input_require_grads() |
|
|
|
|
|
|
|
tokenizer = transformers.T5Tokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
|
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), |
|
other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], |
|
tokenizer=tokenizer, |
|
model=model, |
|
) |
|
|
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
|
|
|
trainer = Trainer( |
|
model=model, tokenizer=tokenizer, args=training_args, **data_module |
|
) |
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
trainer.train(resume_from_checkpoint=True) |
|
else: |
|
trainer.train() |
|
trainer.save_state() |
|
|
|
if deepspeed.is_deepspeed_zero3_enabled(): |
|
|
|
|
|
|
|
|
|
|
|
state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() |
|
if training_args.local_rank == 0: |
|
state_dict = state_dict_zero3 |
|
else: |
|
|
|
state_dict = get_peft_state_maybe_zero_3( |
|
model.named_parameters(), lora_args.lora_bias |
|
) |
|
|
|
if training_args.local_rank == 0: |
|
safe_save_model_for_hf_trainer( |
|
trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|