import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast from torch.utils.data import DataLoader class Charm15Model(nn.Module): def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"): """Initialize Charm 15 with a pretrained model.""" super(Charm15Model, self).__init__() self.device = device self.model_name = model_name try: # Load tokenizer with padding fix self.tokenizer = AutoTokenizer.from_pretrained(model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # Load model with optimizations self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, # Memory-efficient device_map="auto", # Auto-distribute low_cpu_mem_usage=True ).to(self.device) print(f"Loaded model {model_name} on {self.device}") except Exception as e: print(f"Error initializing model/tokenizer: {e}") raise def generate_text(self, prompt: str, max_length: int = 2048, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9): """Generate text with the model.""" try: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): output = self.model.generate( **inputs, max_length=max_length, # Matches your config temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, # From your generation config repetition_penalty=1.1, # Anti-repetition pad_token_id=self.tokenizer.pad_token_id, use_cache=True # Speed up ) return self.tokenizer.decode(output[0], skip_special_tokens=True) except Exception as e: print(f"Error generating text: {e}") return None def fine_tune(self, train_dataloader: DataLoader, eval_dataloader: DataLoader = None, epochs: int = 3, lr: float = 5e-5, gradient_accumulation_steps: int = 4): """Fine-tune the model with a DataLoader.""" optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr) self.model.train() try: for epoch in range(epochs): total_loss = 0 for step, batch in enumerate(train_dataloader): batch = {k: v.to(self.device) for k, v in batch.items()} outputs = self.model(**batch) loss = outputs.loss / gradient_accumulation_steps # Normalize for accumulation loss.backward() if (step + 1) % gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() total_loss += loss.item() * gradient_accumulation_steps avg_loss = total_loss / len(train_dataloader) print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}") # Optional evaluation if eval_dataloader: eval_loss = self._evaluate(eval_dataloader) print(f"Eval Loss: {eval_loss:.4f}") except Exception as e: print(f"Error during fine-tuning: {e}") raise def _evaluate(self, dataloader: DataLoader): """Evaluate the model on a DataLoader.""" self.model.eval() total_loss = 0 with torch.no_grad(): for batch in dataloader: batch = {k: v.to(self.device) for k, v in batch.items()} outputs = self.model(**batch) total_loss += outputs.loss.item() self.model.train() return total_loss / len(dataloader) def save_model(self, save_path: str): """Save model and tokenizer.""" try: os.makedirs(save_path, exist_ok=True) self.model.save_pretrained(save_path) self.tokenizer.save_pretrained(save_path) print(f"Model saved to {save_path}") except Exception as e: print(f"Error saving model: {e}") def load_model(self, load_path: str): """Load model and tokenizer from a path.""" try: self.model = AutoModelForCausalLM.from_pretrained( load_path, torch_dtype=torch.bfloat16, device_map="auto" ).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(load_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"Model loaded from {load_path}") except Exception as e: print(f"Error loading model: {e}") raise def quantize_model(self, bits: int = 8): """Quantize model for efficiency (basic dynamic quantization).""" try: if bits != 8: print("⚠️ Only 8-bit quantization supported with torch.qint8") self.model = torch.quantization.quantize_dynamic( self.model, {nn.Linear}, dtype=torch.qint8 ) print("Model quantized to 8 bits (dynamic quantization)") except Exception as e: print(f"Error quantizing model: {e}") if __name__ == "__main__": # Example usage with your prior setup model = Charm15Model(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1") # Generate text prompt = "Charm 15 is amazing because" text = model.generate_text(prompt) print(f"Generated: {text}") # Assuming DataLoader from your earlier code from your_dataloader_script import DataLoaderHandler # Adjust import train_loader = DataLoaderHandler( "../datasets/eclipse_corpuz_1.1.jsonl", "../finetuned_charm15/tokenizer.json", batch_size=4 ).get_dataloader() # Fine-tune model.fine_tune(train_loader) # Save model.save_model("../finetuned_charm15") # Quantize for 6G edge model.quantize_model() # Reload and test model.load_model("../finetuned_charm15") print(model.generate_text("Testing reloaded model"))