|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from peft import LoraConfig |
|
from tqdm import tqdm |
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments |
|
|
|
from trl import SFTTrainer |
|
|
|
|
|
tqdm.pandas() |
|
|
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
The name of the Casual LM model we wish to fine with SFTTrainer |
|
""" |
|
|
|
model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"}) |
|
dataset_name: Optional[str] = field( |
|
default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"} |
|
) |
|
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) |
|
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) |
|
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) |
|
batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"}) |
|
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) |
|
gradient_accumulation_steps: Optional[int] = field( |
|
default=2, metadata={"help": "the number of gradient accumulation steps"} |
|
) |
|
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) |
|
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) |
|
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) |
|
trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "Enable `trust_remote_code`"}) |
|
output_dir: Optional[str] = field(default="./", metadata={"help": "the output directory"}) |
|
peft_lora_r: Optional[int] = field(default=8, metadata={"help": "the r parameter of the LoRA adapters"}) |
|
peft_lora_alpha: Optional[int] = field(default=2, metadata={"help": "the alpha parameter of the LoRA adapters"}) |
|
logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"}) |
|
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) |
|
num_train_epochs: Optional[int] = field(default=2, metadata={"help": "the number of training epochs"}) |
|
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) |
|
|
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
|
|
if script_args.load_in_8bit and script_args.load_in_4bit: |
|
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") |
|
elif script_args.load_in_8bit or script_args.load_in_4bit: |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit |
|
) |
|
|
|
device_map = {"": 0} |
|
torch_dtype = torch.bfloat16 |
|
else: |
|
device_map = None |
|
quantization_config = None |
|
torch_dtype = None |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
script_args.model_name, |
|
quantization_config=quantization_config, |
|
device_map=device_map, |
|
trust_remote_code=script_args.trust_remote_code, |
|
torch_dtype=torch_dtype, |
|
use_auth_token=script_args.use_auth_token, |
|
) |
|
|
|
|
|
dataset = load_dataset(script_args.dataset_name, split="train") |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=script_args.output_dir, |
|
per_device_train_batch_size=script_args.batch_size, |
|
gradient_accumulation_steps=script_args.gradient_accumulation_steps, |
|
learning_rate=script_args.learning_rate, |
|
logging_steps=script_args.logging_steps, |
|
num_train_epochs=script_args.num_train_epochs, |
|
max_steps=script_args.max_steps, |
|
report_to=script_args.log_with, |
|
) |
|
|
|
|
|
if script_args.use_peft: |
|
peft_config = LoraConfig( |
|
r=script_args.peft_lora_r, |
|
lora_alpha=script_args.peft_lora_alpha, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
else: |
|
peft_config = None |
|
|
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
args=training_args, |
|
max_seq_length=script_args.seq_length, |
|
train_dataset=dataset, |
|
dataset_text_field=script_args.dataset_text_field, |
|
peft_config=peft_config, |
|
) |
|
|
|
trainer.train() |
|
|
|
|
|
trainer.save_model(script_args.output_dir) |