File size: 9,109 Bytes
abb9e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import yaml


def config(
    base_model, base_model_ignore_patterns, base_model_config, model_revision,
    tokenizer_config, model_type, tokenizer_type, trust_remote_code,
    tokenizer_use_fast, tokenizer_legacy, resize_token_embeddings_to_32x,
    is_falcon_derived_model, is_llama_derived_model, is_mistral_derived_model,
    is_qwen_derived_model, model_config, bnb_config_kwargs, gptq,
    gptq_groupsize, gptq_model_v1, load_in_8bit, load_in_4bit, bf16, fp16,
    tf32, bfloat16, float16, gpu_memory_limit, lora_on_cpu, datasets,
    test_datasets, rl, chat_template, default_system_message,
    dataset_prepared_path, push_dataset_to_hub, dataset_processes,
    dataset_keep_in_memory, hub_model_id, hub_strategy, hf_use_auth_token,
    val_set_size, dataset_shard_num, dataset_shard_idx, sequence_len,
    pad_to_sequence_len, sample_packing, eval_sample_packing,
    sample_packing_eff_est, total_num_tokens, device_map, max_memory, adapter,
    lora_model_dir, lora_r, lora_alpha, lora_dropout, lora_target_modules,
    lora_target_linear, lora_modules_to_save, lora_fan_in_fan_out, peft,
    relora_steps, relora_warmup_steps, relora_anneal_steps, relora_prune_ratio,
    relora_cpu_offload, wandb_mode, wandb_project, wandb_entity, wandb_watch,
    wandb_name, wandb_run_id, wandb_log_model, mlflow_tracking_uri,
    mlflow_experiment_name, output_dir, torch_compile, torch_compile_backend,
    gradient_accumulation_steps, micro_batch_size, eval_batch_size, num_epochs,
    warmup_steps, warmup_ratio, learning_rate, lr_quadratic_warmup,
    logging_steps, eval_steps, evals_per_epoch, save_strategy, save_steps,
    saves_per_epoch, save_total_limit, max_steps, eval_table_size,
    eval_max_new_tokens, eval_causal_lm_metrics, loss_watchdog_threshold,
    loss_watchdog_patience, save_safetensors, train_on_inputs, group_by_length,
    gradient_checkpointing, early_stopping_patience, lr_scheduler,
    lr_scheduler_kwargs, cosine_min_lr_ratio, cosine_constant_lr_ratio,
    lr_div_factor, log_sweep_min_lr, log_sweep_max_lr, optimizer, weight_decay,
    adam_beta1, adam_beta2, adam_epsilon, max_grad_norm, neftune_noise_alpha,
    flash_optimum, xformers_attention, flash_attention,
    flash_attn_cross_entropy, flash_attn_rms_norm, flash_attn_fuse_qkv,
    flash_attn_fuse_mlp, sdp_attention, s2_attention, resume_from_checkpoint,
    auto_resume_from_checkpoints, local_rank, special_tokens, tokens, fsdp,
    fsdp_config, deepspeed, ddp_timeout, ddp_bucket_cap_mb,
    ddp_broadcast_buffers, torchdistx_path, pretraining_dataset, debug, seed,
    strict):
  """
      This function generates a configuration dictionary based on the provided parameters and saves it as a yaml file.
      """
  config_dict = {
      # Base model configurations
      "base_model": base_model,
      "base_model_ignore_patterns": base_model_ignore_patterns,
      "base_model_config": base_model_config,
      "model_revision": model_revision,
      "tokenizer_config": tokenizer_config,
      "model_type": model_type,
      "tokenizer_type": tokenizer_type,
      "trust_remote_code": trust_remote_code,
      "tokenizer_use_fast": tokenizer_use_fast,
      "tokenizer_legacy": tokenizer_legacy,
      "resize_token_embeddings_to_32x": resize_token_embeddings_to_32x,
      # Derived model flags
      "is_falcon_derived_model": is_falcon_derived_model,
      "is_llama_derived_model": is_llama_derived_model,
      "is_mistral_derived_model": is_mistral_derived_model,
      "is_qwen_derived_model": is_qwen_derived_model,
      # Model configuration overrides
      "model_config": model_config,
      "bnb_config_kwargs": bnb_config_kwargs,
      # Quantization and precision settings
      "gptq": gptq,
      "gptq_groupsize": gptq_groupsize,
      "gptq_model_v1": gptq_model_v1,
      "load_in_8bit": load_in_8bit,
      "load_in_4bit": load_in_4bit,
      "bf16": bf16,
      "fp16": fp16,
      "tf32": tf32,
      "bfloat16": bfloat16,
      "float16": float16,
      "gpu_memory_limit": gpu_memory_limit,
      "lora_on_cpu": lora_on_cpu,
      # Dataset configurations
      "datasets": datasets,
      "test_datasets": test_datasets,
      "rl": rl,
      "chat_template": chat_template,
      "default_system_message": default_system_message,
      "dataset_prepared_path": dataset_prepared_path,
      "push_dataset_to_hub": push_dataset_to_hub,
      "dataset_processes": dataset_processes,
      "dataset_keep_in_memory": dataset_keep_in_memory,
      "hub_model_id": hub_model_id,
      "hub_strategy": hub_strategy,
      "hf_use_auth_token": hf_use_auth_token,
      "val_set_size": val_set_size,
      "dataset_shard_num": dataset_shard_num,
      "dataset_shard_idx": dataset_shard_idx,
      # Training hyperparameters
      "sequence_len": sequence_len,
      "pad_to_sequence_len": pad_to_sequence_len,
      "sample_packing": sample_packing,
      "eval_sample_packing": eval_sample_packing,
      "sample_packing_eff_est": sample_packing_eff_est,
      "total_num_tokens": total_num_tokens,
      "device_map": device_map,
      "max_memory": max_memory,
      # Adapter and LoRA settings
      "adapter": adapter,
      "lora_model_dir": lora_model_dir,
      "lora_r": lora_r,
      "lora_alpha": lora_alpha,
      "lora_dropout": lora_dropout,
      "lora_target_modules": lora_target_modules,
      "lora_target_linear": lora_target_linear,
      "lora_modules_to_save": lora_modules_to_save,
      "lora_fan_in_fan_out": lora_fan_in_fan_out,
      "peft": peft,
      "relora_steps": relora_steps,
      "relora_warmup_steps": relora_warmup_steps,
      "relora_anneal_steps": relora_anneal_steps,
      "relora_prune_ratio": relora_prune_ratio,
      "relora_cpu_offload": relora_cpu_offload,
      # wandb and mlflow configurations
      "wandb_mode": wandb_mode,
      "wandb_project": wandb_project,
      "wandb_entity": wandb_entity,
      "wandb_watch": wandb_watch,
      "wandb_name": wandb_name,
      "wandb_run_id": wandb_run_id,
      "wandb_log_model": wandb_log_model,
      "mlflow_tracking_uri": mlflow_tracking_uri,
      "mlflow_experiment_name": mlflow_experiment_name,
      "output_dir": output_dir,
      "torch_compile": torch_compile,
      "torch_compile_backend": torch_compile_backend,
      "gradient_accumulation_steps": gradient_accumulation_steps,
      "micro_batch_size": micro_batch_size,
      "eval_batch_size": eval_batch_size,
      "num_epochs": num_epochs,
      "warmup_steps": warmup_steps,
      "warmup_ratio": warmup_ratio,
      "learning_rate": learning_rate,
      "lr_quadratic_warmup": lr_quadratic_warmup,
      "logging_steps": logging_steps,
      "eval_steps": eval_steps,
      "evals_per_epoch": evals_per_epoch,
      "save_strategy": save_strategy,
      "save_steps": save_steps,
      "saves_per_epoch": saves_per_epoch,
      "save_total_limit": save_total_limit,
      "max_steps": max_steps,
      "eval_table_size": eval_table_size,
      "eval_max_new_tokens": eval_max_new_tokens,
      "eval_causal_lm_metrics": eval_causal_lm_metrics,
      "loss_watchdog_threshold": loss_watchdog_threshold,
      "loss_watchdog_patience": loss_watchdog_patience,
      "save_safetensors": save_safetensors,
      "train_on_inputs": train_on_inputs,
      "group_by_length": group_by_length,
      "gradient_checkpointing": gradient_checkpointing,
      "early_stopping_patience": early_stopping_patience,
      "lr_scheduler": lr_scheduler,
      "lr_scheduler_kwargs": lr_scheduler_kwargs,
      "cosine_min_lr_ratio": cosine_min_lr_ratio,
      "cosine_constant_lr_ratio": cosine_constant_lr_ratio,
      "lr_div_factor": lr_div_factor,
      "log_sweep_min_lr": log_sweep_min_lr,
      "log_sweep_max_lr": log_sweep_max_lr,
      "optimizer": optimizer,
      "weight_decay": weight_decay,
      "adam_beta1": adam_beta1,
      "adam_beta2": adam_beta2,
      "adam_epsilon": adam_epsilon,
      "max_grad_norm": max_grad_norm,
      "neftune_noise_alpha": neftune_noise_alpha,
      "flash_optimum": flash_optimum,
      "xformers_attention": xformers_attention,
      "flash_attention": flash_attention,
      "flash_attn_cross_entropy": flash_attn_cross_entropy,
      "flash_attn_rms_norm": flash_attn_rms_norm,
      "flash_attn_fuse_qkv": flash_attn_fuse_qkv,
      "flash_attn_fuse_mlp": flash_attn_fuse_mlp,
      "sdp_attention": sdp_attention,
      "s2_attention": s2_attention,
      "resume_from_checkpoint": resume_from_checkpoint,
      "auto_resume_from_checkpoints": auto_resume_from_checkpoints,
      "local_rank": local_rank,
      "special_tokens": special_tokens,
      "tokens": tokens,
      "fsdp": fsdp,
      "fsdp_config": fsdp_config,
      "deepspeed": deepspeed,
      "ddp_timeout": ddp_timeout,
      "ddp_bucket_cap_mb": ddp_bucket_cap_mb,
      "ddp_broadcast_buffers": ddp_broadcast_buffers,
      "torchdistx_path": torchdistx_path,
      "pretraining_dataset": pretraining_dataset,
      "debug": debug,
      "seed": seed,
      "strict": strict,
  }
  with open("config.yml", "w", encoding="utf-8") as file:
    yaml.dump(config_dict, file)
  return yaml.dump(config_dict)