from unsloth import FastLanguageModel import torch from transformers import AutoTokenizer max_seq_length = 4096 dtype = torch.bfloat16 load_in_4bit = True model_name = '../out/pretrain-base' output_dir = '../out/cpt-base' model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=load_in_4bit, ) print('Ignore loaded tokenizer by FastLanguageModel.from_pretrained and using AutoTokenizer.from_pretrained') tokenizer = AutoTokenizer.from_pretrained('..', trust_remote_code=True, use_fast=True) print(f'{model=}') print(f'{tokenizer=}') model = FastLanguageModel.get_peft_model( model, r=64, # 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens", "lm_head", ], # Add for continual pretraining lora_alpha=16, lora_dropout=0, # Supports any, but = 0 is optimized bias='none', # Supports any, but = "none" is optimized # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! use_gradient_checkpointing='unsloth', # True or "unsloth" for very long context random_state=23, use_rslora=True, # We support rank stabilized LoRA loftq_config=None, # And LoftQ ) print(f'{model=}') from datasets import concatenate_datasets from cpt_base_datasets import cpt_base_datasets from cpt_instruct_datasets import cpt_instruct_datasets from unsloth_utils import load_text_dataset, load_chat_dataset core_datasets = [] for dataset_config in cpt_base_datasets: dataset = load_text_dataset(tokenizer, **dataset_config) print(f'{dataset=}') core_datasets.append(dataset) # for dataset_config in cpt_instruct_datasets: # dataset = load_chat_dataset(tokenizer, **dataset_config) # print(f'{dataset=}') # core_datasets.append(dataset) final_dataset = concatenate_datasets(core_datasets) print(f'{final_dataset=}') from trl import SFTTrainer from transformers import TrainingArguments from unsloth import is_bfloat16_supported from unsloth import UnslothTrainer, UnslothTrainingArguments trainer = UnslothTrainer( model=model, tokenizer=tokenizer, train_dataset=final_dataset, dataset_text_field='text', max_seq_length=max_seq_length, dataset_num_proc=32, args = UnslothTrainingArguments( per_device_train_batch_size=8, gradient_accumulation_steps=8, warmup_ratio=0.1, num_train_epochs=1, learning_rate=5e-5, embedding_learning_rate=5e-6, fp16=not is_bfloat16_supported(), bf16=is_bfloat16_supported(), logging_steps=1, optim='adamw_8bit', weight_decay=0.01, lr_scheduler_type='cosine', seed=23, output_dir=output_dir, report_to='wandb', ), ) trainer_stats = trainer.train()