goku6045 commited on
Commit
2601523
·
verified ·
1 Parent(s): b06c01f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -364
app.py CHANGED
@@ -1,367 +1,129 @@
1
- """
2
- This module is used to launch Axolotl with user defined configurations.
3
- """
4
-
5
  import gradio as gr
6
- import yaml
7
- from config import config
8
-
9
- example_yml = """
10
- base_model: NousResearch/Llama-2-7b-hf
11
- model_type: LlamaForCausalLM
12
- tokenizer_type: LlamaTokenizer
13
-
14
- load_in_8bit: false
15
- load_in_4bit: true
16
- strict: false
17
-
18
- datasets:
19
- - path: mhenrichsen/alpaca_2k_test
20
- type: alpaca
21
- dataset_prepared_path:
22
- val_set_size: 0.05
23
- output_dir: ./qlora-out
24
-
25
- adapter: qlora
26
- lora_model_dir:
27
-
28
- sequence_len: 4096
29
- sample_packing: true
30
- pad_to_sequence_len: true
31
-
32
- lora_r: 32
33
- lora_alpha: 16
34
- lora_dropout: 0.05
35
- lora_target_modules:
36
- lora_target_linear: true
37
- lora_fan_in_fan_out:
38
-
39
- wandb_project:
40
- wandb_entity:
41
- wandb_watch:
42
- wandb_name:
43
- wandb_log_model:
44
-
45
- gradient_accumulation_steps: 4
46
- micro_batch_size: 2
47
- num_epochs: 4
48
- optimizer: paged_adamw_32bit
49
- lr_scheduler: cosine
50
- learning_rate: 0.0002
51
-
52
- train_on_inputs: false
53
- group_by_length: false
54
- bf16: auto
55
- fp16:
56
- tf32: false
57
-
58
- gradient_checkpointing: true
59
- early_stopping_patience:
60
- resume_from_checkpoint:
61
- local_rank:
62
- logging_steps: 1
63
- xformers_attention:
64
- flash_attention: true
65
-
66
- warmup_steps: 10
67
- evals_per_epoch: 4
68
- eval_table_size:
69
- saves_per_epoch: 1
70
- debug:
71
- deepspeed:
72
- weight_decay: 0.0
73
- fsdp:
74
- fsdp_config:
75
- special_tokens:
76
- """
77
-
78
-
79
- def yml_config(yml_config):
80
- """
81
- This function saves as a yaml file from user text.
82
- """
83
- yml_config = yaml.safe_load(yml_config)
84
- with open("config.yml", "w", encoding="utf-8") as file:
85
- yaml.dump(yml_config, file)
86
- # print(yml_config)
87
- return yaml.dump(yml_config)
88
-
89
-
90
- with gr.Blocks(title="Axolotl Launcher") as demo:
91
- gr.Markdown("""
92
- # Axolotl Launcher
93
- Fill out the required fields below to create a training run.
94
- """)
95
- with gr.Tab("Base Model & Tokenizer"):
96
- with gr.Column():
97
- with gr.Row():
98
- base_model = gr.Textbox(label="Base Model")
99
- base_model_ignore_patterns = gr.Textbox(
100
- label="Base Model Ignore Patterns")
101
- base_model_config = gr.Textbox(label="Base Model Config")
102
- model_revision = gr.Textbox(label="Model Revision")
103
- with gr.Row():
104
- tokenizer_config = gr.Textbox(label="Tokenizer Config")
105
- model_type = gr.Textbox(label="Model Type")
106
- tokenizer_type = gr.Textbox(label="Tokenizer Type")
107
- with gr.Row():
108
- trust_remote_code = gr.Checkbox(label="Trust Remote Code", value=False)
109
- tokenizer_use_fast = gr.Checkbox(label="Use Fast Tokenizer",
110
- value=True)
111
- tokenizer_legacy = gr.Checkbox(label="Use Legacy Tokenizer",
112
- value=False)
113
- resize_token_embeddings_to_32x = gr.Checkbox(
114
- label="Resize Token Embeddings to 32x", value=False)
115
- with gr.Accordion("Adv. Config", open=False):
116
- with gr.Accordion("Model Derivation & Configuration Overrides", open=False):
117
- with gr.Column():
118
- is_falcon_derived_model = gr.Checkbox(
119
- label="Is Falcon Derived Model", value=False)
120
- is_llama_derived_model = gr.Checkbox(label="Is Llama Derived Model",
121
- value=False)
122
- is_mistral_derived_model = gr.Checkbox(
123
- label="Is Mistral Derived Model", value=False)
124
- is_qwen_derived_model = gr.Checkbox(label="Is Qwen Derived Model",
125
- value=False)
126
- model_config = gr.TextArea(label="Model Config Overrides",
127
- placeholder="YAML or JSON format")
128
- bnb_config_kwargs = gr.TextArea(label="BnB Config KWArgs",
129
- placeholder="YAML or JSON format")
130
-
131
- with gr.Accordion("Quantization & Precision", open=False):
132
- with gr.Column():
133
- with gr.Row():
134
- gptq = gr.Checkbox(label="GPTQ", value=False)
135
- gptq_groupsize = gr.Number(label="GPTQ Groupsize", value=128)
136
- gptq_model_v1 = gr.Checkbox(label="GPTQ Model V1", value=False)
137
- load_in_8bit = gr.Checkbox(label="Load in 8-bit", value=False)
138
- load_in_4bit = gr.Checkbox(label="Load in 4-bit", value=False)
139
- with gr.Row():
140
- bf16 = gr.Checkbox(label="BF16", value=False)
141
- fp16 = gr.Checkbox(label="FP16", value=False)
142
- tf32 = gr.Checkbox(label="TF32", value=False)
143
- bfloat16 = gr.Checkbox(label="BFloat16", value=False)
144
- float16 = gr.Checkbox(label="Float16", value=False)
145
-
146
- with gr.Accordion("GPU & LoRA Settings", open=False):
147
- gpu_memory_limit = gr.Textbox(label="GPU Memory Limit")
148
- lora_on_cpu = gr.Checkbox(label="LoRA on CPU", value=False)
149
- datasets = gr.TextArea(label="Datasets",
150
- placeholder="YAML or JSON format for datasets")
151
- test_datasets = gr.TextArea(
152
- label="Test Datasets",
153
- placeholder="YAML or JSON format for test datasets")
154
- rl = gr.Textbox(label="RL")
155
- chat_template = gr.Textbox(label="Chat Template")
156
- default_system_message = gr.Textbox(label="Default System Message")
157
- dataset_prepared_path = gr.Textbox(label="Dataset Prepared Path")
158
- push_dataset_to_hub = gr.Textbox(label="Push Dataset to Hub")
159
- dataset_processes = gr.Number(label="Dataset Processes", value=1)
160
- dataset_keep_in_memory = gr.Checkbox(label="Dataset Keep in Memory",
161
- value=False)
162
- with gr.Row():
163
- hub_model_id = gr.Textbox(label="Hub Model ID")
164
- hub_strategy = gr.Textbox(label="Hub Strategy")
165
- hf_use_auth_token = gr.Checkbox(label="HF Use Auth Token",
166
- value=False)
167
- with gr.Row():
168
- val_set_size = gr.Number(label="Validation Set Size",
169
- value=0.04,
170
- step=0.01)
171
- dataset_shard_num = gr.Number(label="Dataset Shard Num")
172
- dataset_shard_idx = gr.Number(label="Dataset Shard Index")
173
-
174
- with gr.Accordion("Training & Evaluation", open=False):
175
- with gr.Row():
176
- sequence_len = gr.Number(label="Sequence Length", value=2048)
177
- pad_to_sequence_len = gr.Checkbox(label="Pad to Sequence Length",
178
- value=False)
179
- with gr.Row():
180
- sample_packing = gr.Checkbox(label="Sample Packing", value=False)
181
- eval_sample_packing = gr.Checkbox(label="Eval Sample Packing",
182
- value=False)
183
- sample_packing_eff_est = gr.Number(label="Sample Packing Eff Est")
184
- with gr.Row():
185
- total_num_tokens = gr.Number(label="Total Num Tokens")
186
- device_map = gr.Textbox(label="Device Map")
187
- max_memory = gr.Textbox(label="Max Memory")
188
- adapter = gr.Textbox(label="Adapter")
189
- with gr.Column():
190
- lora_model_dir = gr.Textbox(label="LoRA Model Dir")
191
- lora_r = gr.Number(label="LoRA R", value=8)
192
- lora_alpha = gr.Number(label="LoRA Alpha", value=16)
193
- lora_dropout = gr.Number(label="LoRA Dropout", value=0.05, step=0.01)
194
- lora_target_modules = gr.TextArea(label="LoRA Target Modules")
195
- lora_target_linear = gr.Checkbox(label="LoRA Target Linear",
196
- value=False)
197
- lora_modules_to_save = gr.TextArea(label="LoRA Modules to Save")
198
- lora_fan_in_fan_out = gr.Checkbox(label="LoRA Fan In Fan Out",
199
- value=False)
200
- peft = gr.Textbox(label="PEFT")
201
- with gr.Row():
202
- relora_steps = gr.Number(label="ReLoRA Steps")
203
- relora_warmup_steps = gr.Number(label="ReLoRA Warmup Steps")
204
- relora_anneal_steps = gr.Number(label="ReLoRA Anneal Steps")
205
- relora_prune_ratio = gr.Number(label="ReLoRA Prune Ratio")
206
- relora_cpu_offload = gr.Checkbox(label="ReLoRA CPU Offload",
207
- value=False)
208
- with gr.Row():
209
- wandb_mode = gr.Textbox(label="WandB Mode")
210
- wandb_project = gr.Textbox(label="WandB Project")
211
- wandb_entity = gr.Textbox(label="WandB Entity")
212
- wandb_watch = gr.Checkbox(label="WandB Watch", value=False)
213
- wandb_name = gr.Textbox(label="WandB Name")
214
- wandb_run_id = gr.Textbox(label="WandB Run ID")
215
- wandb_log_model = gr.Checkbox(label="WandB Log Model", value=False)
216
- with gr.Column():
217
- mlflow_tracking_uri = gr.Textbox(label="MLFlow Tracking URI")
218
- mlflow_experiment_name = gr.Textbox(label="MLFlow Experiment Name")
219
- output_dir = gr.Textbox(label="Output Dir")
220
- torch_compile = gr.Checkbox(label="Torch Compile", value=False)
221
- torch_compile_backend = gr.Textbox(label="Torch Compile Backend")
222
- gradient_accumulation_steps = gr.Number(
223
- label="Gradient Accumulation Steps", value=1)
224
- micro_batch_size = gr.Number(label="Micro Batch Size", value=2)
225
- eval_batch_size = gr.Number(label="Eval Batch Size", value=2)
226
- num_epochs = gr.Number(label="Number of Epochs", value=4)
227
- warmup_steps = gr.Number(label="Warmup Steps", value=100)
228
- warmup_ratio = gr.Number(label="Warmup Ratio")
229
- learning_rate = gr.Number(label="Learning Rate",
230
- value=0.00003,
231
- step=1e-5)
232
- lr_quadratic_warmup = gr.Checkbox(label="LR Quadratic Warmup",
233
- value=False)
234
- logging_steps = gr.Number(label="Logging Steps", value=1)
235
- eval_steps = gr.Textbox(label="Eval Steps")
236
- evals_per_epoch = gr.Number(label="Evals per Epoch", value=4)
237
- save_strategy = gr.Textbox(label="Save Strategy")
238
- save_steps = gr.Textbox(label="Save Steps")
239
- saves_per_epoch = gr.Number(label="Saves per Epoch", value=1)
240
- save_total_limit = gr.Number(label="Save Total Limit")
241
- max_steps = gr.Number(label="Max Steps")
242
- eval_table_size = gr.Number(label="Eval Table Size")
243
- eval_max_new_tokens = gr.Number(label="Eval Max New Tokens",
244
- value=128)
245
- eval_causal_lm_metrics = gr.TextArea(label="Eval Causal LM Metrics")
246
- loss_watchdog_threshold = gr.Number(label="Loss Watchdog Threshold")
247
- loss_watchdog_patience = gr.Number(label="Loss Watchdog Patience",
248
- value=3)
249
- save_safetensors = gr.Checkbox(label="Save SafeTensors", value=False)
250
- train_on_inputs = gr.Checkbox(label="Train on Inputs", value=False)
251
- group_by_length = gr.Checkbox(label="Group by Length", value=False)
252
- gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing",
253
- value=False)
254
- early_stopping_patience = gr.Number(label="Early Stopping Patience",
255
- value=3)
256
- lr_scheduler = gr.Textbox(label="LR Scheduler")
257
- lr_scheduler_kwargs = gr.TextArea(label="LR Scheduler KWArgs")
258
- cosine_min_lr_ratio = gr.Number(label="Cosine Min LR Ratio")
259
- cosine_constant_lr_ratio = gr.Number(
260
- label="Cosine Constant LR Ratio")
261
- lr_div_factor = gr.Number(label="LR Div Factor")
262
- log_sweep_min_lr = gr.Number(label="Log Sweep Min LR")
263
- log_sweep_max_lr = gr.Number(label="Log Sweep Max LR")
264
- optimizer = gr.Textbox(label="Optimizer")
265
- weight_decay = gr.Number(label="Weight Decay", value=0.0, step=0.01)
266
- adam_beta1 = gr.Number(label="Adam Beta1", value=0.9, step=0.01)
267
- adam_beta2 = gr.Number(label="Adam Beta2", value=0.999, step=0.001)
268
- adam_epsilon = gr.Number(label="Adam Epsilon", value=1e-8, step=1e-9)
269
- max_grad_norm = gr.Number(label="Max Grad Norm")
270
- neftune_noise_alpha = gr.Number(label="NEFTune Noise Alpha")
271
- flash_optimum = gr.Checkbox(label="Flash Optimum", value=False)
272
- xformers_attention = gr.Checkbox(label="XFormers Attention",
273
- value=False)
274
- flash_attention = gr.Checkbox(label="Flash Attention", value=False)
275
- flash_attn_cross_entropy = gr.Checkbox(
276
- label="Flash Attn Cross Entropy", value=False)
277
- flash_attn_rms_norm = gr.Checkbox(label="Flash Attn RMS Norm",
278
- value=False)
279
- flash_attn_fuse_qkv = gr.Checkbox(label="Flash Attn Fuse QKV",
280
- value=False)
281
- flash_attn_fuse_mlp = gr.Checkbox(label="Flash Attn Fuse MLP",
282
- value=False)
283
- sdp_attention = gr.Checkbox(label="SDP Attention", value=False)
284
- s2_attention = gr.Checkbox(label="S2 Attention", value=False)
285
- resume_from_checkpoint = gr.Textbox(label="Resume From Checkpoint")
286
- auto_resume_from_checkpoints = gr.Checkbox(
287
- label="Auto Resume From Checkpoints", value=False)
288
- local_rank = gr.Number(label="Local Rank")
289
- special_tokens = gr.TextArea(label="Special Tokens")
290
- tokens = gr.TextArea(label="Tokens")
291
- fsdp = gr.Checkbox(label="FSDP", value=False)
292
- fsdp_config = gr.TextArea(label="FSDP Config")
293
- deepspeed = gr.Textbox(label="Deepspeed")
294
- ddp_timeout = gr.Number(label="DDP Timeout")
295
- ddp_bucket_cap_mb = gr.Number(label="DDP Bucket Cap MB")
296
- ddp_broadcast_buffers = gr.Checkbox(label="DDP Broadcast Buffers",
297
- value=False)
298
- torchdistx_path = gr.Textbox(label="TorchDistX Path")
299
- pretraining_dataset = gr.Textbox(label="Pretraining Dataset")
300
- debug = gr.Checkbox(label="Debug", value=False)
301
- seed = gr.Number(label="Seed", value=42)
302
- strict = gr.Checkbox(label="Strict", value=False)
303
-
304
- submit_button = gr.Button("Launch Configuration")
305
- output_area = gr.TextArea(label="Configuration Output")
306
 
