goku6045 commited on
Commit
abb9e91
·
verified ·
1 Parent(s): b4dc74a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +367 -0
  2. config.py +200 -0
app.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module is used to launch Axolotl with user defined configurations.
3
+ """
4
+
5
+ import gradio as gr
6
+ import yaml
7
+ from config import config
8
+
9
+ example_yml = """
10
+ base_model: NousResearch/Llama-2-7b-hf
11
+ model_type: LlamaForCausalLM
12
+ tokenizer_type: LlamaTokenizer
13
+
14
+ load_in_8bit: false
15
+ load_in_4bit: true
16
+ strict: false
17
+
18
+ datasets:
19
+ - path: mhenrichsen/alpaca_2k_test
20
+ type: alpaca
21
+ dataset_prepared_path:
22
+ val_set_size: 0.05
23
+ output_dir: ./qlora-out
24
+
25
+ adapter: qlora
26
+ lora_model_dir:
27
+
28
+ sequence_len: 4096
29
+ sample_packing: true
30
+ pad_to_sequence_len: true
31
+
32
+ lora_r: 32
33
+ lora_alpha: 16
34
+ lora_dropout: 0.05
35
+ lora_target_modules:
36
+ lora_target_linear: true
37
+ lora_fan_in_fan_out:
38
+
39
+ wandb_project:
40
+ wandb_entity:
41
+ wandb_watch:
42
+ wandb_name:
43
+ wandb_log_model:
44
+
45
+ gradient_accumulation_steps: 4
46
+ micro_batch_size: 2
47
+ num_epochs: 4
48
+ optimizer: paged_adamw_32bit
49
+ lr_scheduler: cosine
50
+ learning_rate: 0.0002
51
+
52
+ train_on_inputs: false
53
+ group_by_length: false
54
+ bf16: auto
55
+ fp16:
56
+ tf32: false
57
+
58
+ gradient_checkpointing: true
59
+ early_stopping_patience:
60
+ resume_from_checkpoint:
61
+ local_rank:
62
+ logging_steps: 1
63
+ xformers_attention:
64
+ flash_attention: true
65
+
66
+ warmup_steps: 10
67
+ evals_per_epoch: 4
68
+ eval_table_size:
69
+ saves_per_epoch: 1
70
+ debug:
71
+ deepspeed:
72
+ weight_decay: 0.0
73
+ fsdp:
74
+ fsdp_config:
75
+ special_tokens:
76
+ """
77
+
78
+
79
+ def yml_config(yml_config):
80
+ """
81
+ This function saves as a yaml file from user text.
82
+ """
83
+ yml_config = yaml.safe_load(yml_config)
84
+ with open("config.yml", "w", encoding="utf-8") as file:
85
+ yaml.dump(yml_config, file)
86
+ # print(yml_config)
87
+ return yaml.dump(yml_config)
88
+
89
+
90
+ with gr.Blocks(title="Axolotl Launcher") as demo:
91
+ gr.Markdown("""
92
+ # Axolotl Launcher
93
+ Fill out the required fields below to create a training run.
94
+ """)
95
+ with gr.Tab("Base Model & Tokenizer"):
96
+ with gr.Column():
97
+ with gr.Row():
98
+ base_model = gr.Textbox(label="Base Model")
99
+ base_model_ignore_patterns = gr.Textbox(
100
+ label="Base Model Ignore Patterns")
101
+ base_model_config = gr.Textbox(label="Base Model Config")
102
+ model_revision = gr.Textbox(label="Model Revision")
103
+ with gr.Row():
104
+ tokenizer_config = gr.Textbox(label="Tokenizer Config")
105
+ model_type = gr.Textbox(label="Model Type")
106
+ tokenizer_type = gr.Textbox(label="Tokenizer Type")
107
+ with gr.Row():
108
+ trust_remote_code = gr.Checkbox(label="Trust Remote Code", value=False)
109
+ tokenizer_use_fast = gr.Checkbox(label="Use Fast Tokenizer",
110
+ value=True)
111
+ tokenizer_legacy = gr.Checkbox(label="Use Legacy Tokenizer",
112
+ value=False)
113
+ resize_token_embeddings_to_32x = gr.Checkbox(
114
+ label="Resize Token Embeddings to 32x", value=False)
115
+ with gr.Accordion("Adv. Config", open=False):
116
+ with gr.Tab("Model Derivation & Configuration Overrides"):
117
+ with gr.Column():
118
+ is_falcon_derived_model = gr.Checkbox(
119
+ label="Is Falcon Derived Model", value=False)
120
+ is_llama_derived_model = gr.Checkbox(label="Is Llama Derived Model",
121
+ value=False)
122
+ is_mistral_derived_model = gr.Checkbox(
123
+ label="Is Mistral Derived Model", value=False)
124
+ is_qwen_derived_model = gr.Checkbox(label="Is Qwen Derived Model",
125
+ value=False)
126
+ model_config = gr.TextArea(label="Model Config Overrides",
127
+ placeholder="YAML or JSON format")
128
+ bnb_config_kwargs = gr.TextArea(label="BnB Config KWArgs",
129
+ placeholder="YAML or JSON format")
130
+
131
+ with gr.Tab("Quantization & Precision"):
132
+ with gr.Column():
133
+ with gr.Row():
134
+ gptq = gr.Checkbox(label="GPTQ", value=False)
135
+ gptq_groupsize = gr.Number(label="GPTQ Groupsize", value=128)
136
+ gptq_model_v1 = gr.Checkbox(label="GPTQ Model V1", value=False)
137
+ load_in_8bit = gr.Checkbox(label="Load in 8-bit", value=False)
138
+ load_in_4bit = gr.Checkbox(label="Load in 4-bit", value=False)
139
+ with gr.Row():
140
+ bf16 = gr.Checkbox(label="BF16", value=False)
141
+ fp16 = gr.Checkbox(label="FP16", value=False)
142
+ tf32 = gr.Checkbox(label="TF32", value=False)
143
+ bfloat16 = gr.Checkbox(label="BFloat16", value=False)
144
+ float16 = gr.Checkbox(label="Float16", value=False)
145
+
146
+ with gr.Tab("GPU & LoRA Settings"):
147
+ gpu_memory_limit = gr.Textbox(label="GPU Memory Limit")
148
+ lora_on_cpu = gr.Checkbox(label="LoRA on CPU", value=False)
149
+ datasets = gr.TextArea(label="Datasets",
150
+ placeholder="YAML or JSON format for datasets")
151
+ test_datasets = gr.TextArea(
152
+ label="Test Datasets",
153
+ placeholder="YAML or JSON format for test datasets")
154
+ rl = gr.Textbox(label="RL")
155
+ chat_template = gr.Textbox(label="Chat Template")
156
+ default_system_message = gr.Textbox(label="Default System Message")
157
+ dataset_prepared_path = gr.Textbox(label="Dataset Prepared Path")
158
+ push_dataset_to_hub = gr.Textbox(label="Push Dataset to Hub")
159
+ dataset_processes = gr.Number(label="Dataset Processes", value=1)
160
+ dataset_keep_in_memory = gr.Checkbox(label="Dataset Keep in Memory",
161
+ value=False)
162
+ with gr.Row():
163
+ hub_model_id = gr.Textbox(label="Hub Model ID")
164
+ hub_strategy = gr.Textbox(label="Hub Strategy")
165
+ hf_use_auth_token = gr.Checkbox(label="HF Use Auth Token",
166
+ value=False)
167
+ with gr.Row():
168
+ val_set_size = gr.Number(label="Validation Set Size",
169
+ value=0.04,
170
+ step=0.01)
171
+ dataset_shard_num = gr.Number(label="Dataset Shard Num")
172
+ dataset_shard_idx = gr.Number(label="Dataset Shard Index")
173
+
174
+ with gr.Tab("Training & Evaluation"):
175
+ with gr.Row():
176
+ sequence_len = gr.Number(label="Sequence Length", value=2048)
177
+ pad_to_sequence_len = gr.Checkbox(label="Pad to Sequence Length",
178
+ value=False)
179
+ with gr.Row():
180
+ sample_packing = gr.Checkbox(label="Sample Packing", value=False)
181
+ eval_sample_packing = gr.Checkbox(label="Eval Sample Packing",
182
+ value=False)
183
+ sample_packing_eff_est = gr.Number(label="Sample Packing Eff Est")
184
+ with gr.Row():
185
+ total_num_tokens = gr.Number(label="Total Num Tokens")
186
+ device_map = gr.Textbox(label="Device Map")
187
+ max_memory = gr.Textbox(label="Max Memory")
188
+ adapter = gr.Textbox(label="Adapter")
189
+ with gr.Column():
190
+ lora_model_dir = gr.Textbox(label="LoRA Model Dir")
191
+ lora_r = gr.Number(label="LoRA R", value=8)
192
+ lora_alpha = gr.Number(label="LoRA Alpha", value=16)
193
+ lora_dropout = gr.Number(label="LoRA Dropout", value=0.05, step=0.01)
194
+ lora_target_modules = gr.TextArea(label="LoRA Target Modules")
195
+ lora_target_linear = gr.Checkbox(label="LoRA Target Linear",
196
+ value=False)
197
+ lora_modules_to_save = gr.TextArea(label="LoRA Modules to Save")
198
+ lora_fan_in_fan_out = gr.Checkbox(label="LoRA Fan In Fan Out",
199
+ value=False)
200
+ peft = gr.Textbox(label="PEFT")
201
+ with gr.Row():
202
+ relora_steps = gr.Number(label="ReLoRA Steps")
203
+ relora_warmup_steps = gr.Number(label="ReLoRA Warmup Steps")
204
+ relora_anneal_steps = gr.Number(label="ReLoRA Anneal Steps")
205
+ relora_prune_ratio = gr.Number(label="ReLoRA Prune Ratio")
206
+ relora_cpu_offload = gr.Checkbox(label="ReLoRA CPU Offload",
207
+ value=False)
208
+ with gr.Row():
209
+ wandb_mode = gr.Textbox(label="WandB Mode")
210
+ wandb_project = gr.Textbox(label="WandB Project")
211
+ wandb_entity = gr.Textbox(label="WandB Entity")
212
+ wandb_watch = gr.Checkbox(label="WandB Watch", value=False)
213
+ wandb_name = gr.Textbox(label="WandB Name")
214
+ wandb_run_id = gr.Textbox(label="WandB Run ID")
215
+ wandb_log_model = gr.Checkbox(label="WandB Log Model", value=False)
216
+ with gr.Column():
217
+ mlflow_tracking_uri = gr.Textbox(label="MLFlow Tracking URI")
218
+ mlflow_experiment_name = gr.Textbox(label="MLFlow Experiment Name")
219
+ output_dir = gr.Textbox(label="Output Dir")
220
+ torch_compile = gr.Checkbox(label="Torch Compile", value=False)
221
+ torch_compile_backend = gr.Textbox(label="Torch Compile Backend")
222
+ gradient_accumulation_steps = gr.Number(
223
+ label="Gradient Accumulation Steps", value=1)
224
+ micro_batch_size = gr.Number(label="Micro Batch Size", value=2)
225
+ eval_batch_size = gr.Number(label="Eval Batch Size", value=2)
226
+ num_epochs = gr.Number(label="Number of Epochs", value=4)
227
+ warmup_steps = gr.Number(label="Warmup Steps", value=100)
228
+ warmup_ratio = gr.Number(label="Warmup Ratio")
229
+ learning_rate = gr.Number(label="Learning Rate",
230
+ value=0.00003,
231
+ step=1e-5)
232
+ lr_quadratic_warmup = gr.Checkbox(label="LR Quadratic Warmup",
233
+ value=False)
234
+ logging_steps = gr.Number(label="Logging Steps", value=1)
235
+ eval_steps = gr.Textbox(label="Eval Steps")
236
+ evals_per_epoch = gr.Number(label="Evals per Epoch", value=4)
237
+ save_strategy = gr.Textbox(label="Save Strategy")
238
+ save_steps = gr.Textbox(label="Save Steps")
239
+ saves_per_epoch = gr.Number(label="Saves per Epoch", value=1)
240
+ save_total_limit = gr.Number(label="Save Total Limit")
241
+ max_steps = gr.Number(label="Max Steps")
242
+ eval_table_size = gr.Number(label="Eval Table Size")
243
+ eval_max_new_tokens = gr.Number(label="Eval Max New Tokens",
244
+ value=128)
245
+ eval_causal_lm_metrics = gr.TextArea(label="Eval Causal LM Metrics")
246
+ loss_watchdog_threshold = gr.Number(label="Loss Watchdog Threshold")
247
+ loss_watchdog_patience = gr.Number(label="Loss Watchdog Patience",
248
+ value=3)
249
+ save_safetensors = gr.Checkbox(label="Save SafeTensors", value=False)
250
+ train_on_inputs = gr.Checkbox(label="Train on Inputs", value=False)
251
+ group_by_length = gr.Checkbox(label="Group by Length", value=False)
252
+ gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing",
253
+ value=False)
254
+ early_stopping_patience = gr.Number(label="Early Stopping Patience",
255
+ value=3)
256
+ lr_scheduler = gr.Textbox(label="LR Scheduler")
257
+ lr_scheduler_kwargs = gr.TextArea(label="LR Scheduler KWArgs")
258
+ cosine_min_lr_ratio = gr.Number(label="Cosine Min LR Ratio")
259
+ cosine_constant_lr_ratio = gr.Number(
260
+ label="Cosine Constant LR Ratio")
261
+ lr_div_factor = gr.Number(label="LR Div Factor")
262
+ log_sweep_min_lr = gr.Number(label="Log Sweep Min LR")
263
+ log_sweep_max_lr = gr.Number(label="Log Sweep Max LR")
264
+ optimizer = gr.Textbox(label="Optimizer")
265
+ weight_decay = gr.Number(label="Weight Decay", value=0.0, step=0.01)
266
+ adam_beta1 = gr.Number(label="Adam Beta1", value=0.9, step=0.01)
267
+ adam_beta2 = gr.Number(label="Adam Beta2", value=0.999, step=0.001)
268
+ adam_epsilon = gr.Number(label="Adam Epsilon", value=1e-8, step=1e-9)
269
+ max_grad_norm = gr.Number(label="Max Grad Norm")
270
+ neftune_noise_alpha = gr.Number(label="NEFTune Noise Alpha")
271
+ flash_optimum = gr.Checkbox(label="Flash Optimum", value=False)
272
+ xformers_attention = gr.Checkbox(label="XFormers Attention",
273
+ value=False)
274
+ flash_attention = gr.Checkbox(label="Flash Attention", value=False)
275
+ flash_attn_cross_entropy = gr.Checkbox(
276
+ label="Flash Attn Cross Entropy", value=False)
277
+ flash_attn_rms_norm = gr.Checkbox(label="Flash Attn RMS Norm",
278
+ value=False)
279
+ flash_attn_fuse_qkv = gr.Checkbox(label="Flash Attn Fuse QKV",
280
+ value=False)
281
+ flash_attn_fuse_mlp = gr.Checkbox(label="Flash Attn Fuse MLP",
282
+ value=False)
283
+ sdp_attention = gr.Checkbox(label="SDP Attention", value=False)
284
+ s2_attention = gr.Checkbox(label="S2 Attention", value=False)
285
+ resume_from_checkpoint = gr.Textbox(label="Resume From Checkpoint")
286
+ auto_resume_from_checkpoints = gr.Checkbox(
287
+ label="Auto Resume From Checkpoints", value=False)
288
+ local_rank = gr.Number(label="Local Rank")
289
+ special_tokens = gr.TextArea(label="Special Tokens")
290
+ tokens = gr.TextArea(label="Tokens")
291
+ fsdp = gr.Checkbox(label="FSDP", value=False)
292
+ fsdp_config = gr.TextArea(label="FSDP Config")
293
+ deepspeed = gr.Textbox(label="Deepspeed")
294
+ ddp_timeout = gr.Number(label="DDP Timeout")
295
+ ddp_bucket_cap_mb = gr.Number(label="DDP Bucket Cap MB")
296
+ ddp_broadcast_buffers = gr.Checkbox(label="DDP Broadcast Buffers",
297
+ value=False)
298
+ torchdistx_path = gr.Textbox(label="TorchDistX Path")
299
+ pretraining_dataset = gr.Textbox(label="Pretraining Dataset")
300
+ debug = gr.Checkbox(label="Debug", value=False)
301
+ seed = gr.Number(label="Seed", value=42)
302
+ strict = gr.Checkbox(label="Strict", value=False)
303
+
304
+ submit_button = gr.Button("Launch Configuration")
305
+ output_area = gr.TextArea(label="Configuration Output")
306
+
307
+ submit_button.click(
308
+ config,
309
+ inputs=[
310
+ base_model, base_model_ignore_patterns, base_model_config,
311
+ model_revision, tokenizer_config, model_type, tokenizer_type,
312
+ trust_remote_code, tokenizer_use_fast, tokenizer_legacy,
313
+ resize_token_embeddings_to_32x, is_falcon_derived_model,
314
+ is_llama_derived_model, is_mistral_derived_model,
315
+ is_qwen_derived_model, model_config, bnb_config_kwargs, gptq,
316
+ gptq_groupsize, gptq_model_v1, load_in_8bit, load_in_4bit, bf16,
317
+ fp16, tf32, bfloat16, float16, gpu_memory_limit, lora_on_cpu,
318
+ datasets, test_datasets, rl, chat_template, default_system_message,
319
+ dataset_prepared_path, push_dataset_to_hub, dataset_processes,
320
+ dataset_keep_in_memory, hub_model_id, hub_strategy,
321
+ hf_use_auth_token, val_set_size, dataset_shard_num,
322
+ dataset_shard_idx, sequence_len, pad_to_sequence_len, sample_packing,
323
+ eval_sample_packing, sample_packing_eff_est, total_num_tokens,
324
+ device_map, max_memory, adapter, lora_model_dir, lora_r, lora_alpha,
325
+ lora_dropout, lora_target_modules, lora_target_linear,
326
+ lora_modules_to_save, lora_fan_in_fan_out, peft, relora_steps,
327
+ relora_warmup_steps, relora_anneal_steps, relora_prune_ratio,
328
+ relora_cpu_offload, wandb_mode, wandb_project, wandb_entity,
329
+ wandb_watch, wandb_name, wandb_run_id, wandb_log_model,
330
+ mlflow_tracking_uri, mlflow_experiment_name, output_dir,
331
+ torch_compile, torch_compile_backend, gradient_accumulation_steps,
332
+ micro_batch_size, eval_batch_size, num_epochs, warmup_steps,
333
+ warmup_ratio, learning_rate, lr_quadratic_warmup, logging_steps,
334
+ eval_steps, evals_per_epoch, save_strategy, save_steps,
335
+ saves_per_epoch, save_total_limit, max_steps, eval_table_size,
336
+ eval_max_new_tokens, eval_causal_lm_metrics, loss_watchdog_threshold,
337
+ loss_watchdog_patience, save_safetensors, train_on_inputs,
338
+ group_by_length, gradient_checkpointing, early_stopping_patience,
339
+ lr_scheduler, lr_scheduler_kwargs, cosine_min_lr_ratio,
340
+ cosine_constant_lr_ratio, lr_div_factor, log_sweep_min_lr,
341
+ log_sweep_max_lr, optimizer, weight_decay, adam_beta1, adam_beta2,
342
+ adam_epsilon, max_grad_norm, neftune_noise_alpha, flash_optimum,
343
+ xformers_attention, flash_attention, flash_attn_cross_entropy,
344
+ flash_attn_rms_norm, flash_attn_fuse_qkv, flash_attn_fuse_mlp,
345
+ sdp_attention, s2_attention, resume_from_checkpoint,
346
+ auto_resume_from_checkpoints, local_rank, special_tokens, tokens,
347
+ fsdp, fsdp_config, deepspeed, ddp_timeout, ddp_bucket_cap_mb,
348
+ ddp_broadcast_buffers, torchdistx_path, pretraining_dataset, debug,
349
+ seed, strict
350
+ ],
351
+ outputs=output_area)
352
+ """
353
+ This section is used to create a configuration file from user text.
354
+ """
355
+ with gr.Tab(label="YML text"):
356
+ yml_config_text = gr.TextArea(label='YML Config',
357
+ lines=50,
358
+ value=example_yml)
359
+ create_config = gr.Button("Create config")
360
+ output = gr.TextArea(label="Generated config")
361
+ create_config.click(
362
+ yml_config,
363
+ inputs=[yml_config_text],
364
+ outputs=output,
365
+ )
366
+
367
+ demo.launch(share=True)
config.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+
4
+ def config(
5
+ base_model, base_model_ignore_patterns, base_model_config, model_revision,
6
+ tokenizer_config, model_type, tokenizer_type, trust_remote_code,
7
+ tokenizer_use_fast, tokenizer_legacy, resize_token_embeddings_to_32x,
8
+ is_falcon_derived_model, is_llama_derived_model, is_mistral_derived_model,
9
+ is_qwen_derived_model, model_config, bnb_config_kwargs, gptq,
10
+ gptq_groupsize, gptq_model_v1, load_in_8bit, load_in_4bit, bf16, fp16,
11
+ tf32, bfloat16, float16, gpu_memory_limit, lora_on_cpu, datasets,
12
+ test_datasets, rl, chat_template, default_system_message,
13
+ dataset_prepared_path, push_dataset_to_hub, dataset_processes,
14
+ dataset_keep_in_memory, hub_model_id, hub_strategy, hf_use_auth_token,
15
+ val_set_size, dataset_shard_num, dataset_shard_idx, sequence_len,
16
+ pad_to_sequence_len, sample_packing, eval_sample_packing,
17
+ sample_packing_eff_est, total_num_tokens, device_map, max_memory, adapter,
18
+ lora_model_dir, lora_r, lora_alpha, lora_dropout, lora_target_modules,
19
+ lora_target_linear, lora_modules_to_save, lora_fan_in_fan_out, peft,
20
+ relora_steps, relora_warmup_steps, relora_anneal_steps, relora_prune_ratio,
21
+ relora_cpu_offload, wandb_mode, wandb_project, wandb_entity, wandb_watch,
22
+ wandb_name, wandb_run_id, wandb_log_model, mlflow_tracking_uri,
23
+ mlflow_experiment_name, output_dir, torch_compile, torch_compile_backend,
24
+ gradient_accumulation_steps, micro_batch_size, eval_batch_size, num_epochs,
25
+ warmup_steps, warmup_ratio, learning_rate, lr_quadratic_warmup,
26
+ logging_steps, eval_steps, evals_per_epoch, save_strategy, save_steps,
27
+ saves_per_epoch, save_total_limit, max_steps, eval_table_size,
28
+ eval_max_new_tokens, eval_causal_lm_metrics, loss_watchdog_threshold,
29
+ loss_watchdog_patience, save_safetensors, train_on_inputs, group_by_length,
30
+ gradient_checkpointing, early_stopping_patience, lr_scheduler,
31
+ lr_scheduler_kwargs, cosine_min_lr_ratio, cosine_constant_lr_ratio,
32
+ lr_div_factor, log_sweep_min_lr, log_sweep_max_lr, optimizer, weight_decay,
33
+ adam_beta1, adam_beta2, adam_epsilon, max_grad_norm, neftune_noise_alpha,
34
+ flash_optimum, xformers_attention, flash_attention,
35
+ flash_attn_cross_entropy, flash_attn_rms_norm, flash_attn_fuse_qkv,
36
+ flash_attn_fuse_mlp, sdp_attention, s2_attention, resume_from_checkpoint,
37
+ auto_resume_from_checkpoints, local_rank, special_tokens, tokens, fsdp,
38
+ fsdp_config, deepspeed, ddp_timeout, ddp_bucket_cap_mb,
39
+ ddp_broadcast_buffers, torchdistx_path, pretraining_dataset, debug, seed,
40
+ strict):
41
+ """
42
+ This function generates a configuration dictionary based on the provided parameters and saves it as a yaml file.
43
+ """
44
+ config_dict = {
45
+ # Base model configurations
46
+ "base_model": base_model,
47
+ "base_model_ignore_patterns": base_model_ignore_patterns,
48
+ "base_model_config": base_model_config,
49
+ "model_revision": model_revision,
50
+ "tokenizer_config": tokenizer_config,
51
+ "model_type": model_type,
52
+ "tokenizer_type": tokenizer_type,
53
+ "trust_remote_code": trust_remote_code,
54
+ "tokenizer_use_fast": tokenizer_use_fast,
55
+ "tokenizer_legacy": tokenizer_legacy,
56
+ "resize_token_embeddings_to_32x": resize_token_embeddings_to_32x,
57
+ # Derived model flags
58
+ "is_falcon_derived_model": is_falcon_derived_model,
59
+ "is_llama_derived_model": is_llama_derived_model,
60
+ "is_mistral_derived_model": is_mistral_derived_model,
61
+ "is_qwen_derived_model": is_qwen_derived_model,
62
+ # Model configuration overrides
63
+ "model_config": model_config,
64
+ "bnb_config_kwargs": bnb_config_kwargs,
65
+ # Quantization and precision settings
66
+ "gptq": gptq,
67
+ "gptq_groupsize": gptq_groupsize,
68
+ "gptq_model_v1": gptq_model_v1,
69
+ "load_in_8bit": load_in_8bit,
70
+ "load_in_4bit": load_in_4bit,
71
+ "bf16": bf16,
72
+ "fp16": fp16,
73
+ "tf32": tf32,
74
+ "bfloat16": bfloat16,
75
+ "float16": float16,
76
+ "gpu_memory_limit": gpu_memory_limit,
77
+ "lora_on_cpu": lora_on_cpu,
78
+ # Dataset configurations
79
+ "datasets": datasets,
80
+ "test_datasets": test_datasets,
81
+ "rl": rl,
82
+ "chat_template": chat_template,
83
+ "default_system_message": default_system_message,
84
+ "dataset_prepared_path": dataset_prepared_path,
85
+ "push_dataset_to_hub": push_dataset_to_hub,
86
+ "dataset_processes": dataset_processes,
87
+ "dataset_keep_in_memory": dataset_keep_in_memory,
88
+ "hub_model_id": hub_model_id,
89
+ "hub_strategy": hub_strategy,
90
+ "hf_use_auth_token": hf_use_auth_token,
91
+ "val_set_size": val_set_size,
92
+ "dataset_shard_num": dataset_shard_num,
93
+ "dataset_shard_idx": dataset_shard_idx,
94
+ # Training hyperparameters
95
+ "sequence_len": sequence_len,
96
+ "pad_to_sequence_len": pad_to_sequence_len,
97
+ "sample_packing": sample_packing,
98
+ "eval_sample_packing": eval_sample_packing,
99
+ "sample_packing_eff_est": sample_packing_eff_est,
100
+ "total_num_tokens": total_num_tokens,
101
+ "device_map": device_map,
102
+ "max_memory": max_memory,
103
+ # Adapter and LoRA settings
104
+ "adapter": adapter,
105
+ "lora_model_dir": lora_model_dir,
106
+ "lora_r": lora_r,
107
+ "lora_alpha": lora_alpha,
108
+ "lora_dropout": lora_dropout,
109
+ "lora_target_modules": lora_target_modules,
110
+ "lora_target_linear": lora_target_linear,
111
+ "lora_modules_to_save": lora_modules_to_save,
112
+ "lora_fan_in_fan_out": lora_fan_in_fan_out,
113
+ "peft": peft,
114
+ "relora_steps": relora_steps,
115
+ "relora_warmup_steps": relora_warmup_steps,
116
+ "relora_anneal_steps": relora_anneal_steps,
117
+ "relora_prune_ratio": relora_prune_ratio,
118
+ "relora_cpu_offload": relora_cpu_offload,
119
+ # wandb and mlflow configurations
120
+ "wandb_mode": wandb_mode,
121
+ "wandb_project": wandb_project,
122
+ "wandb_entity": wandb_entity,
123
+ "wandb_watch": wandb_watch,
124
+ "wandb_name": wandb_name,
125
+ "wandb_run_id": wandb_run_id,
126
+ "wandb_log_model": wandb_log_model,
127
+ "mlflow_tracking_uri": mlflow_tracking_uri,
128
+ "mlflow_experiment_name": mlflow_experiment_name,
129
+ "output_dir": output_dir,
130
+ "torch_compile": torch_compile,
131
+ "torch_compile_backend": torch_compile_backend,
132
+ "gradient_accumulation_steps": gradient_accumulation_steps,
133
+ "micro_batch_size": micro_batch_size,
134
+ "eval_batch_size": eval_batch_size,
135
+ "num_epochs": num_epochs,
136
+ "warmup_steps": warmup_steps,
137
+ "warmup_ratio": warmup_ratio,
138
+ "learning_rate": learning_rate,
139
+ "lr_quadratic_warmup": lr_quadratic_warmup,
140
+ "logging_steps": logging_steps,
141
+ "eval_steps": eval_steps,
142
+ "evals_per_epoch": evals_per_epoch,
143
+ "save_strategy": save_strategy,
144
+ "save_steps": save_steps,
145
+ "saves_per_epoch": saves_per_epoch,
146
+ "save_total_limit": save_total_limit,
147
+ "max_steps": max_steps,
148
+ "eval_table_size": eval_table_size,
149
+ "eval_max_new_tokens": eval_max_new_tokens,
150
+ "eval_causal_lm_metrics": eval_causal_lm_metrics,
151
+ "loss_watchdog_threshold": loss_watchdog_threshold,
152
+ "loss_watchdog_patience": loss_watchdog_patience,
153
+ "save_safetensors": save_safetensors,
154
+ "train_on_inputs": train_on_inputs,
155
+ "group_by_length": group_by_length,
156
+ "gradient_checkpointing": gradient_checkpointing,
157
+ "early_stopping_patience": early_stopping_patience,
158
+ "lr_scheduler": lr_scheduler,
159
+ "lr_scheduler_kwargs": lr_scheduler_kwargs,
160
+ "cosine_min_lr_ratio": cosine_min_lr_ratio,
161
+ "cosine_constant_lr_ratio": cosine_constant_lr_ratio,
162
+ "lr_div_factor": lr_div_factor,
163
+ "log_sweep_min_lr": log_sweep_min_lr,
164
+ "log_sweep_max_lr": log_sweep_max_lr,
165
+ "optimizer": optimizer,
166
+ "weight_decay": weight_decay,
167
+ "adam_beta1": adam_beta1,
168
+ "adam_beta2": adam_beta2,
169
+ "adam_epsilon": adam_epsilon,
170
+ "max_grad_norm": max_grad_norm,
171
+ "neftune_noise_alpha": neftune_noise_alpha,
172
+ "flash_optimum": flash_optimum,
173
+ "xformers_attention": xformers_attention,
174
+ "flash_attention": flash_attention,
175
+ "flash_attn_cross_entropy": flash_attn_cross_entropy,
176
+ "flash_attn_rms_norm": flash_attn_rms_norm,
177
+ "flash_attn_fuse_qkv": flash_attn_fuse_qkv,
178
+ "flash_attn_fuse_mlp": flash_attn_fuse_mlp,
179
+ "sdp_attention": sdp_attention,
180
+ "s2_attention": s2_attention,
181
+ "resume_from_checkpoint": resume_from_checkpoint,
182
+ "auto_resume_from_checkpoints": auto_resume_from_checkpoints,
183
+ "local_rank": local_rank,
184
+ "special_tokens": special_tokens,
185
+ "tokens": tokens,
186
+ "fsdp": fsdp,
187
+ "fsdp_config": fsdp_config,
188
+ "deepspeed": deepspeed,
189
+ "ddp_timeout": ddp_timeout,
190
+ "ddp_bucket_cap_mb": ddp_bucket_cap_mb,
191
+ "ddp_broadcast_buffers": ddp_broadcast_buffers,
192
+ "torchdistx_path": torchdistx_path,
193
+ "pretraining_dataset": pretraining_dataset,
194
+ "debug": debug,
195
+ "seed": seed,
196
+ "strict": strict,
197
+ }
198
+ with open("config.yml", "w", encoding="utf-8") as file:
199
+ yaml.dump(config_dict, file)
200
+ return yaml.dump(config_dict)