# flake8: noqa """ pip install -U transformers accelerate trl wandb wheel packaging peft bitsandbytes liger-kernel flash_attn python sft.py \ --run_name="llama3.1-8b-continued2" \ --model_name_or_path="meta-llama/Meta-Llama-3.1-8B" \ --dataset_name="mlfoundations/dclm-baseline-1.0-parquet,mlabonne/FineTome-100k" \ --report_to="wandb" \ --optim="adamw_torch_fused" \ --lr_scheduler_type="cosine" \ --max_steps=10000000 \ --max_seq_length=64000 \ --learning_rate=0.0001 \ --attn_implementation="flash_attention_2" \ --save_strategy="steps" \ --save_steps 50 \ --save_total_limit=10 \ --per_device_train_batch_size=1 \ --gradient_accumulation_steps=8 \ --logging_steps=1 \ --num_train_epochs=1 \ --load_in_4bit \ --push_to_hub \ --hub_model_id="ericflo/Llama-3.1-8B-ContinuedTraining2-LoRA" \ --hub_strategy="all_checkpoints" \ --gradient_checkpointing \ --use_peft \ --lora_r=128 \ --lora_alpha=256 \ --lora_dropout=0.05 \ --use_liger=true \ --packing=true \ --torch_dtype="bfloat16" \ --output_dir="continuedtraining2_output" """ import logging import os import random from contextlib import nullcontext from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser from trl.env_utils import strtobool TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0")) if TRL_USE_RICH: init_zero_verbose() FORMAT = "%(message)s" from rich.console import Console from rich.logging import RichHandler import torch from datasets import load_dataset, interleave_datasets from tqdm.rich import tqdm from transformers import AutoTokenizer from trl import ( ModelConfig, RichProgressCallback, SFTConfig, SFTTrainer, get_peft_config, get_quantization_config, get_kbit_device_map, ) tqdm.pandas() if TRL_USE_RICH: logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO) print("Loading tokenizers...") METAML_TOK = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") CHATML_TOK = AutoTokenizer.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B") print("Tokenizers loaded.") def formatting_prompts_func(example): try: language = example.get('language') url = example.get('url') text = example.get('text') title = example.get('title') conversations = example.get('conversations') source = example.get('source') repo_name = example.get('max_stars_repo_name') repo_path = example.get('max_stars_repo_path') star_count = example.get('max_stars_count') content = example.get('content') # mlfoundations/dclm-baseline-1.0-parquet if language and url and text: return f'{language} {url} {text}' elif title and url and text: # wikimedia/wikipedia return f'{title} {url} {text}' elif conversations: # mlabonne/FineTome-100k rows = [{ "role": {"system": "system", "gpt": "assistant", "human": "user"}[row["from"]], "content": row["value"], } for row in conversations] tok = random.choice([METAML_TOK, CHATML_TOK]) return f'{source} {tok.apply_chat_template(rows, tokenize=False)}' elif "max_stars_repo_name" in example: # bigcode/starcoderdata return f'{example["max_stars_repo_name"]} {example["max_stars_repo_path"]} {example["max_stars_count"]} {example["content"]}' print(f"Unknown example: {example}") raise ValueError(f"Unknown example: {example}") except Exception as e: print(e) raise e if __name__ == "__main__": parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) args, training_args, model_config = parser.parse_args_and_config() # Force use our print callback if TRL_USE_RICH: training_args.disable_tqdm = True console = Console() ################ # Model init kwargs & Tokenizer ################ model_config.lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] quantization_config = get_quantization_config(model_config) model_kwargs = dict( revision=model_config.model_revision, trust_remote_code=model_config.trust_remote_code, attn_implementation=model_config.attn_implementation, torch_dtype=model_config.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) training_args.model_init_kwargs = model_kwargs tokenizer = AutoTokenizer.from_pretrained( model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True ) tokenizer.pad_token = tokenizer.eos_token ################ # Dataset ################ dataset_names = args.dataset_name.split(',') train_datasets = [load_dataset(name, split="train", streaming=True) for name in dataset_names] train_datasets.append(load_dataset("bigcode/starcoderdata", data_dir="python", split="train", streaming=True)) train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True)) train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.es", split="train", streaming=True)) train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.fr", split="train", streaming=True)) interleaved_dataset = interleave_datasets(train_datasets) eval_dataset = interleaved_dataset.take(100) train_dataset = interleaved_dataset.skip(100) print(train_dataset) print(eval_dataset) ################ # Optional rich context managers ############### init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...") save_context = ( nullcontext() if not TRL_USE_RICH else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}") ) ################ # Training ################ with init_context: trainer = SFTTrainer( model=model_config.model_name_or_path, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, peft_config=get_peft_config(model_config), callbacks=[RichProgressCallback] if TRL_USE_RICH else None, formatting_func=formatting_prompts_func, ) trainer.train() with save_context: trainer.save_model(training_args.output_dir)