File size: 6,727 Bytes
434d0b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast
from torch.utils.data import DataLoader
class Charm15Model(nn.Module):
def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
"""Initialize Charm 15 with a pretrained model."""
super(Charm15Model, self).__init__()
self.device = device
self.model_name = model_name
try:
# Load tokenizer with padding fix
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
# Load model with optimizations
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # Memory-efficient
device_map="auto", # Auto-distribute
low_cpu_mem_usage=True
).to(self.device)
print(f"Loaded model {model_name} on {self.device}")
except Exception as e:
print(f"Error initializing model/tokenizer: {e}")
raise
def generate_text(self, prompt: str, max_length: int = 2048, temperature: float = 0.7,
top_k: int = 50, top_p: float = 0.9):
"""Generate text with the model."""
try:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.model.generate(
**inputs,
max_length=max_length, # Matches your config
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True, # From your generation config
repetition_penalty=1.1, # Anti-repetition
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True # Speed up
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
except Exception as e:
print(f"Error generating text: {e}")
return None
def fine_tune(self, train_dataloader: DataLoader, eval_dataloader: DataLoader = None,
epochs: int = 3, lr: float = 5e-5, gradient_accumulation_steps: int = 4):
"""Fine-tune the model with a DataLoader."""
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
self.model.train()
try:
for epoch in range(epochs):
total_loss = 0
for step, batch in enumerate(train_dataloader):
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
loss = outputs.loss / gradient_accumulation_steps # Normalize for accumulation
loss.backward()
if (step + 1) % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item() * gradient_accumulation_steps
avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}")
# Optional evaluation
if eval_dataloader:
eval_loss = self._evaluate(eval_dataloader)
print(f"Eval Loss: {eval_loss:.4f}")
except Exception as e:
print(f"Error during fine-tuning: {e}")
raise
def _evaluate(self, dataloader: DataLoader):
"""Evaluate the model on a DataLoader."""
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
total_loss += outputs.loss.item()
self.model.train()
return total_loss / len(dataloader)
def save_model(self, save_path: str):
"""Save model and tokenizer."""
try:
os.makedirs(save_path, exist_ok=True)
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
print(f"Model saved to {save_path}")
except Exception as e:
print(f"Error saving model: {e}")
def load_model(self, load_path: str):
"""Load model and tokenizer from a path."""
try:
self.model = AutoModelForCausalLM.from_pretrained(
load_path, torch_dtype=torch.bfloat16, device_map="auto"
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(load_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"Model loaded from {load_path}")
except Exception as e:
print(f"Error loading model: {e}")
raise
def quantize_model(self, bits: int = 8):
"""Quantize model for efficiency (basic dynamic quantization)."""
try:
if bits != 8:
print("⚠️ Only 8-bit quantization supported with torch.qint8")
self.model = torch.quantization.quantize_dynamic(
self.model, {nn.Linear}, dtype=torch.qint8
)
print("Model quantized to 8 bits (dynamic quantization)")
except Exception as e:
print(f"Error quantizing model: {e}")
if __name__ == "__main__":
# Example usage with your prior setup
model = Charm15Model(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
# Generate text
prompt = "Charm 15 is amazing because"
text = model.generate_text(prompt)
print(f"Generated: {text}")
# Assuming DataLoader from your earlier code
from your_dataloader_script import DataLoaderHandler # Adjust import
train_loader = DataLoaderHandler(
"../datasets/eclipse_corpuz_1.1.jsonl",
"../finetuned_charm15/tokenizer.json",
batch_size=4
).get_dataloader()
# Fine-tune
model.fine_tune(train_loader)
# Save
model.save_model("../finetuned_charm15")
# Quantize for 6G edge
model.quantize_model()
# Reload and test
model.load_model("../finetuned_charm15")
print(model.generate_text("Testing reloaded model")) |