# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field from typing import Optional import torch from datasets import load_dataset from peft import LoraConfig from tqdm import tqdm from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments from trl import SFTTrainer tqdm.pandas() # Define and parse arguments. @dataclass class ScriptArguments: """ The name of the Casual LM model we wish to fine with SFTTrainer """ model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"}) dataset_name: Optional[str] = field( default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"} ) dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"}) # 64 original seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) gradient_accumulation_steps: Optional[int] = field( default=2, metadata={"help": "the number of gradient accumulation steps"} ) load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "Enable `trust_remote_code`"}) output_dir: Optional[str] = field(default="./", metadata={"help": "the output directory"}) peft_lora_r: Optional[int] = field(default=8, metadata={"help": "the r parameter of the LoRA adapters"}) peft_lora_alpha: Optional[int] = field(default=2, metadata={"help": "the alpha parameter of the LoRA adapters"}) logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"}) use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) num_train_epochs: Optional[int] = field(default=2, metadata={"help": "the number of training epochs"}) max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) parser = HfArgumentParser(ScriptArguments) script_args = parser.parse_args_into_dataclasses()[0] # Step 1: Load the model if script_args.load_in_8bit and script_args.load_in_4bit: raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") elif script_args.load_in_8bit or script_args.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit ) # This means: fit the entire model on the GPU:0 device_map = {"": 0} torch_dtype = torch.bfloat16 else: device_map = None quantization_config = None torch_dtype = None model = AutoModelForCausalLM.from_pretrained( script_args.model_name, quantization_config=quantization_config, device_map=device_map, trust_remote_code=script_args.trust_remote_code, torch_dtype=torch_dtype, use_auth_token=script_args.use_auth_token, ) # Step 2: Load the dataset dataset = load_dataset(script_args.dataset_name, split="train") # Step 3: Define the training arguments training_args = TrainingArguments( output_dir=script_args.output_dir, per_device_train_batch_size=script_args.batch_size, gradient_accumulation_steps=script_args.gradient_accumulation_steps, learning_rate=script_args.learning_rate, logging_steps=script_args.logging_steps, num_train_epochs=script_args.num_train_epochs, max_steps=script_args.max_steps, report_to=script_args.log_with, ) # Step 4: Define the LoraConfig if script_args.use_peft: peft_config = LoraConfig( r=script_args.peft_lora_r, lora_alpha=script_args.peft_lora_alpha, bias="none", task_type="CAUSAL_LM", ) else: peft_config = None # Step 5: Define the Trainer trainer = SFTTrainer( model=model, args=training_args, max_seq_length=script_args.seq_length, train_dataset=dataset, dataset_text_field=script_args.dataset_text_field, peft_config=peft_config, ) trainer.train() # Step 6: Save the model trainer.save_model(script_args.output_dir)