# accelerate launch ./scripts/finetune.py 2-PKTDC-llama-13B-gptq-lora-24gb.yml | |
# | |
# base model settings (local or huggingface repo) | |
base_model: PocketDoc/llama-13b-gptq-4bit-128g | |
base_model_config: PocketDoc/llama-13b-gptq-4bit-128g | |
model_type: LlamaForCausalLM | |
tokenizer_type: LlamaTokenizer | |
trust_remote_code: | |
# wandb configuration | |
wandb_project: llama-13b-gptq-4bit-128g-lora | |
wandb_watch: | |
wandb_run_id: | |
wandb_log_model: | |
# where to save the finished model to | |
output_dir: ./llama-13b-gptq-4bit-128g-lora | |
# dataset settings (local or huggingface repo) | |
datasets: | |
- path: dansmeth.json | |
type: pygmalion | |
dataset_prepared_path: data/last_run_prepared | |
# percentage of the dataset to set aside as evaluation. | |
val_set_size: 0.02 | |
# max token length / prompt | |
sequence_len: 2048 | |
# max sequence length to concatenate training samples together up to | |
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning | |
max_packed_sequence_len: 2048 | |
# quantized model loading settings | |
gptq: true | |
gptq_groupsize: 128 # group size | |
gptq_model_v1: false # v1 or v2 | |
strict: false | |
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer | |
load_in_8bit: true | |
load_in_4bit: | |
# Use CUDA bf16 | |
bf16: false | |
# Use CUDA fp16 | |
fp16: true | |
# Use CUDA tf32 | |
tf32: true | |
# training hyperparameters | |
gradient_accumulation_steps: 30 | |
micro_batch_size: 6 | |
eval_batch_size: 6 | |
num_epochs: 12 | |
warmup_steps: 10 | |
learning_rate: 0.000004 | |
logging_steps: 1 | |
eval_steps: 5 | |
save_steps: 10 | |
# stop training after this many evaluation losses have increased in a row | |
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback | |
early_stopping_patience: | |
# specify a scheduler to use with the optimizer. only one_cycle is supported currently | |
lr_scheduler: linear | |
# specify optimizer | |
optimizer: paged_adamw_8bit | |
# specify weight decay | |
weight_decay: 0.0001 | |
# if you already have a lora model trained that you want to load, put that here | |
lora_model_dir: | |
# LoRA hyperparameters | |
adapter: lora # blank for full finetune | |
lora_r: 32 | |
lora_alpha: 64 | |
lora_dropout: 0.05 | |
lora_target_linear: | |
lora_target_modules: | |
- q_proj | |
- v_proj | |
# - k_proj | |
# - o_proj | |
# - gate_proj | |
# - down_proj | |
# - up_proj | |
lora_modules_to_save: | |
# - embed_tokens | |
# - lm_head | |
lora_out_dir: | |
lora_fan_in_fan_out: false | |
# whether to mask out or include the human's prompt from the training labels | |
train_on_inputs: false | |
# don't use this, leads to wonky training (according to someone on the internet) | |
group_by_length: true | |
# does not work with current implementation of 4-bit LoRA | |
gradient_checkpointing: true | |
# whether to use xformers attention patch https://github.com/facebookresearch/xformers: | |
xformers_attention: true | |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention: | |
flash_attention: # require a100 for llama | |
# whether to use scaled-dot-product attention | |
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html | |
sdp_attention: | |
# resume from a specific checkpoint dir | |
resume_from_checkpoint: | |
# if resume_from_checkpoint isn't set and you simply want it to start where it left off | |
# be careful with this being turned on between different models | |
auto_resume_from_checkpoints: false | |
# don't mess with this, it's here for accelerate and torchrun | |
local_rank: | |
# add or change special tokens | |
special_tokens: | |
# sys_role_token: "<|system|>" | |
# user_role_token: "<|user|>" | |
# model_role_token: "<|model|>" | |
bos_token: "<s>" | |
eos_token: "</s>" | |
unk_token: "<unk>" | |
# add extra tokens | |
tokens: | |
# FSDP | |
fsdp: | |
fsdp_config: | |
# Deepspeed | |
deepspeed: | |
# TODO | |
torchdistx_path: | |
# Debug mode | |
debug: |