File size: 7,196 Bytes
651dc30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
import json
from typing import Dict, Any, Optional
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedTokenizerFast,
TrainingArguments,
Trainer
)
class ConfigLoader:
"""A utility class to load configs and instantiate transformers objects."""
def __init__(self, config_path: str, default_dir: str = "../configs"):
"""Initialize with a config file path."""
self.config_path = os.path.join(default_dir, config_path) if not os.path.isabs(config_path) else config_path
self.config = {}
self.default_dir = default_dir
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._load_config()
def _load_config(self) -> None:
"""Load the configuration from a JSON file."""
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"Config file not found: {self.config_path}")
try:
with open(self.config_path, "r", encoding="utf-8") as f:
self.config = json.load(f)
print(f"✅ Loaded config from {self.config_path}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
except Exception as e:
raise RuntimeError(f"Error loading config: {e}")
def get(self, key: str, default: Any = None) -> Any:
"""Get a value from the config with an optional default."""
return self.config.get(key, default)
def validate(self, required_keys: list = None):
"""Validate required keys in the config."""
if required_keys:
missing = [key for key in required_keys if key not in self.config]
if missing:
raise KeyError(f"Missing required keys in config: {missing}")
def save(self, save_path: Optional[str] = None) -> None:
"""Save the current config to a file."""
path = save_path or self.config_path
os.makedirs(os.path.dirname(path), exist_ok=True)
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(self.config, f, indent=4)
print(f"✅ Config saved to {path}")
except Exception as e:
raise RuntimeError(f"Error saving config: {e}")
def load_model(self, model_path: Optional[str] = None) -> AutoModelForCausalLM:
"""Load a transformers model based on config or path."""
try:
model_name_or_path = model_path or self.config.get("model_name", "mistralai/Mixtral-8x7B-Instruct-v0.1")
model_config = self.config.get("model_config", {})
if model_path and not model_config: # Local path without config
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True
)
else: # Use config for custom model
from transformers import MistralConfig
config = MistralConfig(**model_config)
model = AutoModelForCausalLM.from_config(config)
return model.to(self.device)
except Exception as e:
raise RuntimeError(f"Error loading model: {e}")
def load_tokenizer(self, tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast:
"""Load a tokenizer based on config or path."""
try:
tokenizer_path = tokenizer_path or self.config.get("tokenizer_path", "../finetuned_charm15/tokenizer.json")
if tokenizer_path.endswith(".json"):
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"✅ Loaded tokenizer from {tokenizer_path}")
return tokenizer
except Exception as e:
raise RuntimeError(f"Error loading tokenizer: {e}")
def get_training_args(self) -> TrainingArguments:
"""Create TrainingArguments from config."""
try:
training_config = self.config.get("training_config", {
"output_dir": "../finetuned_charm15",
"per_device_train_batch_size": 1,
"num_train_epochs": 3,
"learning_rate": 5e-5,
"gradient_accumulation_steps": 8,
"bf16": True,
"save_strategy": "epoch",
"evaluation_strategy": "epoch",
"save_total_limit": 2,
"logging_steps": 100,
"report_to": "none"
})
return TrainingArguments(**training_config)
except Exception as e:
raise RuntimeError(f"Error creating TrainingArguments: {e}")
@staticmethod
def get_default_config() -> Dict[str, Any]:
"""Return a default config combining model, tokenizer, and training settings."""
return {
"model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"tokenizer_path": "../finetuned_charm15/tokenizer.json",
"model_config": {
"architectures": ["MistralForCausalLM"],
"hidden_size": 4096,
"num_hidden_layers": 8,
"vocab_size": 32000,
"max_position_embeddings": 4096,
"torch_dtype": "bfloat16"
},
"training_config": {
"output_dir": "../finetuned_charm15",
"per_device_train_batch_size": 1,
"num_train_epochs": 3,
"learning_rate": 5e-5,
"gradient_accumulation_steps": 8,
"bf16": True,
"save_strategy": "epoch",
"evaluation_strategy": "epoch",
"save_total_limit": 2,
"logging_steps": 100,
"report_to": "none"
},
"generation_config": {
"max_length": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.2,
"do_sample": True
}
}
if __name__ == "__main__":
# Example usage
config_loader = ConfigLoader("charm15_config.json")
# Load model and tokenizer
model = config_loader.load_model()
tokenizer = config_loader.load_tokenizer()
# Get training args
training_args = config_loader.get_training_args()
# Validate
config_loader.validate(["model_name", "training_config"])
# Test generation
inputs = tokenizer("Hello, Charm 15!", return_tensors="pt").to(config_loader.device)
outputs = model.generate(**inputs, **config_loader.get("generation_config", {}))
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
# Save updated config
config_loader.save("../finetuned_charm15/config.json") |