File size: 6,727 Bytes
434d0b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast
from torch.utils.data import DataLoader

class Charm15Model(nn.Module):
    def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        """Initialize Charm 15 with a pretrained model."""
        super(Charm15Model, self).__init__()
        self.device = device
        self.model_name = model_name
        
        try:
            # Load tokenizer with padding fix
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            
            # Load model with optimizations
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,  # Memory-efficient
                device_map="auto",           # Auto-distribute
                low_cpu_mem_usage=True
            ).to(self.device)
            print(f"Loaded model {model_name} on {self.device}")
        except Exception as e:
            print(f"Error initializing model/tokenizer: {e}")
            raise

    def generate_text(self, prompt: str, max_length: int = 2048, temperature: float = 0.7, 
                      top_k: int = 50, top_p: float = 0.9):
        """Generate text with the model."""
        try:
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            with torch.no_grad():
                output = self.model.generate(
                    **inputs,
                    max_length=max_length,      # Matches your config
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    do_sample=True,             # From your generation config
                    repetition_penalty=1.1,     # Anti-repetition
                    pad_token_id=self.tokenizer.pad_token_id,
                    use_cache=True              # Speed up
                )
            return self.tokenizer.decode(output[0], skip_special_tokens=True)
        except Exception as e:
            print(f"Error generating text: {e}")
            return None

    def fine_tune(self, train_dataloader: DataLoader, eval_dataloader: DataLoader = None, 
                  epochs: int = 3, lr: float = 5e-5, gradient_accumulation_steps: int = 4):
        """Fine-tune the model with a DataLoader."""
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
        self.model.train()
        
        try:
            for epoch in range(epochs):
                total_loss = 0
                for step, batch in enumerate(train_dataloader):
                    batch = {k: v.to(self.device) for k, v in batch.items()}
                    outputs = self.model(**batch)
                    loss = outputs.loss / gradient_accumulation_steps  # Normalize for accumulation
                    
                    loss.backward()
                    if (step + 1) % gradient_accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()
                    
                    total_loss += loss.item() * gradient_accumulation_steps
                
                avg_loss = total_loss / len(train_dataloader)
                print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}")
                
                # Optional evaluation
                if eval_dataloader:
                    eval_loss = self._evaluate(eval_dataloader)
                    print(f"Eval Loss: {eval_loss:.4f}")
        except Exception as e:
            print(f"Error during fine-tuning: {e}")
            raise

    def _evaluate(self, dataloader: DataLoader):
        """Evaluate the model on a DataLoader."""
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for batch in dataloader:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                outputs = self.model(**batch)
                total_loss += outputs.loss.item()
        self.model.train()
        return total_loss / len(dataloader)

    def save_model(self, save_path: str):
        """Save model and tokenizer."""
        try:
            os.makedirs(save_path, exist_ok=True)
            self.model.save_pretrained(save_path)
            self.tokenizer.save_pretrained(save_path)
            print(f"Model saved to {save_path}")
        except Exception as e:
            print(f"Error saving model: {e}")

    def load_model(self, load_path: str):
        """Load model and tokenizer from a path."""
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                load_path, torch_dtype=torch.bfloat16, device_map="auto"
            ).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(load_path)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            print(f"Model loaded from {load_path}")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise

    def quantize_model(self, bits: int = 8):
        """Quantize model for efficiency (basic dynamic quantization)."""
        try:
            if bits != 8:
                print("⚠️ Only 8-bit quantization supported with torch.qint8")
            self.model = torch.quantization.quantize_dynamic(
                self.model, {nn.Linear}, dtype=torch.qint8
            )
            print("Model quantized to 8 bits (dynamic quantization)")
        except Exception as e:
            print(f"Error quantizing model: {e}")

if __name__ == "__main__":
    # Example usage with your prior setup
    model = Charm15Model(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    # Generate text
    prompt = "Charm 15 is amazing because"
    text = model.generate_text(prompt)
    print(f"Generated: {text}")
    
    # Assuming DataLoader from your earlier code
    from your_dataloader_script import DataLoaderHandler  # Adjust import
    train_loader = DataLoaderHandler(
        "../datasets/eclipse_corpuz_1.1.jsonl", 
        "../finetuned_charm15/tokenizer.json", 
        batch_size=4
    ).get_dataloader()
    
    # Fine-tune
    model.fine_tune(train_loader)
    
    # Save
    model.save_model("../finetuned_charm15")
    
    # Quantize for 6G edge
    model.quantize_model()
    
    # Reload and test
    model.load_model("../finetuned_charm15")
    print(model.generate_text("Testing reloaded model"))