Charm_15 / config_loader.py
GeminiFan207's picture
Create config_loader.py
651dc30 verified
import os
import json
from typing import Dict, Any, Optional
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedTokenizerFast,
TrainingArguments,
Trainer
)
class ConfigLoader:
"""A utility class to load configs and instantiate transformers objects."""
def __init__(self, config_path: str, default_dir: str = "../configs"):
"""Initialize with a config file path."""
self.config_path = os.path.join(default_dir, config_path) if not os.path.isabs(config_path) else config_path
self.config = {}
self.default_dir = default_dir
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._load_config()
def _load_config(self) -> None:
"""Load the configuration from a JSON file."""
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"Config file not found: {self.config_path}")
try:
with open(self.config_path, "r", encoding="utf-8") as f:
self.config = json.load(f)
print(f"✅ Loaded config from {self.config_path}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
except Exception as e:
raise RuntimeError(f"Error loading config: {e}")
def get(self, key: str, default: Any = None) -> Any:
"""Get a value from the config with an optional default."""
return self.config.get(key, default)
def validate(self, required_keys: list = None):
"""Validate required keys in the config."""
if required_keys:
missing = [key for key in required_keys if key not in self.config]
if missing:
raise KeyError(f"Missing required keys in config: {missing}")
def save(self, save_path: Optional[str] = None) -> None:
"""Save the current config to a file."""
path = save_path or self.config_path
os.makedirs(os.path.dirname(path), exist_ok=True)
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(self.config, f, indent=4)
print(f"✅ Config saved to {path}")
except Exception as e:
raise RuntimeError(f"Error saving config: {e}")
def load_model(self, model_path: Optional[str] = None) -> AutoModelForCausalLM:
"""Load a transformers model based on config or path."""
try:
model_name_or_path = model_path or self.config.get("model_name", "mistralai/Mixtral-8x7B-Instruct-v0.1")
model_config = self.config.get("model_config", {})
if model_path and not model_config: # Local path without config
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True
)
else: # Use config for custom model
from transformers import MistralConfig
config = MistralConfig(**model_config)
model = AutoModelForCausalLM.from_config(config)
return model.to(self.device)
except Exception as e:
raise RuntimeError(f"Error loading model: {e}")
def load_tokenizer(self, tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast:
"""Load a tokenizer based on config or path."""
try:
tokenizer_path = tokenizer_path or self.config.get("tokenizer_path", "../finetuned_charm15/tokenizer.json")
if tokenizer_path.endswith(".json"):
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"✅ Loaded tokenizer from {tokenizer_path}")
return tokenizer
except Exception as e:
raise RuntimeError(f"Error loading tokenizer: {e}")
def get_training_args(self) -> TrainingArguments:
"""Create TrainingArguments from config."""
try:
training_config = self.config.get("training_config", {
"output_dir": "../finetuned_charm15",
"per_device_train_batch_size": 1,
"num_train_epochs": 3,
"learning_rate": 5e-5,
"gradient_accumulation_steps": 8,
"bf16": True,
"save_strategy": "epoch",
"evaluation_strategy": "epoch",
"save_total_limit": 2,
"logging_steps": 100,
"report_to": "none"
})
return TrainingArguments(**training_config)
except Exception as e:
raise RuntimeError(f"Error creating TrainingArguments: {e}")
@staticmethod
def get_default_config() -> Dict[str, Any]:
"""Return a default config combining model, tokenizer, and training settings."""
return {
"model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"tokenizer_path": "../finetuned_charm15/tokenizer.json",
"model_config": {
"architectures": ["MistralForCausalLM"],
"hidden_size": 4096,
"num_hidden_layers": 8,
"vocab_size": 32000,
"max_position_embeddings": 4096,
"torch_dtype": "bfloat16"
},
"training_config": {
"output_dir": "../finetuned_charm15",
"per_device_train_batch_size": 1,
"num_train_epochs": 3,
"learning_rate": 5e-5,
"gradient_accumulation_steps": 8,
"bf16": True,
"save_strategy": "epoch",
"evaluation_strategy": "epoch",
"save_total_limit": 2,
"logging_steps": 100,
"report_to": "none"
},
"generation_config": {
"max_length": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.2,
"do_sample": True
}
}
if __name__ == "__main__":
# Example usage
config_loader = ConfigLoader("charm15_config.json")
# Load model and tokenizer
model = config_loader.load_model()
tokenizer = config_loader.load_tokenizer()
# Get training args
training_args = config_loader.get_training_args()
# Validate
config_loader.validate(["model_name", "training_config"])
# Test generation
inputs = tokenizer("Hello, Charm 15!", return_tensors="pt").to(config_loader.device)
outputs = model.generate(**inputs, **config_loader.get("generation_config", {}))
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
# Save updated config
config_loader.save("../finetuned_charm15/config.json")