307
- submit_button.click(
308
- config,
309
- inputs=[
310
- base_model, base_model_ignore_patterns, base_model_config,
311
- model_revision, tokenizer_config, model_type, tokenizer_type,
312
- trust_remote_code, tokenizer_use_fast, tokenizer_legacy,
313
- resize_token_embeddings_to_32x, is_falcon_derived_model,
314
- is_llama_derived_model, is_mistral_derived_model,
315
- is_qwen_derived_model, model_config, bnb_config_kwargs, gptq,
316
- gptq_groupsize, gptq_model_v1, load_in_8bit, load_in_4bit, bf16,
317
- fp16, tf32, bfloat16, float16, gpu_memory_limit, lora_on_cpu,
318
- datasets, test_datasets, rl, chat_template, default_system_message,
319
- dataset_prepared_path, push_dataset_to_hub, dataset_processes,
320
- dataset_keep_in_memory, hub_model_id, hub_strategy,
321
- hf_use_auth_token, val_set_size, dataset_shard_num,
322
- dataset_shard_idx, sequence_len, pad_to_sequence_len, sample_packing,
323
- eval_sample_packing, sample_packing_eff_est, total_num_tokens,
324
- device_map, max_memory, adapter, lora_model_dir, lora_r, lora_alpha,
325
- lora_dropout, lora_target_modules, lora_target_linear,
326
- lora_modules_to_save, lora_fan_in_fan_out, peft, relora_steps,
327
- relora_warmup_steps, relora_anneal_steps, relora_prune_ratio,
328
- relora_cpu_offload, wandb_mode, wandb_project, wandb_entity,
329
- wandb_watch, wandb_name, wandb_run_id, wandb_log_model,
330
- mlflow_tracking_uri, mlflow_experiment_name, output_dir,
331
- torch_compile, torch_compile_backend, gradient_accumulation_steps,
332
- micro_batch_size, eval_batch_size, num_epochs, warmup_steps,
333
- warmup_ratio, learning_rate, lr_quadratic_warmup, logging_steps,
334
- eval_steps, evals_per_epoch, save_strategy, save_steps,
335
- saves_per_epoch, save_total_limit, max_steps, eval_table_size,
336
- eval_max_new_tokens, eval_causal_lm_metrics, loss_watchdog_threshold,
337
- loss_watchdog_patience, save_safetensors, train_on_inputs,
338
- group_by_length, gradient_checkpointing, early_stopping_patience,
339
- lr_scheduler, lr_scheduler_kwargs, cosine_min_lr_ratio,
340
- cosine_constant_lr_ratio, lr_div_factor, log_sweep_min_lr,
341
- log_sweep_max_lr, optimizer, weight_decay, adam_beta1, adam_beta2,
342
- adam_epsilon, max_grad_norm, neftune_noise_alpha, flash_optimum,
343
- xformers_attention, flash_attention, flash_attn_cross_entropy,
344
- flash_attn_rms_norm, flash_attn_fuse_qkv, flash_attn_fuse_mlp,
345
- sdp_attention, s2_attention, resume_from_checkpoint,
346
- auto_resume_from_checkpoints, local_rank, special_tokens, tokens,
347
- fsdp, fsdp_config, deepspeed, ddp_timeout, ddp_bucket_cap_mb,
348
- ddp_broadcast_buffers, torchdistx_path, pretraining_dataset, debug,
349
- seed, strict
350
- ],
351
- outputs=output_area)
352
- """
353
- This section is used to create a configuration file from user text.
354
- """
355
- with gr.Tab(label="YML text"):
356
- yml_config_text = gr.TextArea(label='YML Config',
357
- lines=50,
358
- value=example_yml)
359
- create_config = gr.Button("Create config")
360
- output = gr.TextArea(label="Generated config")
361
- create_config.click(
362
- yml_config,
363
- inputs=[yml_config_text],
364
- outputs=output,
365
- )
366
 
