Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import torch | |
from sklearn.metrics import f1_score | |
from torch.optim import AdamW | |
from tqdm import tqdm | |
from transformers import get_linear_schedule_with_warmup | |
from .config_train import (device, epochs, eps, lr, model_saved_path, | |
weight_decay) | |
from .load_data import train_dataloader, validation_dataloader | |
from .model import model | |
class Key_Ner_Training: | |
def __init__(self, model, train_dataloader, validation_dataloader, epochs, lr, eps, weight_decay, device, model_saved_path): | |
""" | |
Initializes the Key_Ner_Training with the necessary components for training. | |
Args: | |
model (torch.nn.Module): The model to be trained. | |
train_dataloader (DataLoader): DataLoader for training data. | |
validation_dataloader (DataLoader): DataLoader for validation data. | |
epochs (int): Number of training epochs. | |
lr (float): Learning rate for the optimizer. | |
eps (float): Epsilon value for the optimizer. | |
weight_decay (float): Weight decay for the optimizer. | |
device (str): Device to run the model on ("cuda" or "cpu"). | |
model_saved_path (str): Path to save the trained model. | |
""" | |
self.model = model.to(device) | |
self.train_dataloader = train_dataloader | |
self.validation_dataloader = validation_dataloader | |
self.epochs = epochs | |
self.device = device | |
self.model_saved_path = model_saved_path | |
# AdamW optimizer | |
self.optimizer = AdamW(self.model.parameters(), lr=lr, eps=eps, weight_decay=weight_decay) | |
# Total number of training steps | |
self.total_steps = len(train_dataloader) * epochs | |
# Learning rate scheduler | |
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=self.total_steps) | |
# Metrics | |
self.train_losses = [] | |
self.val_losses = [] | |
self.train_f1_scores = [] | |
self.val_f1_scores = [] | |
def train(self): | |
"""Trains the model over the specified number of epochs.""" | |
for epoch in range(self.epochs): | |
print(f'Epoch {epoch + 1}/{self.epochs}') | |
print('-' * 10) | |
# Training | |
avg_train_loss, train_f1 = self._train_epoch() | |
self.train_losses.append(avg_train_loss) | |
self.train_f1_scores.append(train_f1) | |
print(f'Training loss: {avg_train_loss}, F1-score: {train_f1}') | |
# Validation | |
avg_val_loss, val_f1 = self._validate_epoch() | |
self.val_losses.append(avg_val_loss) | |
self.val_f1_scores.append(val_f1) | |
print(f'Validation Loss: {avg_val_loss}, F1-score: {val_f1}') | |
print("Training complete!") | |
# Plot losses and F1 scores | |
self._plot_metrics() | |
# Save model | |
self.model.save_pretrained(self.model_saved_path) | |
def _train_epoch(self): | |
"""Runs a single training epoch.""" | |
self.model.train() | |
total_loss = 0 | |
train_predictions = [] | |
train_targets = [] | |
train_dataloader_iterator = tqdm(self.train_dataloader, desc="Training") | |
for step, batch in enumerate(train_dataloader_iterator): | |
b_input_ids = batch[0].to(self.device) | |
b_input_mask = batch[1].to(self.device) | |
b_labels = batch[2].to(self.device) | |
self.model.zero_grad() | |
outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) | |
loss = outputs.loss | |
total_loss += loss.item() | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
self.scheduler.step() | |
train_dataloader_iterator.set_postfix({"Loss": loss.item()}) | |
logits = outputs.logits | |
predictions = torch.argmax(logits, dim=2) | |
train_predictions.extend(predictions.cpu().numpy().flatten()) | |
train_targets.extend(b_labels.cpu().numpy().flatten()) | |
avg_train_loss = total_loss / len(self.train_dataloader) | |
train_f1 = f1_score(train_targets, train_predictions, average='macro') | |
return avg_train_loss, train_f1 | |
def _validate_epoch(self): | |
"""Runs a single validation epoch.""" | |
self.model.eval() | |
total_eval_loss = 0 | |
val_predictions = [] | |
val_targets = [] | |
validation_dataloader_iterator = tqdm(self.validation_dataloader, desc="Validation") | |
for batch in validation_dataloader_iterator: | |
b_input_ids = batch[0].to(self.device) | |
b_input_mask = batch[1].to(self.device) | |
b_labels = batch[2].to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) | |
loss = outputs.loss | |
total_eval_loss += loss.item() | |
validation_dataloader_iterator.set_postfix({"Loss": loss.item()}) | |
logits = outputs.logits | |
predictions = torch.argmax(logits, dim=2) | |
val_predictions.extend(predictions.cpu().numpy().flatten()) | |
val_targets.extend(b_labels.cpu().numpy().flatten()) | |
avg_val_loss = total_eval_loss / len(self.validation_dataloader) | |
val_f1 = f1_score(val_targets, val_predictions, average='macro') | |
return avg_val_loss, val_f1 | |
def _plot_metrics(self): | |
"""Plots training and validation losses and F1 scores.""" | |
epochs_range = range(1, self.epochs + 1) | |
# Plotting Loss | |
plt.figure(figsize=(12, 6)) | |
plt.plot(epochs_range, self.train_losses, label='Training Loss') | |
plt.plot(epochs_range, self.val_losses, label='Validation Loss') | |
plt.xlabel('Epochs') | |
plt.ylabel('Loss') | |
plt.title('Training and Validation Loss') | |
plt.legend() | |
plt.show() | |
# Plotting F1-score | |
plt.figure(figsize=(12, 6)) | |
plt.plot(epochs_range, self.train_f1_scores, label='Training F1-score') | |
plt.plot(epochs_range, self.val_f1_scores, label='Validation F1-score') | |
plt.xlabel('Epochs') | |
plt.ylabel('F1-score') | |
plt.title('Training and Validation F1-score') | |
plt.legend() | |
plt.show() | |
# Example usage: | |
if __name__ == "__main__": | |
trainer = Key_Ner_Training( | |
model=model, | |
train_dataloader=train_dataloader, | |
validation_dataloader=validation_dataloader, | |
epochs=epochs, | |
lr=lr, | |
eps=eps, | |
weight_decay=weight_decay, | |
device=device, | |
model_saved_path=model_saved_path | |
) | |
trainer.train() | |