File size: 7,196 Bytes
651dc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import json
from typing import Dict, Any, Optional
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizerFast,
    TrainingArguments,
    Trainer
)

class ConfigLoader:
    """A utility class to load configs and instantiate transformers objects."""
    
    def __init__(self, config_path: str, default_dir: str = "../configs"):
        """Initialize with a config file path."""
        self.config_path = os.path.join(default_dir, config_path) if not os.path.isabs(config_path) else config_path
        self.config = {}
        self.default_dir = default_dir
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self._load_config()

    def _load_config(self) -> None:
        """Load the configuration from a JSON file."""
        if not os.path.exists(self.config_path):
            raise FileNotFoundError(f"Config file not found: {self.config_path}")
        
        try:
            with open(self.config_path, "r", encoding="utf-8") as f:
                self.config = json.load(f)
            print(f"✅ Loaded config from {self.config_path}")
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON in {self.config_path}: {e}")
        except Exception as e:
            raise RuntimeError(f"Error loading config: {e}")

    def get(self, key: str, default: Any = None) -> Any:
        """Get a value from the config with an optional default."""
        return self.config.get(key, default)

    def validate(self, required_keys: list = None):
        """Validate required keys in the config."""
        if required_keys:
            missing = [key for key in required_keys if key not in self.config]
            if missing:
                raise KeyError(f"Missing required keys in config: {missing}")

    def save(self, save_path: Optional[str] = None) -> None:
        """Save the current config to a file."""
        path = save_path or self.config_path
        os.makedirs(os.path.dirname(path), exist_ok=True)
        try:
            with open(path, "w", encoding="utf-8") as f:
                json.dump(self.config, f, indent=4)
            print(f"✅ Config saved to {path}")
        except Exception as e:
            raise RuntimeError(f"Error saving config: {e}")

    def load_model(self, model_path: Optional[str] = None) -> AutoModelForCausalLM:
        """Load a transformers model based on config or path."""
        try:
            model_name_or_path = model_path or self.config.get("model_name", "mistralai/Mixtral-8x7B-Instruct-v0.1")
            model_config = self.config.get("model_config", {})
            if model_path and not model_config:  # Local path without config
                model = AutoModelForCausalLM.from_pretrained(
                    model_name_or_path,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                    low_cpu_mem_usage=True
                )
            else:  # Use config for custom model
                from transformers import MistralConfig
                config = MistralConfig(**model_config)
                model = AutoModelForCausalLM.from_config(config)
            return model.to(self.device)
        except Exception as e:
            raise RuntimeError(f"Error loading model: {e}")

    def load_tokenizer(self, tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast:
        """Load a tokenizer based on config or path."""
        try:
            tokenizer_path = tokenizer_path or self.config.get("tokenizer_path", "../finetuned_charm15/tokenizer.json")
            if tokenizer_path.endswith(".json"):
                tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
            else:
                tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                tokenizer.pad_token_id = tokenizer.eos_token_id
            print(f"✅ Loaded tokenizer from {tokenizer_path}")
            return tokenizer
        except Exception as e:
            raise RuntimeError(f"Error loading tokenizer: {e}")

    def get_training_args(self) -> TrainingArguments:
        """Create TrainingArguments from config."""
        try:
            training_config = self.config.get("training_config", {
                "output_dir": "../finetuned_charm15",
                "per_device_train_batch_size": 1,
                "num_train_epochs": 3,
                "learning_rate": 5e-5,
                "gradient_accumulation_steps": 8,
                "bf16": True,
                "save_strategy": "epoch",
                "evaluation_strategy": "epoch",
                "save_total_limit": 2,
                "logging_steps": 100,
                "report_to": "none"
            })
            return TrainingArguments(**training_config)
        except Exception as e:
            raise RuntimeError(f"Error creating TrainingArguments: {e}")

    @staticmethod
    def get_default_config() -> Dict[str, Any]:
        """Return a default config combining model, tokenizer, and training settings."""
        return {
            "model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
            "tokenizer_path": "../finetuned_charm15/tokenizer.json",
            "model_config": {
                "architectures": ["MistralForCausalLM"],
                "hidden_size": 4096,
                "num_hidden_layers": 8,
                "vocab_size": 32000,
                "max_position_embeddings": 4096,
                "torch_dtype": "bfloat16"
            },
            "training_config": {
                "output_dir": "../finetuned_charm15",
                "per_device_train_batch_size": 1,
                "num_train_epochs": 3,
                "learning_rate": 5e-5,
                "gradient_accumulation_steps": 8,
                "bf16": True,
                "save_strategy": "epoch",
                "evaluation_strategy": "epoch",
                "save_total_limit": 2,
                "logging_steps": 100,
                "report_to": "none"
            },
            "generation_config": {
                "max_length": 2048,
                "temperature": 0.7,
                "top_p": 0.9,
                "top_k": 50,
                "repetition_penalty": 1.2,
                "do_sample": True
            }
        }

if __name__ == "__main__":
    # Example usage
    config_loader = ConfigLoader("charm15_config.json")
    
    # Load model and tokenizer
    model = config_loader.load_model()
    tokenizer = config_loader.load_tokenizer()
    
    # Get training args
    training_args = config_loader.get_training_args()
    
    # Validate
    config_loader.validate(["model_name", "training_config"])
    
    # Test generation
    inputs = tokenizer("Hello, Charm 15!", return_tensors="pt").to(config_loader.device)
    outputs = model.generate(**inputs, **config_loader.get("generation_config", {}))
    print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
    
    # Save updated config
    config_loader.save("../finetuned_charm15/config.json")