|
import os |
|
import json |
|
from typing import Dict, Any, Optional |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
PreTrainedTokenizerFast, |
|
TrainingArguments, |
|
Trainer |
|
) |
|
|
|
class ConfigLoader: |
|
"""A utility class to load configs and instantiate transformers objects.""" |
|
|
|
def __init__(self, config_path: str, default_dir: str = "../configs"): |
|
"""Initialize with a config file path.""" |
|
self.config_path = os.path.join(default_dir, config_path) if not os.path.isabs(config_path) else config_path |
|
self.config = {} |
|
self.default_dir = default_dir |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self._load_config() |
|
|
|
def _load_config(self) -> None: |
|
"""Load the configuration from a JSON file.""" |
|
if not os.path.exists(self.config_path): |
|
raise FileNotFoundError(f"Config file not found: {self.config_path}") |
|
|
|
try: |
|
with open(self.config_path, "r", encoding="utf-8") as f: |
|
self.config = json.load(f) |
|
print(f"✅ Loaded config from {self.config_path}") |
|
except json.JSONDecodeError as e: |
|
raise ValueError(f"Invalid JSON in {self.config_path}: {e}") |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading config: {e}") |
|
|
|
def get(self, key: str, default: Any = None) -> Any: |
|
"""Get a value from the config with an optional default.""" |
|
return self.config.get(key, default) |
|
|
|
def validate(self, required_keys: list = None): |
|
"""Validate required keys in the config.""" |
|
if required_keys: |
|
missing = [key for key in required_keys if key not in self.config] |
|
if missing: |
|
raise KeyError(f"Missing required keys in config: {missing}") |
|
|
|
def save(self, save_path: Optional[str] = None) -> None: |
|
"""Save the current config to a file.""" |
|
path = save_path or self.config_path |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
try: |
|
with open(path, "w", encoding="utf-8") as f: |
|
json.dump(self.config, f, indent=4) |
|
print(f"✅ Config saved to {path}") |
|
except Exception as e: |
|
raise RuntimeError(f"Error saving config: {e}") |
|
|
|
def load_model(self, model_path: Optional[str] = None) -> AutoModelForCausalLM: |
|
"""Load a transformers model based on config or path.""" |
|
try: |
|
model_name_or_path = model_path or self.config.get("model_name", "mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
model_config = self.config.get("model_config", {}) |
|
if model_path and not model_config: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
low_cpu_mem_usage=True |
|
) |
|
else: |
|
from transformers import MistralConfig |
|
config = MistralConfig(**model_config) |
|
model = AutoModelForCausalLM.from_config(config) |
|
return model.to(self.device) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model: {e}") |
|
|
|
def load_tokenizer(self, tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast: |
|
"""Load a tokenizer based on config or path.""" |
|
try: |
|
tokenizer_path = tokenizer_path or self.config.get("tokenizer_path", "../finetuned_charm15/tokenizer.json") |
|
if tokenizer_path.endswith(".json"): |
|
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
print(f"✅ Loaded tokenizer from {tokenizer_path}") |
|
return tokenizer |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading tokenizer: {e}") |
|
|
|
def get_training_args(self) -> TrainingArguments: |
|
"""Create TrainingArguments from config.""" |
|
try: |
|
training_config = self.config.get("training_config", { |
|
"output_dir": "../finetuned_charm15", |
|
"per_device_train_batch_size": 1, |
|
"num_train_epochs": 3, |
|
"learning_rate": 5e-5, |
|
"gradient_accumulation_steps": 8, |
|
"bf16": True, |
|
"save_strategy": "epoch", |
|
"evaluation_strategy": "epoch", |
|
"save_total_limit": 2, |
|
"logging_steps": 100, |
|
"report_to": "none" |
|
}) |
|
return TrainingArguments(**training_config) |
|
except Exception as e: |
|
raise RuntimeError(f"Error creating TrainingArguments: {e}") |
|
|
|
@staticmethod |
|
def get_default_config() -> Dict[str, Any]: |
|
"""Return a default config combining model, tokenizer, and training settings.""" |
|
return { |
|
"model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
"tokenizer_path": "../finetuned_charm15/tokenizer.json", |
|
"model_config": { |
|
"architectures": ["MistralForCausalLM"], |
|
"hidden_size": 4096, |
|
"num_hidden_layers": 8, |
|
"vocab_size": 32000, |
|
"max_position_embeddings": 4096, |
|
"torch_dtype": "bfloat16" |
|
}, |
|
"training_config": { |
|
"output_dir": "../finetuned_charm15", |
|
"per_device_train_batch_size": 1, |
|
"num_train_epochs": 3, |
|
"learning_rate": 5e-5, |
|
"gradient_accumulation_steps": 8, |
|
"bf16": True, |
|
"save_strategy": "epoch", |
|
"evaluation_strategy": "epoch", |
|
"save_total_limit": 2, |
|
"logging_steps": 100, |
|
"report_to": "none" |
|
}, |
|
"generation_config": { |
|
"max_length": 2048, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"top_k": 50, |
|
"repetition_penalty": 1.2, |
|
"do_sample": True |
|
} |
|
} |
|
|
|
if __name__ == "__main__": |
|
|
|
config_loader = ConfigLoader("charm15_config.json") |
|
|
|
|
|
model = config_loader.load_model() |
|
tokenizer = config_loader.load_tokenizer() |
|
|
|
|
|
training_args = config_loader.get_training_args() |
|
|
|
|
|
config_loader.validate(["model_name", "training_config"]) |
|
|
|
|
|
inputs = tokenizer("Hello, Charm 15!", return_tensors="pt").to(config_loader.device) |
|
outputs = model.generate(**inputs, **config_loader.get("generation_config", {})) |
|
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") |
|
|
|
|
|
config_loader.save("../finetuned_charm15/config.json") |