import os import json from typing import Dict, Any, Optional import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, TrainingArguments, Trainer ) class ConfigLoader: """A utility class to load configs and instantiate transformers objects.""" def __init__(self, config_path: str, default_dir: str = "../configs"): """Initialize with a config file path.""" self.config_path = os.path.join(default_dir, config_path) if not os.path.isabs(config_path) else config_path self.config = {} self.default_dir = default_dir self.device = "cuda" if torch.cuda.is_available() else "cpu" self._load_config() def _load_config(self) -> None: """Load the configuration from a JSON file.""" if not os.path.exists(self.config_path): raise FileNotFoundError(f"Config file not found: {self.config_path}") try: with open(self.config_path, "r", encoding="utf-8") as f: self.config = json.load(f) print(f"✅ Loaded config from {self.config_path}") except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in {self.config_path}: {e}") except Exception as e: raise RuntimeError(f"Error loading config: {e}") def get(self, key: str, default: Any = None) -> Any: """Get a value from the config with an optional default.""" return self.config.get(key, default) def validate(self, required_keys: list = None): """Validate required keys in the config.""" if required_keys: missing = [key for key in required_keys if key not in self.config] if missing: raise KeyError(f"Missing required keys in config: {missing}") def save(self, save_path: Optional[str] = None) -> None: """Save the current config to a file.""" path = save_path or self.config_path os.makedirs(os.path.dirname(path), exist_ok=True) try: with open(path, "w", encoding="utf-8") as f: json.dump(self.config, f, indent=4) print(f"✅ Config saved to {path}") except Exception as e: raise RuntimeError(f"Error saving config: {e}") def load_model(self, model_path: Optional[str] = None) -> AutoModelForCausalLM: """Load a transformers model based on config or path.""" try: model_name_or_path = model_path or self.config.get("model_name", "mistralai/Mixtral-8x7B-Instruct-v0.1") model_config = self.config.get("model_config", {}) if model_path and not model_config: # Local path without config model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True ) else: # Use config for custom model from transformers import MistralConfig config = MistralConfig(**model_config) model = AutoModelForCausalLM.from_config(config) return model.to(self.device) except Exception as e: raise RuntimeError(f"Error loading model: {e}") def load_tokenizer(self, tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast: """Load a tokenizer based on config or path.""" try: tokenizer_path = tokenizer_path or self.config.get("tokenizer_path", "../finetuned_charm15/tokenizer.json") if tokenizer_path.endswith(".json"): tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) else: tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id print(f"✅ Loaded tokenizer from {tokenizer_path}") return tokenizer except Exception as e: raise RuntimeError(f"Error loading tokenizer: {e}") def get_training_args(self) -> TrainingArguments: """Create TrainingArguments from config.""" try: training_config = self.config.get("training_config", { "output_dir": "../finetuned_charm15", "per_device_train_batch_size": 1, "num_train_epochs": 3, "learning_rate": 5e-5, "gradient_accumulation_steps": 8, "bf16": True, "save_strategy": "epoch", "evaluation_strategy": "epoch", "save_total_limit": 2, "logging_steps": 100, "report_to": "none" }) return TrainingArguments(**training_config) except Exception as e: raise RuntimeError(f"Error creating TrainingArguments: {e}") @staticmethod def get_default_config() -> Dict[str, Any]: """Return a default config combining model, tokenizer, and training settings.""" return { "model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1", "tokenizer_path": "../finetuned_charm15/tokenizer.json", "model_config": { "architectures": ["MistralForCausalLM"], "hidden_size": 4096, "num_hidden_layers": 8, "vocab_size": 32000, "max_position_embeddings": 4096, "torch_dtype": "bfloat16" }, "training_config": { "output_dir": "../finetuned_charm15", "per_device_train_batch_size": 1, "num_train_epochs": 3, "learning_rate": 5e-5, "gradient_accumulation_steps": 8, "bf16": True, "save_strategy": "epoch", "evaluation_strategy": "epoch", "save_total_limit": 2, "logging_steps": 100, "report_to": "none" }, "generation_config": { "max_length": 2048, "temperature": 0.7, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.2, "do_sample": True } } if __name__ == "__main__": # Example usage config_loader = ConfigLoader("charm15_config.json") # Load model and tokenizer model = config_loader.load_model() tokenizer = config_loader.load_tokenizer() # Get training args training_args = config_loader.get_training_args() # Validate config_loader.validate(["model_name", "training_config"]) # Test generation inputs = tokenizer("Hello, Charm 15!", return_tensors="pt").to(config_loader.device) outputs = model.generate(**inputs, **config_loader.get("generation_config", {})) print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") # Save updated config config_loader.save("../finetuned_charm15/config.json")