367
- demo.launch(share=True)
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+
4
+
5
+ class Main:
6
+
7
+ async def train_model(self,max_steps, base_model, model_type, tokenizer_type, is_llama_derived_model,
8
+ strict, datasets_path, dataset_format, shards,
9
+ val_set_size, output_dir, adapter, lora_model_dir, sequence_len, sample_packing,
10
+ pad_to_sequence_len, lora_r, lora_alpha, lora_dropout,
11
+ lora_target_modules, lora_target_linear, lora_fan_in_fan_out, gradient_accumulation_steps,
12
+ micro_batch_size, num_epochs, optimizer, lr_scheduler, learning_rate, train_on_inputs,
13
+ group_by_length, bf16, fp16, tf32, gradient_checkpointing,
14
+ resume_from_checkpoint, local_rank, logging_steps, xformers_attention, flash_attention,
15
+ load_best_model_at_end, warmup_steps, evals_per_epoch, eval_table_size, saves_per_epoch,
16
+ debug, weight_decay, wandb_project, wandb_entity, wandb_watch,
17
+ wandb_name, wandb_log_model,last_tab,progress=gr.Progress(track_tqdm=True)):
18
+
19
+ a = [base_model, model_type, tokenizer_type, is_llama_derived_model,
20
+ strict, datasets_path, dataset_format, shards,
21
+ val_set_size, output_dir, adapter, lora_model_dir, sequence_len, sample_packing,
22
+ pad_to_sequence_len, lora_r, lora_alpha, lora_dropout,
23
+ lora_target_modules, lora_target_linear, lora_fan_in_fan_out, gradient_accumulation_steps,
24
+ micro_batch_size, num_epochs, optimizer, lr_scheduler, learning_rate, train_on_inputs,
25
+ group_by_length, bf16, fp16, tf32, gradient_checkpointing,
26
+ resume_from_checkpoint, local_rank, logging_steps, xformers_attention, flash_attention,
27
+ load_best_model_at_end, warmup_steps, evals_per_epoch, eval_table_size, saves_per_epoch,
28
+ debug, weight_decay, wandb_project, wandb_entity, wandb_watch,
29
+ wandb_name, wandb_log_model,last_tab]
30
+
31
+ return a
32
+
33
+
34
+
35
+
36
+
37
+ def initiate_userInterface(self):
38
+ with gr.Blocks() as self.app:
39
+ gr.Markdown("### Axolotl UI")
40
+
41
+ # Finetuning Tab
42
+ with gr.Tab("FineTuning UI"):
43
+ base_model = gr.Dropdown(choices=["NousResearch/Llama-2-7b-hf", "mistralai/Mistral-7B-Instruct-v0.2"], label="Select Model", value="NousResearch/Llama-2-7b-hf")
44
+ datasets_path = gr.Textbox(label="datasets_path", value="mhenrichsen/alpaca_2k_test")
45
+ dataset_format = gr.Radio(choices=['Alpaca'], label="Dataset Format", value='Alpaca')
46
+ shards = gr.Slider(minimum=0, maximum=20, step=1, label="shards", value=10)
47
+ last_tab = gr.Checkbox(label='last_tab',value=False,visible=False)
48
+
49
+ with gr.Accordion("Advanced Settings",open=False):
50
+ with gr.Tab("YAML Configuration"):
51
+ model_type = gr.Radio(label="model_type", choices=['MistralForCausalLM','LlamaForCausalLM'],info="",value="LlamaForCausalLM")
52
+ tokenizer_type = gr.Textbox(label="tokenizer_type", value="LlamaTokenizer",visible=False)
53
+ is_llama_derived_model = gr.Checkbox(label="is_llama_derived_model", value=True,info="Determines the padding strategy based on the parent type of the model")
54
+ strict = gr.Checkbox(label="strict", value=False,visible=False)
55
+ val_set_size = gr.Slider(minimum=0, maximum=1, step=0.1, label="val_set_size", value=0.05,info="Percentage of training data to be used for validation")
56
+ output_dir = gr.Textbox(label="output_dir", value="./finetune-out",info="Output directory of the finetuned model")
57
+ adapter = gr.Radio(choices=["qlora", "lora"], label="adapter",value='qlora',info="Parameter efficient training strategy")
58
+ lora_model_dir = gr.Textbox(label="lora_model_dir",info="Directory of a custom adapter can be provided",visible=False)
59
+ sequence_len = gr.Slider(minimum=512, maximum=4096, step=10,label="sequence_len", value=1024,info="The maximum length input allowed to train")
60
+ sample_packing = gr.Checkbox(label="sample_packing", value=True,info="Speeds up data preparation but recommended false for small datasets")
61
+ pad_to_sequence_len = gr.Checkbox(label="pad_to_sequence_len", value=True, info="Pads the input to match sequence length to avoid memory fragmentation and out of memory issues. Recommended true")
62
+ # eval_sample_packing = gr.Checkbox(label="eval_sample_packing", value=False)
63
+ lora_r = gr.Slider(minimum=8, maximum=64, step=2,label="lora_r", value=32,info="The number of parameters in adaptation layers.")
64
+ lora_alpha = gr.Slider(minimum=8, maximum=64, step=0.1,label="lora_alpha", value=16,info="How much adapted weights affect base model's")
65
+ lora_dropout = gr.Slider(minimum=0, maximum=1, label="lora_dropout", value=0.05, step=0.01,info="The ratio of weights ignored randomly within adapted weights")
66
+ lora_target_modules = gr.Textbox(label="lora_target_modules", value="q_proj, v_proj, k_proj",info="All dense layers can be targeted using parameter efficient tuning")
67
+ lora_target_linear = gr.Checkbox(label="lora_target_linear", value=True,info="Lora Target Modules will be ignored and all linear layers will be used")
68
+ lora_fan_in_fan_out = gr.Textbox(label="lora_fan_in_fan_out",visible=False)
69
+
70
+ gradient_accumulation_steps = gr.Slider(minimum=4, maximum=64, step=1,label="gradient_accumulation_steps", value=4,info="Number of steps required to update the weights with cumulative gradients")
71
+ micro_batch_size = gr.Slider(minimum=1, maximum=64, step=2,label="micro_batch_size", value=2,info="Number of samples sent to each gpu")
72
+ num_epochs = gr.Slider(minimum=1, maximum=4, step=1,label="num_epochs", value=1)
73
+ max_steps = gr.Textbox(label="max_steps",value='1',info="Maximum number of steps to be trained. Overwrites the number of epochs",visible=False)
74
+ optimizer = gr.Radio(choices=["adamw_hf",'adamw_torch','adamw_torch_fused','adamw_torch_xla','adamw_apex_fused','adafactor','adamw_anyprecision','sgd','adagrad','adamw_bnb_8bit','lion_8bit','lion_32bit','paged_adamw_32bit','paged_adamw_8bit','paged_lion_32bit','paged_lion_8bit'], value="paged_adamw_32bit",label='optimizer',info="Use an optimizer which aligns with the quantization of model")
75
+ lr_scheduler = gr.Radio(label="lr_scheduler", choices=['one_cycle', 'log_sweep', 'cosine'],value="cosine",info="Determines dynamic learning rate based on current step")
76
+ learning_rate = gr.Textbox(label="max_learning_rate", value="2e-5",info="")
77
+ train_on_inputs = gr.Checkbox(label="train_on_inputs", value=False,visible=False)
78
+ group_by_length = gr.Checkbox(label="group_by_length", value=False,visible=False)
79
+ bf16 = gr.Checkbox(label="bfloat16", value=False,info="Enable bfloat16 precision for tensors; supported only on Ampere or newer GPUs.")
80
+ fp16 = gr.Checkbox(label="Half Precision", value=True,info="Enable half precision (FP16) for tensor processing.")
81
+ tf32 = gr.Checkbox(label="TensorFloat32", value=False,info="Enable TensorFloat32 precision for tensors; supported only on Ampere or newer GPUs.")
82
+ gradient_checkpointing = gr.Checkbox(label="gradient_checkpointing", value=True,info='',visible=False)
83
+ resume_from_checkpoint = gr.Textbox(label="resume_from_checkpoint",visible=False)
84
+ local_rank = gr.Textbox(label="local_rank",visible=False)
85
+ logging_steps = gr.Slider(minimum=1, maximum=100, step=1,label="logging_steps", value=1,info='',visible=False)
86
+ xformers_attention = gr.Checkbox(label="xformers_attention", value=False,visible=False)
87
+ flash_attention = gr.Checkbox(label="flash_attention", value=False,info='',visible=False)
88
+ load_best_model_at_end = gr.Checkbox(label="load_best_model_at_end", value=False,visible=False)
89
+ warmup_steps = gr.Slider(minimum=1, maximum=100, step=1,label="warmup_steps", value=10,visible=False)
90
+ evals_per_epoch = gr.Slider(minimum=1, maximum=100, step=1,label="evals_per_epoch", value=4,info='No. of Evaluation Per Epoch',visible=False)
91
+ eval_table_size = gr.Textbox(label="eval_table_size",visible=False)
92
+ saves_per_epoch = gr.Slider(minimum=1, maximum=100, step=1,label="saves_per_epoch", value=1,info='No. of checkpoints to be saved')
93
+
94
+ debug = gr.Checkbox(label="debug", value=False,visible=False)
95
+
96
+ weight_decay = gr.Number(label="weight_decay", value=0.0,visible=False)
97
+ wandb_watch = gr.Checkbox(label="wandb_watch", value=False,visible=False)
98
+ wandb_log_model = gr.Checkbox(label="wandb_log_model", value=False,visible=False)
99
+ wandb_project = gr.Textbox(label="wandb_project",visible=False)
100
+ wandb_entity = gr.Textbox(label="wandb_entity",visible=False)
101
+ wandb_name = gr.Textbox(label="wandb_name",visible=False)
102
+
103
+
104
+ train_btn = gr.Button("Start Training")
105
+ train_btn.click(
106
+ self.train_model,
107
+ inputs=[max_steps, base_model, model_type, tokenizer_type, is_llama_derived_model,
108
+ strict, datasets_path, dataset_format, shards,
109
+ val_set_size, output_dir, adapter, lora_model_dir, sequence_len, sample_packing,
110
+ pad_to_sequence_len, lora_r, lora_alpha, lora_dropout,
111
+ lora_target_modules, lora_target_linear, lora_fan_in_fan_out, gradient_accumulation_steps,
112
+ micro_batch_size, num_epochs, optimizer, lr_scheduler, learning_rate, train_on_inputs,
113
+ group_by_length, bf16, fp16, tf32, gradient_checkpointing,
114
+ resume_from_checkpoint, local_rank, logging_steps, xformers_attention, flash_attention,
115
+ load_best_model_at_end, warmup_steps, evals_per_epoch, eval_table_size, saves_per_epoch,
116
+ debug, weight_decay, wandb_project, wandb_entity, wandb_watch,
117
+ wandb_name, wandb_log_model,last_tab],
118
+ outputs=[gr.Textbox(label="Training Output",interactive=False)]
119
+ )
120
+
121
+ return self.app
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main = Main()
126
+ app = main.initiate_userInterface()
127
+ app.queue().launch(share=True,server_name='0.0.0.0')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129