GeminiFan207 commited on
Commit
651dc30
·
verified ·
1 Parent(s): 434d0b4

Create config_loader.py

Browse files
Files changed (1) hide show
  1. config_loader.py +174 -0
config_loader.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Dict, Any, Optional
4
+ import torch
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ PreTrainedTokenizerFast,
9
+ TrainingArguments,
10
+ Trainer
11
+ )
12
+
13
+ class ConfigLoader:
14
+ """A utility class to load configs and instantiate transformers objects."""
15
+
16
+ def __init__(self, config_path: str, default_dir: str = "../configs"):
17
+ """Initialize with a config file path."""
18
+ self.config_path = os.path.join(default_dir, config_path) if not os.path.isabs(config_path) else config_path
19
+ self.config = {}
20
+ self.default_dir = default_dir
21
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ self._load_config()
23
+
24
+ def _load_config(self) -> None:
25
+ """Load the configuration from a JSON file."""
26
+ if not os.path.exists(self.config_path):
27
+ raise FileNotFoundError(f"Config file not found: {self.config_path}")
28
+
29
+ try:
30
+ with open(self.config_path, "r", encoding="utf-8") as f:
31
+ self.config = json.load(f)
32
+ print(f"✅ Loaded config from {self.config_path}")
33
+ except json.JSONDecodeError as e:
34
+ raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
35
+ except Exception as e:
36
+ raise RuntimeError(f"Error loading config: {e}")
37
+
38
+ def get(self, key: str, default: Any = None) -> Any:
39
+ """Get a value from the config with an optional default."""
40
+ return self.config.get(key, default)
41
+
42
+ def validate(self, required_keys: list = None):
43
+ """Validate required keys in the config."""
44
+ if required_keys:
45
+ missing = [key for key in required_keys if key not in self.config]
46
+ if missing:
47
+ raise KeyError(f"Missing required keys in config: {missing}")
48
+
49
+ def save(self, save_path: Optional[str] = None) -> None:
50
+ """Save the current config to a file."""
51
+ path = save_path or self.config_path
52
+ os.makedirs(os.path.dirname(path), exist_ok=True)
53
+ try:
54
+ with open(path, "w", encoding="utf-8") as f:
55
+ json.dump(self.config, f, indent=4)
56
+ print(f"✅ Config saved to {path}")
57
+ except Exception as e:
58
+ raise RuntimeError(f"Error saving config: {e}")
59
+
60
+ def load_model(self, model_path: Optional[str] = None) -> AutoModelForCausalLM:
61
+ """Load a transformers model based on config or path."""
62
+ try:
63
+ model_name_or_path = model_path or self.config.get("model_name", "mistralai/Mixtral-8x7B-Instruct-v0.1")
64
+ model_config = self.config.get("model_config", {})
65
+ if model_path and not model_config: # Local path without config
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_name_or_path,
68
+ torch_dtype=torch.bfloat16,
69
+ device_map="auto",
70
+ low_cpu_mem_usage=True
71
+ )
72
+ else: # Use config for custom model
73
+ from transformers import MistralConfig
74
+ config = MistralConfig(**model_config)
75
+ model = AutoModelForCausalLM.from_config(config)
76
+ return model.to(self.device)
77
+ except Exception as e:
78
+ raise RuntimeError(f"Error loading model: {e}")
79
+
80
+ def load_tokenizer(self, tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast:
81
+ """Load a tokenizer based on config or path."""
82
+ try:
83
+ tokenizer_path = tokenizer_path or self.config.get("tokenizer_path", "../finetuned_charm15/tokenizer.json")
84
+ if tokenizer_path.endswith(".json"):
85
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
86
+ else:
87
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
88
+
89
+ if tokenizer.pad_token is None:
90
+ tokenizer.pad_token = tokenizer.eos_token
91
+ tokenizer.pad_token_id = tokenizer.eos_token_id
92
+ print(f"✅ Loaded tokenizer from {tokenizer_path}")
93
+ return tokenizer
94
+ except Exception as e:
95
+ raise RuntimeError(f"Error loading tokenizer: {e}")
96
+
97
+ def get_training_args(self) -> TrainingArguments:
98
+ """Create TrainingArguments from config."""
99
+ try:
100
+ training_config = self.config.get("training_config", {
101
+ "output_dir": "../finetuned_charm15",
102
+ "per_device_train_batch_size": 1,
103
+ "num_train_epochs": 3,
104
+ "learning_rate": 5e-5,
105
+ "gradient_accumulation_steps": 8,
106
+ "bf16": True,
107
+ "save_strategy": "epoch",
108
+ "evaluation_strategy": "epoch",
109
+ "save_total_limit": 2,
110
+ "logging_steps": 100,
111
+ "report_to": "none"
112
+ })
113
+ return TrainingArguments(**training_config)
114
+ except Exception as e:
115
+ raise RuntimeError(f"Error creating TrainingArguments: {e}")
116
+
117
+ @staticmethod
118
+ def get_default_config() -> Dict[str, Any]:
119
+ """Return a default config combining model, tokenizer, and training settings."""
120
+ return {
121
+ "model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
122
+ "tokenizer_path": "../finetuned_charm15/tokenizer.json",
123
+ "model_config": {
124
+ "architectures": ["MistralForCausalLM"],
125
+ "hidden_size": 4096,
126
+ "num_hidden_layers": 8,
127
+ "vocab_size": 32000,
128
+ "max_position_embeddings": 4096,
129
+ "torch_dtype": "bfloat16"
130
+ },
131
+ "training_config": {
132
+ "output_dir": "../finetuned_charm15",
133
+ "per_device_train_batch_size": 1,
134
+ "num_train_epochs": 3,
135
+ "learning_rate": 5e-5,
136
+ "gradient_accumulation_steps": 8,
137
+ "bf16": True,
138
+ "save_strategy": "epoch",
139
+ "evaluation_strategy": "epoch",
140
+ "save_total_limit": 2,
141
+ "logging_steps": 100,
142
+ "report_to": "none"
143
+ },
144
+ "generation_config": {
145
+ "max_length": 2048,
146
+ "temperature": 0.7,
147
+ "top_p": 0.9,
148
+ "top_k": 50,
149
+ "repetition_penalty": 1.2,
150
+ "do_sample": True
151
+ }
152
+ }
153
+
154
+ if __name__ == "__main__":
155
+ # Example usage
156
+ config_loader = ConfigLoader("charm15_config.json")
157
+
158
+ # Load model and tokenizer
159
+ model = config_loader.load_model()
160
+ tokenizer = config_loader.load_tokenizer()
161
+
162
+ # Get training args
163
+ training_args = config_loader.get_training_args()
164
+
165
+ # Validate
166
+ config_loader.validate(["model_name", "training_config"])
167
+
168
+ # Test generation
169
+ inputs = tokenizer("Hello, Charm 15!", return_tensors="pt").to(config_loader.device)
170
+ outputs = model.generate(**inputs, **config_loader.get("generation_config", {}))
171
+ print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
172
+
173
+ # Save updated config
174
+ config_loader.save("../finetuned_charm15/config.json")