# training/train.py import os import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from transformers import BertTokenizer, AdamW from model.luna_model import LunaAI from sklearn.metrics import accuracy_score, precision_recall_fscore_support class TextDataset(Dataset): def __init__(self, csv_file, tokenizer, max_length=128): self.data = pd.read_csv(csv_file) self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data.iloc[idx, 0] label = self.data.iloc[idx, 1] encoding = self.tokenizer.encode_plus( text, add_special_tokens=True, return_tensors='pt', padding='max_length', max_length=self.max_length, truncation=True, ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long), } def evaluate_model(model, dataloader): model.eval() predictions, true_labels = [], [] with torch.no_grad(): for batch in dataloader: outputs = model(batch['input_ids'], batch['attention_mask']) _, preds = torch.max(outputs, dim=1) predictions.extend(preds.cpu().numpy()) true_labels.extend(batch['labels'].cpu().numpy()) accuracy = accuracy_score(true_labels, predictions) precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted') return accuracy, precision, recall, f1 def save_checkpoint(epoch, model, optimizer, loss, path="./checkpoints"): os.makedirs(path, exist_ok=True) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, os.path.join(path, f"checkpoint_epoch_{epoch}.pth")) def train_model(model, dataset, epochs=3, batch_size=16, learning_rate=5e-5): dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) optimizer = AdamW(model.parameters(), lr=learning_rate) model.train() for epoch in range(epochs): for batch in dataloader: input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] labels = batch['labels'] optimizer.zero_grad() outputs = model(input_ids, attention_mask) loss = nn.CrossEntropyLoss()(outputs, labels) loss.backward() optimizer.step() print(f'Epoch {epoch}, Loss: {loss.item()}') save_checkpoint(epoch, model, optimizer, loss.item()) # Optional: Evaluate the model at each epoch end accuracy, precision, recall, f1 = evaluate_model(model, dataloader) print(f'Epoch {epoch} - Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1 Score: {f1}') if __name__ == "__main__": tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') dataset = TextDataset('data/dataset.csv', tokenizer) model = LunaAI(num_classes=2) # Adjust num_classes if necessary train_model(model, dataset)