import matplotlib.pyplot as plt import torch from sklearn.metrics import f1_score from torch.optim import AdamW from tqdm import tqdm from transformers import get_linear_schedule_with_warmup from .config_train import (device, epochs, eps, lr, model_saved_path, weight_decay) from .load_data import train_dataloader, validation_dataloader from .model import model class Key_Ner_Training: def __init__(self, model, train_dataloader, validation_dataloader, epochs, lr, eps, weight_decay, device, model_saved_path): """ Initializes the Key_Ner_Training with the necessary components for training. Args: model (torch.nn.Module): The model to be trained. train_dataloader (DataLoader): DataLoader for training data. validation_dataloader (DataLoader): DataLoader for validation data. epochs (int): Number of training epochs. lr (float): Learning rate for the optimizer. eps (float): Epsilon value for the optimizer. weight_decay (float): Weight decay for the optimizer. device (str): Device to run the model on ("cuda" or "cpu"). model_saved_path (str): Path to save the trained model. """ self.model = model.to(device) self.train_dataloader = train_dataloader self.validation_dataloader = validation_dataloader self.epochs = epochs self.device = device self.model_saved_path = model_saved_path # AdamW optimizer self.optimizer = AdamW(self.model.parameters(), lr=lr, eps=eps, weight_decay=weight_decay) # Total number of training steps self.total_steps = len(train_dataloader) * epochs # Learning rate scheduler self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=self.total_steps) # Metrics self.train_losses = [] self.val_losses = [] self.train_f1_scores = [] self.val_f1_scores = [] def train(self): """Trains the model over the specified number of epochs.""" for epoch in range(self.epochs): print(f'Epoch {epoch + 1}/{self.epochs}') print('-' * 10) # Training avg_train_loss, train_f1 = self._train_epoch() self.train_losses.append(avg_train_loss) self.train_f1_scores.append(train_f1) print(f'Training loss: {avg_train_loss}, F1-score: {train_f1}') # Validation avg_val_loss, val_f1 = self._validate_epoch() self.val_losses.append(avg_val_loss) self.val_f1_scores.append(val_f1) print(f'Validation Loss: {avg_val_loss}, F1-score: {val_f1}') print("Training complete!") # Plot losses and F1 scores self._plot_metrics() # Save model self.model.save_pretrained(self.model_saved_path) def _train_epoch(self): """Runs a single training epoch.""" self.model.train() total_loss = 0 train_predictions = [] train_targets = [] train_dataloader_iterator = tqdm(self.train_dataloader, desc="Training") for step, batch in enumerate(train_dataloader_iterator): b_input_ids = batch[0].to(self.device) b_input_mask = batch[1].to(self.device) b_labels = batch[2].to(self.device) self.model.zero_grad() outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) loss = outputs.loss total_loss += loss.item() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() train_dataloader_iterator.set_postfix({"Loss": loss.item()}) logits = outputs.logits predictions = torch.argmax(logits, dim=2) train_predictions.extend(predictions.cpu().numpy().flatten()) train_targets.extend(b_labels.cpu().numpy().flatten()) avg_train_loss = total_loss / len(self.train_dataloader) train_f1 = f1_score(train_targets, train_predictions, average='macro') return avg_train_loss, train_f1 def _validate_epoch(self): """Runs a single validation epoch.""" self.model.eval() total_eval_loss = 0 val_predictions = [] val_targets = [] validation_dataloader_iterator = tqdm(self.validation_dataloader, desc="Validation") for batch in validation_dataloader_iterator: b_input_ids = batch[0].to(self.device) b_input_mask = batch[1].to(self.device) b_labels = batch[2].to(self.device) with torch.no_grad(): outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) loss = outputs.loss total_eval_loss += loss.item() validation_dataloader_iterator.set_postfix({"Loss": loss.item()}) logits = outputs.logits predictions = torch.argmax(logits, dim=2) val_predictions.extend(predictions.cpu().numpy().flatten()) val_targets.extend(b_labels.cpu().numpy().flatten()) avg_val_loss = total_eval_loss / len(self.validation_dataloader) val_f1 = f1_score(val_targets, val_predictions, average='macro') return avg_val_loss, val_f1 def _plot_metrics(self): """Plots training and validation losses and F1 scores.""" epochs_range = range(1, self.epochs + 1) # Plotting Loss plt.figure(figsize=(12, 6)) plt.plot(epochs_range, self.train_losses, label='Training Loss') plt.plot(epochs_range, self.val_losses, label='Validation Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Training and Validation Loss') plt.legend() plt.show() # Plotting F1-score plt.figure(figsize=(12, 6)) plt.plot(epochs_range, self.train_f1_scores, label='Training F1-score') plt.plot(epochs_range, self.val_f1_scores, label='Validation F1-score') plt.xlabel('Epochs') plt.ylabel('F1-score') plt.title('Training and Validation F1-score') plt.legend() plt.show() # Example usage: if __name__ == "__main__": trainer = Key_Ner_Training( model=model, train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, epochs=epochs, lr=lr, eps=eps, weight_decay=weight_decay, device=device, model_saved_path=model_saved_path ) trainer.train()