darkbreakerk's picture
Refactor + convert onnx model
280d87f
raw
history blame
6.83 kB
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import f1_score
from torch.optim import AdamW
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from .config_train import (device, epochs, eps, lr, model_saved_path,
weight_decay)
from .load_data import train_dataloader, validation_dataloader
from .model import model
class Key_Ner_Training:
def __init__(self, model, train_dataloader, validation_dataloader, epochs, lr, eps, weight_decay, device, model_saved_path):
"""
Initializes the Key_Ner_Training with the necessary components for training.
Args:
model (torch.nn.Module): The model to be trained.
train_dataloader (DataLoader): DataLoader for training data.
validation_dataloader (DataLoader): DataLoader for validation data.
epochs (int): Number of training epochs.
lr (float): Learning rate for the optimizer.
eps (float): Epsilon value for the optimizer.
weight_decay (float): Weight decay for the optimizer.
device (str): Device to run the model on ("cuda" or "cpu").
model_saved_path (str): Path to save the trained model.
"""
self.model = model.to(device)
self.train_dataloader = train_dataloader
self.validation_dataloader = validation_dataloader
self.epochs = epochs
self.device = device
self.model_saved_path = model_saved_path
# AdamW optimizer
self.optimizer = AdamW(self.model.parameters(), lr=lr, eps=eps, weight_decay=weight_decay)
# Total number of training steps
self.total_steps = len(train_dataloader) * epochs
# Learning rate scheduler
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=self.total_steps)
# Metrics
self.train_losses = []
self.val_losses = []
self.train_f1_scores = []
self.val_f1_scores = []
def train(self):
"""Trains the model over the specified number of epochs."""
for epoch in range(self.epochs):
print(f'Epoch {epoch + 1}/{self.epochs}')
print('-' * 10)
# Training
avg_train_loss, train_f1 = self._train_epoch()
self.train_losses.append(avg_train_loss)
self.train_f1_scores.append(train_f1)
print(f'Training loss: {avg_train_loss}, F1-score: {train_f1}')
# Validation
avg_val_loss, val_f1 = self._validate_epoch()
self.val_losses.append(avg_val_loss)
self.val_f1_scores.append(val_f1)
print(f'Validation Loss: {avg_val_loss}, F1-score: {val_f1}')
print("Training complete!")
# Plot losses and F1 scores
self._plot_metrics()
# Save model
self.model.save_pretrained(self.model_saved_path)
def _train_epoch(self):
"""Runs a single training epoch."""
self.model.train()
total_loss = 0
train_predictions = []
train_targets = []
train_dataloader_iterator = tqdm(self.train_dataloader, desc="Training")
for step, batch in enumerate(train_dataloader_iterator):
b_input_ids = batch[0].to(self.device)
b_input_mask = batch[1].to(self.device)
b_labels = batch[2].to(self.device)
self.model.zero_grad()
outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
loss = outputs.loss
total_loss += loss.item()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.scheduler.step()
train_dataloader_iterator.set_postfix({"Loss": loss.item()})
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
train_predictions.extend(predictions.cpu().numpy().flatten())
train_targets.extend(b_labels.cpu().numpy().flatten())
avg_train_loss = total_loss / len(self.train_dataloader)
train_f1 = f1_score(train_targets, train_predictions, average='macro')
return avg_train_loss, train_f1
def _validate_epoch(self):
"""Runs a single validation epoch."""
self.model.eval()
total_eval_loss = 0
val_predictions = []
val_targets = []
validation_dataloader_iterator = tqdm(self.validation_dataloader, desc="Validation")
for batch in validation_dataloader_iterator:
b_input_ids = batch[0].to(self.device)
b_input_mask = batch[1].to(self.device)
b_labels = batch[2].to(self.device)
with torch.no_grad():
outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
loss = outputs.loss
total_eval_loss += loss.item()
validation_dataloader_iterator.set_postfix({"Loss": loss.item()})
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
val_predictions.extend(predictions.cpu().numpy().flatten())
val_targets.extend(b_labels.cpu().numpy().flatten())
avg_val_loss = total_eval_loss / len(self.validation_dataloader)
val_f1 = f1_score(val_targets, val_predictions, average='macro')
return avg_val_loss, val_f1
def _plot_metrics(self):
"""Plots training and validation losses and F1 scores."""
epochs_range = range(1, self.epochs + 1)
# Plotting Loss
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, self.train_losses, label='Training Loss')
plt.plot(epochs_range, self.val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# Plotting F1-score
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, self.train_f1_scores, label='Training F1-score')
plt.plot(epochs_range, self.val_f1_scores, label='Validation F1-score')
plt.xlabel('Epochs')
plt.ylabel('F1-score')
plt.title('Training and Validation F1-score')
plt.legend()
plt.show()
# Example usage:
if __name__ == "__main__":
trainer = Key_Ner_Training(
model=model,
train_dataloader=train_dataloader,
validation_dataloader=validation_dataloader,
epochs=epochs,
lr=lr,
eps=eps,
weight_decay=weight_decay,
device=device,
model_saved_path=model_saved_path
)
trainer.train()