Crystalcareai commited on
Commit
0b25161
·
verified ·
1 Parent(s): 461cea9

Create schedulefree.py

Browse files
Files changed (1) hide show
  1. schedulefree.py +103 -0
schedulefree.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import signal
2
+ import sys
3
+ from datasets import load_dataset
4
+ from transformers import TrainingArguments
5
+ from trl import SFTTrainer
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from peft import LoraConfig
9
+ from schedulefree import AdamWScheduleFree
10
+
11
+ # Signal handler function
12
+ def signal_handler(sig, frame):
13
+ print('You pressed Ctrl+C! Exiting...')
14
+ sys.exit(0)
15
+
16
+ # Register signal handler
17
+ signal.signal(signal.SIGINT, signal_handler)
18
+
19
+ dataset = load_dataset("Crystalcareai/Orca-Reka")['train']
20
+
21
+ def chatml_format(example):
22
+ """Format the dataset for training, accounting for empty columns."""
23
+ return {
24
+ "instruction": example['instruction'] if 'instruction' in example else " \n",
25
+ "input": example['input'] if 'input' in example else " \n",
26
+ "system": example['system'] if 'system' in example else " \n",
27
+ "output": example['output'] if 'output' in example else " \n",
28
+ }
29
+
30
+ # Format dataset
31
+ dataset = dataset.map(chatml_format, remove_columns=dataset.column_names)
32
+
33
+ # Load model and tokenizer
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ model_id,
36
+ device_map="auto",
37
+ attn_implementation="flash_attention_2",
38
+ torch_dtype=torch.bfloat16,
39
+ )
40
+ tokenizer = AutoTokenizer.from_pretrained(model)
41
+ tokenizer.padding_side = 'right' # to prevent warnings
42
+
43
+ peft_config = LoraConfig(
44
+ lora_alpha=16,
45
+ lora_dropout=0.05,
46
+ r=32,
47
+ bias="none",
48
+ target_modules=[
49
+ "0.w1",
50
+ "0.w2",
51
+ "0.w3",
52
+ "q_proj",
53
+ "v_proj",
54
+ "k_proj",
55
+ "o_proj"
56
+ ],
57
+ task_type="CAUSAL_LM",
58
+ use_dora=False, # Enable Dora method
59
+ )
60
+
61
+ args = TrainingArguments(
62
+ output_dir="./out", # directory to save and repository id
63
+ num_train_epochs=3, # number of training epochs
64
+ per_device_train_batch_size=4, # batch size per device during training
65
+ gradient_checkpointing=True, # use gradient checkpointing to save memory
66
+ optim="adamw_hf",
67
+ logging_steps=2,
68
+ save_strategy="steps",
69
+ save_steps=300,
70
+ bf16=True, # use bfloat16 precision
71
+ tf32=True, # use tf32 precision
72
+ ### peft specific arguments ###
73
+ learning_rate=2e-4,
74
+ max_grad_norm=0.3,
75
+ warmup_ratio=0.00,
76
+ lr_scheduler_type="constant",
77
+ report_to="wandb",
78
+ push_to_hub=False,
79
+ # push model to hub
80
+ )
81
+
82
+ max_seq_length = 2048 # max sequence length for model and packing of the dataset
83
+
84
+ # Create the schedulefree optimizer
85
+ optimizer = AdamWScheduleFree(model.parameters(), lr=args.learning_rate, beta=0.9)
86
+
87
+ trainer = SFTTrainer(
88
+ model=model,
89
+ args=args,
90
+ train_dataset=dataset,
91
+ ### peft specific arguments ###
92
+ peft_config=peft_config,
93
+ max_seq_length=max_seq_length,
94
+ tokenizer=tokenizer,
95
+ packing=False,
96
+ optimizers=(optimizer, None), # Pass the schedulefree optimizer
97
+ )
98
+
99
+ # start training, the model will be automatically saved to the hub and the output directory
100
+ trainer.train()
101
+
102
+ # save model
103
+ trainer.save_model()