import yaml def config( base_model, base_model_ignore_patterns, base_model_config, model_revision, tokenizer_config, model_type, tokenizer_type, trust_remote_code, tokenizer_use_fast, tokenizer_legacy, resize_token_embeddings_to_32x, is_falcon_derived_model, is_llama_derived_model, is_mistral_derived_model, is_qwen_derived_model, model_config, bnb_config_kwargs, gptq, gptq_groupsize, gptq_model_v1, load_in_8bit, load_in_4bit, bf16, fp16, tf32, bfloat16, float16, gpu_memory_limit, lora_on_cpu, datasets, test_datasets, rl, chat_template, default_system_message, dataset_prepared_path, push_dataset_to_hub, dataset_processes, dataset_keep_in_memory, hub_model_id, hub_strategy, hf_use_auth_token, val_set_size, dataset_shard_num, dataset_shard_idx, sequence_len, pad_to_sequence_len, sample_packing, eval_sample_packing, sample_packing_eff_est, total_num_tokens, device_map, max_memory, adapter, lora_model_dir, lora_r, lora_alpha, lora_dropout, lora_target_modules, lora_target_linear, lora_modules_to_save, lora_fan_in_fan_out, peft, relora_steps, relora_warmup_steps, relora_anneal_steps, relora_prune_ratio, relora_cpu_offload, wandb_mode, wandb_project, wandb_entity, wandb_watch, wandb_name, wandb_run_id, wandb_log_model, mlflow_tracking_uri, mlflow_experiment_name, output_dir, torch_compile, torch_compile_backend, gradient_accumulation_steps, micro_batch_size, eval_batch_size, num_epochs, warmup_steps, warmup_ratio, learning_rate, lr_quadratic_warmup, logging_steps, eval_steps, evals_per_epoch, save_strategy, save_steps, saves_per_epoch, save_total_limit, max_steps, eval_table_size, eval_max_new_tokens, eval_causal_lm_metrics, loss_watchdog_threshold, loss_watchdog_patience, save_safetensors, train_on_inputs, group_by_length, gradient_checkpointing, early_stopping_patience, lr_scheduler, lr_scheduler_kwargs, cosine_min_lr_ratio, cosine_constant_lr_ratio, lr_div_factor, log_sweep_min_lr, log_sweep_max_lr, optimizer, weight_decay, adam_beta1, adam_beta2, adam_epsilon, max_grad_norm, neftune_noise_alpha, flash_optimum, xformers_attention, flash_attention, flash_attn_cross_entropy, flash_attn_rms_norm, flash_attn_fuse_qkv, flash_attn_fuse_mlp, sdp_attention, s2_attention, resume_from_checkpoint, auto_resume_from_checkpoints, local_rank, special_tokens, tokens, fsdp, fsdp_config, deepspeed, ddp_timeout, ddp_bucket_cap_mb, ddp_broadcast_buffers, torchdistx_path, pretraining_dataset, debug, seed, strict): """ This function generates a configuration dictionary based on the provided parameters and saves it as a yaml file. """ config_dict = { # Base model configurations "base_model": base_model, "base_model_ignore_patterns": base_model_ignore_patterns, "base_model_config": base_model_config, "model_revision": model_revision, "tokenizer_config": tokenizer_config, "model_type": model_type, "tokenizer_type": tokenizer_type, "trust_remote_code": trust_remote_code, "tokenizer_use_fast": tokenizer_use_fast, "tokenizer_legacy": tokenizer_legacy, "resize_token_embeddings_to_32x": resize_token_embeddings_to_32x, # Derived model flags "is_falcon_derived_model": is_falcon_derived_model, "is_llama_derived_model": is_llama_derived_model, "is_mistral_derived_model": is_mistral_derived_model, "is_qwen_derived_model": is_qwen_derived_model, # Model configuration overrides "model_config": model_config, "bnb_config_kwargs": bnb_config_kwargs, # Quantization and precision settings "gptq": gptq, "gptq_groupsize": gptq_groupsize, "gptq_model_v1": gptq_model_v1, "load_in_8bit": load_in_8bit, "load_in_4bit": load_in_4bit, "bf16": bf16, "fp16": fp16, "tf32": tf32, "bfloat16": bfloat16, "float16": float16, "gpu_memory_limit": gpu_memory_limit, "lora_on_cpu": lora_on_cpu, # Dataset configurations "datasets": datasets, "test_datasets": test_datasets, "rl": rl, "chat_template": chat_template, "default_system_message": default_system_message, "dataset_prepared_path": dataset_prepared_path, "push_dataset_to_hub": push_dataset_to_hub, "dataset_processes": dataset_processes, "dataset_keep_in_memory": dataset_keep_in_memory, "hub_model_id": hub_model_id, "hub_strategy": hub_strategy, "hf_use_auth_token": hf_use_auth_token, "val_set_size": val_set_size, "dataset_shard_num": dataset_shard_num, "dataset_shard_idx": dataset_shard_idx, # Training hyperparameters "sequence_len": sequence_len, "pad_to_sequence_len": pad_to_sequence_len, "sample_packing": sample_packing, "eval_sample_packing": eval_sample_packing, "sample_packing_eff_est": sample_packing_eff_est, "total_num_tokens": total_num_tokens, "device_map": device_map, "max_memory": max_memory, # Adapter and LoRA settings "adapter": adapter, "lora_model_dir": lora_model_dir, "lora_r": lora_r, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout, "lora_target_modules": lora_target_modules, "lora_target_linear": lora_target_linear, "lora_modules_to_save": lora_modules_to_save, "lora_fan_in_fan_out": lora_fan_in_fan_out, "peft": peft, "relora_steps": relora_steps, "relora_warmup_steps": relora_warmup_steps, "relora_anneal_steps": relora_anneal_steps, "relora_prune_ratio": relora_prune_ratio, "relora_cpu_offload": relora_cpu_offload, # wandb and mlflow configurations "wandb_mode": wandb_mode, "wandb_project": wandb_project, "wandb_entity": wandb_entity, "wandb_watch": wandb_watch, "wandb_name": wandb_name, "wandb_run_id": wandb_run_id, "wandb_log_model": wandb_log_model, "mlflow_tracking_uri": mlflow_tracking_uri, "mlflow_experiment_name": mlflow_experiment_name, "output_dir": output_dir, "torch_compile": torch_compile, "torch_compile_backend": torch_compile_backend, "gradient_accumulation_steps": gradient_accumulation_steps, "micro_batch_size": micro_batch_size, "eval_batch_size": eval_batch_size, "num_epochs": num_epochs, "warmup_steps": warmup_steps, "warmup_ratio": warmup_ratio, "learning_rate": learning_rate, "lr_quadratic_warmup": lr_quadratic_warmup, "logging_steps": logging_steps, "eval_steps": eval_steps, "evals_per_epoch": evals_per_epoch, "save_strategy": save_strategy, "save_steps": save_steps, "saves_per_epoch": saves_per_epoch, "save_total_limit": save_total_limit, "max_steps": max_steps, "eval_table_size": eval_table_size, "eval_max_new_tokens": eval_max_new_tokens, "eval_causal_lm_metrics": eval_causal_lm_metrics, "loss_watchdog_threshold": loss_watchdog_threshold, "loss_watchdog_patience": loss_watchdog_patience, "save_safetensors": save_safetensors, "train_on_inputs": train_on_inputs, "group_by_length": group_by_length, "gradient_checkpointing": gradient_checkpointing, "early_stopping_patience": early_stopping_patience, "lr_scheduler": lr_scheduler, "lr_scheduler_kwargs": lr_scheduler_kwargs, "cosine_min_lr_ratio": cosine_min_lr_ratio, "cosine_constant_lr_ratio": cosine_constant_lr_ratio, "lr_div_factor": lr_div_factor, "log_sweep_min_lr": log_sweep_min_lr, "log_sweep_max_lr": log_sweep_max_lr, "optimizer": optimizer, "weight_decay": weight_decay, "adam_beta1": adam_beta1, "adam_beta2": adam_beta2, "adam_epsilon": adam_epsilon, "max_grad_norm": max_grad_norm, "neftune_noise_alpha": neftune_noise_alpha, "flash_optimum": flash_optimum, "xformers_attention": xformers_attention, "flash_attention": flash_attention, "flash_attn_cross_entropy": flash_attn_cross_entropy, "flash_attn_rms_norm": flash_attn_rms_norm, "flash_attn_fuse_qkv": flash_attn_fuse_qkv, "flash_attn_fuse_mlp": flash_attn_fuse_mlp, "sdp_attention": sdp_attention, "s2_attention": s2_attention, "resume_from_checkpoint": resume_from_checkpoint, "auto_resume_from_checkpoints": auto_resume_from_checkpoints, "local_rank": local_rank, "special_tokens": special_tokens, "tokens": tokens, "fsdp": fsdp, "fsdp_config": fsdp_config, "deepspeed": deepspeed, "ddp_timeout": ddp_timeout, "ddp_bucket_cap_mb": ddp_bucket_cap_mb, "ddp_broadcast_buffers": ddp_broadcast_buffers, "torchdistx_path": torchdistx_path, "pretraining_dataset": pretraining_dataset, "debug": debug, "seed": seed, "strict": strict, } with open("config.yml", "w", encoding="utf-8") as file: yaml.dump(config_dict, file) return yaml.dump(config_dict)