test / config.py
goku6045's picture
Upload 2 files
abb9e91 verified
import yaml
def config(
base_model, base_model_ignore_patterns, base_model_config, model_revision,
tokenizer_config, model_type, tokenizer_type, trust_remote_code,
tokenizer_use_fast, tokenizer_legacy, resize_token_embeddings_to_32x,
is_falcon_derived_model, is_llama_derived_model, is_mistral_derived_model,
is_qwen_derived_model, model_config, bnb_config_kwargs, gptq,
gptq_groupsize, gptq_model_v1, load_in_8bit, load_in_4bit, bf16, fp16,
tf32, bfloat16, float16, gpu_memory_limit, lora_on_cpu, datasets,
test_datasets, rl, chat_template, default_system_message,
dataset_prepared_path, push_dataset_to_hub, dataset_processes,
dataset_keep_in_memory, hub_model_id, hub_strategy, hf_use_auth_token,
val_set_size, dataset_shard_num, dataset_shard_idx, sequence_len,
pad_to_sequence_len, sample_packing, eval_sample_packing,
sample_packing_eff_est, total_num_tokens, device_map, max_memory, adapter,
lora_model_dir, lora_r, lora_alpha, lora_dropout, lora_target_modules,
lora_target_linear, lora_modules_to_save, lora_fan_in_fan_out, peft,
relora_steps, relora_warmup_steps, relora_anneal_steps, relora_prune_ratio,
relora_cpu_offload, wandb_mode, wandb_project, wandb_entity, wandb_watch,
wandb_name, wandb_run_id, wandb_log_model, mlflow_tracking_uri,
mlflow_experiment_name, output_dir, torch_compile, torch_compile_backend,
gradient_accumulation_steps, micro_batch_size, eval_batch_size, num_epochs,
warmup_steps, warmup_ratio, learning_rate, lr_quadratic_warmup,
logging_steps, eval_steps, evals_per_epoch, save_strategy, save_steps,
saves_per_epoch, save_total_limit, max_steps, eval_table_size,
eval_max_new_tokens, eval_causal_lm_metrics, loss_watchdog_threshold,
loss_watchdog_patience, save_safetensors, train_on_inputs, group_by_length,
gradient_checkpointing, early_stopping_patience, lr_scheduler,
lr_scheduler_kwargs, cosine_min_lr_ratio, cosine_constant_lr_ratio,
lr_div_factor, log_sweep_min_lr, log_sweep_max_lr, optimizer, weight_decay,
adam_beta1, adam_beta2, adam_epsilon, max_grad_norm, neftune_noise_alpha,
flash_optimum, xformers_attention, flash_attention,
flash_attn_cross_entropy, flash_attn_rms_norm, flash_attn_fuse_qkv,
flash_attn_fuse_mlp, sdp_attention, s2_attention, resume_from_checkpoint,
auto_resume_from_checkpoints, local_rank, special_tokens, tokens, fsdp,
fsdp_config, deepspeed, ddp_timeout, ddp_bucket_cap_mb,
ddp_broadcast_buffers, torchdistx_path, pretraining_dataset, debug, seed,
strict):
"""
This function generates a configuration dictionary based on the provided parameters and saves it as a yaml file.
"""
config_dict = {
# Base model configurations
"base_model": base_model,
"base_model_ignore_patterns": base_model_ignore_patterns,
"base_model_config": base_model_config,
"model_revision": model_revision,
"tokenizer_config": tokenizer_config,
"model_type": model_type,
"tokenizer_type": tokenizer_type,
"trust_remote_code": trust_remote_code,
"tokenizer_use_fast": tokenizer_use_fast,
"tokenizer_legacy": tokenizer_legacy,
"resize_token_embeddings_to_32x": resize_token_embeddings_to_32x,
# Derived model flags
"is_falcon_derived_model": is_falcon_derived_model,
"is_llama_derived_model": is_llama_derived_model,
"is_mistral_derived_model": is_mistral_derived_model,
"is_qwen_derived_model": is_qwen_derived_model,
# Model configuration overrides
"model_config": model_config,
"bnb_config_kwargs": bnb_config_kwargs,
# Quantization and precision settings
"gptq": gptq,
"gptq_groupsize": gptq_groupsize,
"gptq_model_v1": gptq_model_v1,
"load_in_8bit": load_in_8bit,
"load_in_4bit": load_in_4bit,
"bf16": bf16,
"fp16": fp16,
"tf32": tf32,
"bfloat16": bfloat16,
"float16": float16,
"gpu_memory_limit": gpu_memory_limit,
"lora_on_cpu": lora_on_cpu,
# Dataset configurations
"datasets": datasets,
"test_datasets": test_datasets,
"rl": rl,
"chat_template": chat_template,
"default_system_message": default_system_message,
"dataset_prepared_path": dataset_prepared_path,
"push_dataset_to_hub": push_dataset_to_hub,
"dataset_processes": dataset_processes,
"dataset_keep_in_memory": dataset_keep_in_memory,
"hub_model_id": hub_model_id,
"hub_strategy": hub_strategy,
"hf_use_auth_token": hf_use_auth_token,
"val_set_size": val_set_size,
"dataset_shard_num": dataset_shard_num,
"dataset_shard_idx": dataset_shard_idx,
# Training hyperparameters
"sequence_len": sequence_len,
"pad_to_sequence_len": pad_to_sequence_len,
"sample_packing": sample_packing,
"eval_sample_packing": eval_sample_packing,
"sample_packing_eff_est": sample_packing_eff_est,
"total_num_tokens": total_num_tokens,
"device_map": device_map,
"max_memory": max_memory,
# Adapter and LoRA settings
"adapter": adapter,
"lora_model_dir": lora_model_dir,
"lora_r": lora_r,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
"lora_target_modules": lora_target_modules,
"lora_target_linear": lora_target_linear,
"lora_modules_to_save": lora_modules_to_save,
"lora_fan_in_fan_out": lora_fan_in_fan_out,
"peft": peft,
"relora_steps": relora_steps,
"relora_warmup_steps": relora_warmup_steps,
"relora_anneal_steps": relora_anneal_steps,
"relora_prune_ratio": relora_prune_ratio,
"relora_cpu_offload": relora_cpu_offload,
# wandb and mlflow configurations
"wandb_mode": wandb_mode,
"wandb_project": wandb_project,
"wandb_entity": wandb_entity,
"wandb_watch": wandb_watch,
"wandb_name": wandb_name,
"wandb_run_id": wandb_run_id,
"wandb_log_model": wandb_log_model,
"mlflow_tracking_uri": mlflow_tracking_uri,
"mlflow_experiment_name": mlflow_experiment_name,
"output_dir": output_dir,
"torch_compile": torch_compile,
"torch_compile_backend": torch_compile_backend,
"gradient_accumulation_steps": gradient_accumulation_steps,
"micro_batch_size": micro_batch_size,
"eval_batch_size": eval_batch_size,
"num_epochs": num_epochs,
"warmup_steps": warmup_steps,
"warmup_ratio": warmup_ratio,
"learning_rate": learning_rate,
"lr_quadratic_warmup": lr_quadratic_warmup,
"logging_steps": logging_steps,
"eval_steps": eval_steps,
"evals_per_epoch": evals_per_epoch,
"save_strategy": save_strategy,
"save_steps": save_steps,
"saves_per_epoch": saves_per_epoch,
"save_total_limit": save_total_limit,
"max_steps": max_steps,
"eval_table_size": eval_table_size,
"eval_max_new_tokens": eval_max_new_tokens,
"eval_causal_lm_metrics": eval_causal_lm_metrics,
"loss_watchdog_threshold": loss_watchdog_threshold,
"loss_watchdog_patience": loss_watchdog_patience,
"save_safetensors": save_safetensors,
"train_on_inputs": train_on_inputs,
"group_by_length": group_by_length,
"gradient_checkpointing": gradient_checkpointing,
"early_stopping_patience": early_stopping_patience,
"lr_scheduler": lr_scheduler,
"lr_scheduler_kwargs": lr_scheduler_kwargs,
"cosine_min_lr_ratio": cosine_min_lr_ratio,
"cosine_constant_lr_ratio": cosine_constant_lr_ratio,
"lr_div_factor": lr_div_factor,
"log_sweep_min_lr": log_sweep_min_lr,
"log_sweep_max_lr": log_sweep_max_lr,
"optimizer": optimizer,
"weight_decay": weight_decay,
"adam_beta1": adam_beta1,
"adam_beta2": adam_beta2,
"adam_epsilon": adam_epsilon,
"max_grad_norm": max_grad_norm,
"neftune_noise_alpha": neftune_noise_alpha,
"flash_optimum": flash_optimum,
"xformers_attention": xformers_attention,
"flash_attention": flash_attention,
"flash_attn_cross_entropy": flash_attn_cross_entropy,
"flash_attn_rms_norm": flash_attn_rms_norm,
"flash_attn_fuse_qkv": flash_attn_fuse_qkv,
"flash_attn_fuse_mlp": flash_attn_fuse_mlp,
"sdp_attention": sdp_attention,
"s2_attention": s2_attention,
"resume_from_checkpoint": resume_from_checkpoint,
"auto_resume_from_checkpoints": auto_resume_from_checkpoints,
"local_rank": local_rank,
"special_tokens": special_tokens,
"tokens": tokens,
"fsdp": fsdp,
"fsdp_config": fsdp_config,
"deepspeed": deepspeed,
"ddp_timeout": ddp_timeout,
"ddp_bucket_cap_mb": ddp_bucket_cap_mb,
"ddp_broadcast_buffers": ddp_broadcast_buffers,
"torchdistx_path": torchdistx_path,
"pretraining_dataset": pretraining_dataset,
"debug": debug,
"seed": seed,
"strict": strict,
}
with open("config.yml", "w", encoding="utf-8") as file:
yaml.dump(config_dict, file)
return yaml.dump(config_dict)