|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast |
|
from torch.utils.data import DataLoader |
|
|
|
class Charm15Model(nn.Module): |
|
def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"): |
|
"""Initialize Charm 15 with a pretrained model.""" |
|
super(Charm15Model, self).__init__() |
|
self.device = device |
|
self.model_name = model_name |
|
|
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
low_cpu_mem_usage=True |
|
).to(self.device) |
|
print(f"Loaded model {model_name} on {self.device}") |
|
except Exception as e: |
|
print(f"Error initializing model/tokenizer: {e}") |
|
raise |
|
|
|
def generate_text(self, prompt: str, max_length: int = 2048, temperature: float = 0.7, |
|
top_k: int = 50, top_p: float = 0.9): |
|
"""Generate text with the model.""" |
|
try: |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
with torch.no_grad(): |
|
output = self.model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
do_sample=True, |
|
repetition_penalty=1.1, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
use_cache=True |
|
) |
|
return self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
except Exception as e: |
|
print(f"Error generating text: {e}") |
|
return None |
|
|
|
def fine_tune(self, train_dataloader: DataLoader, eval_dataloader: DataLoader = None, |
|
epochs: int = 3, lr: float = 5e-5, gradient_accumulation_steps: int = 4): |
|
"""Fine-tune the model with a DataLoader.""" |
|
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr) |
|
self.model.train() |
|
|
|
try: |
|
for epoch in range(epochs): |
|
total_loss = 0 |
|
for step, batch in enumerate(train_dataloader): |
|
batch = {k: v.to(self.device) for k, v in batch.items()} |
|
outputs = self.model(**batch) |
|
loss = outputs.loss / gradient_accumulation_steps |
|
|
|
loss.backward() |
|
if (step + 1) % gradient_accumulation_steps == 0: |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
total_loss += loss.item() * gradient_accumulation_steps |
|
|
|
avg_loss = total_loss / len(train_dataloader) |
|
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}") |
|
|
|
|
|
if eval_dataloader: |
|
eval_loss = self._evaluate(eval_dataloader) |
|
print(f"Eval Loss: {eval_loss:.4f}") |
|
except Exception as e: |
|
print(f"Error during fine-tuning: {e}") |
|
raise |
|
|
|
def _evaluate(self, dataloader: DataLoader): |
|
"""Evaluate the model on a DataLoader.""" |
|
self.model.eval() |
|
total_loss = 0 |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
batch = {k: v.to(self.device) for k, v in batch.items()} |
|
outputs = self.model(**batch) |
|
total_loss += outputs.loss.item() |
|
self.model.train() |
|
return total_loss / len(dataloader) |
|
|
|
def save_model(self, save_path: str): |
|
"""Save model and tokenizer.""" |
|
try: |
|
os.makedirs(save_path, exist_ok=True) |
|
self.model.save_pretrained(save_path) |
|
self.tokenizer.save_pretrained(save_path) |
|
print(f"Model saved to {save_path}") |
|
except Exception as e: |
|
print(f"Error saving model: {e}") |
|
|
|
def load_model(self, load_path: str): |
|
"""Load model and tokenizer from a path.""" |
|
try: |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
load_path, torch_dtype=torch.bfloat16, device_map="auto" |
|
).to(self.device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(load_path) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
print(f"Model loaded from {load_path}") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
|
|
def quantize_model(self, bits: int = 8): |
|
"""Quantize model for efficiency (basic dynamic quantization).""" |
|
try: |
|
if bits != 8: |
|
print("⚠️ Only 8-bit quantization supported with torch.qint8") |
|
self.model = torch.quantization.quantize_dynamic( |
|
self.model, {nn.Linear}, dtype=torch.qint8 |
|
) |
|
print("Model quantized to 8 bits (dynamic quantization)") |
|
except Exception as e: |
|
print(f"Error quantizing model: {e}") |
|
|
|
if __name__ == "__main__": |
|
|
|
model = Charm15Model(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
|
|
|
|
prompt = "Charm 15 is amazing because" |
|
text = model.generate_text(prompt) |
|
print(f"Generated: {text}") |
|
|
|
|
|
from your_dataloader_script import DataLoaderHandler |
|
train_loader = DataLoaderHandler( |
|
"../datasets/eclipse_corpuz_1.1.jsonl", |
|
"../finetuned_charm15/tokenizer.json", |
|
batch_size=4 |
|
).get_dataloader() |
|
|
|
|
|
model.fine_tune(train_loader) |
|
|
|
|
|
model.save_model("../finetuned_charm15") |
|
|
|
|
|
model.quantize_model() |
|
|
|
|
|
model.load_model("../finetuned_charm15") |
|
print(model.generate_text("Testing reloaded model")) |