Spaces:
Sleeping
Sleeping
File size: 6,825 Bytes
280d87f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import f1_score
from torch.optim import AdamW
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from .config_train import (device, epochs, eps, lr, model_saved_path,
weight_decay)
from .load_data import train_dataloader, validation_dataloader
from .model import model
class Key_Ner_Training:
def __init__(self, model, train_dataloader, validation_dataloader, epochs, lr, eps, weight_decay, device, model_saved_path):
"""
Initializes the Key_Ner_Training with the necessary components for training.
Args:
model (torch.nn.Module): The model to be trained.
train_dataloader (DataLoader): DataLoader for training data.
validation_dataloader (DataLoader): DataLoader for validation data.
epochs (int): Number of training epochs.
lr (float): Learning rate for the optimizer.
eps (float): Epsilon value for the optimizer.
weight_decay (float): Weight decay for the optimizer.
device (str): Device to run the model on ("cuda" or "cpu").
model_saved_path (str): Path to save the trained model.
"""
self.model = model.to(device)
self.train_dataloader = train_dataloader
self.validation_dataloader = validation_dataloader
self.epochs = epochs
self.device = device
self.model_saved_path = model_saved_path
# AdamW optimizer
self.optimizer = AdamW(self.model.parameters(), lr=lr, eps=eps, weight_decay=weight_decay)
# Total number of training steps
self.total_steps = len(train_dataloader) * epochs
# Learning rate scheduler
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=self.total_steps)
# Metrics
self.train_losses = []
self.val_losses = []
self.train_f1_scores = []
self.val_f1_scores = []
def train(self):
"""Trains the model over the specified number of epochs."""
for epoch in range(self.epochs):
print(f'Epoch {epoch + 1}/{self.epochs}')
print('-' * 10)
# Training
avg_train_loss, train_f1 = self._train_epoch()
self.train_losses.append(avg_train_loss)
self.train_f1_scores.append(train_f1)
print(f'Training loss: {avg_train_loss}, F1-score: {train_f1}')
# Validation
avg_val_loss, val_f1 = self._validate_epoch()
self.val_losses.append(avg_val_loss)
self.val_f1_scores.append(val_f1)
print(f'Validation Loss: {avg_val_loss}, F1-score: {val_f1}')
print("Training complete!")
# Plot losses and F1 scores
self._plot_metrics()
# Save model
self.model.save_pretrained(self.model_saved_path)
def _train_epoch(self):
"""Runs a single training epoch."""
self.model.train()
total_loss = 0
train_predictions = []
train_targets = []
train_dataloader_iterator = tqdm(self.train_dataloader, desc="Training")
for step, batch in enumerate(train_dataloader_iterator):
b_input_ids = batch[0].to(self.device)
b_input_mask = batch[1].to(self.device)
b_labels = batch[2].to(self.device)
self.model.zero_grad()
outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
loss = outputs.loss
total_loss += loss.item()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.scheduler.step()
train_dataloader_iterator.set_postfix({"Loss": loss.item()})
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
train_predictions.extend(predictions.cpu().numpy().flatten())
train_targets.extend(b_labels.cpu().numpy().flatten())
avg_train_loss = total_loss / len(self.train_dataloader)
train_f1 = f1_score(train_targets, train_predictions, average='macro')
return avg_train_loss, train_f1
def _validate_epoch(self):
"""Runs a single validation epoch."""
self.model.eval()
total_eval_loss = 0
val_predictions = []
val_targets = []
validation_dataloader_iterator = tqdm(self.validation_dataloader, desc="Validation")
for batch in validation_dataloader_iterator:
b_input_ids = batch[0].to(self.device)
b_input_mask = batch[1].to(self.device)
b_labels = batch[2].to(self.device)
with torch.no_grad():
outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
loss = outputs.loss
total_eval_loss += loss.item()
validation_dataloader_iterator.set_postfix({"Loss": loss.item()})
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
val_predictions.extend(predictions.cpu().numpy().flatten())
val_targets.extend(b_labels.cpu().numpy().flatten())
avg_val_loss = total_eval_loss / len(self.validation_dataloader)
val_f1 = f1_score(val_targets, val_predictions, average='macro')
return avg_val_loss, val_f1
def _plot_metrics(self):
"""Plots training and validation losses and F1 scores."""
epochs_range = range(1, self.epochs + 1)
# Plotting Loss
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, self.train_losses, label='Training Loss')
plt.plot(epochs_range, self.val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# Plotting F1-score
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, self.train_f1_scores, label='Training F1-score')
plt.plot(epochs_range, self.val_f1_scores, label='Validation F1-score')
plt.xlabel('Epochs')
plt.ylabel('F1-score')
plt.title('Training and Validation F1-score')
plt.legend()
plt.show()
# Example usage:
if __name__ == "__main__":
trainer = Key_Ner_Training(
model=model,
train_dataloader=train_dataloader,
validation_dataloader=validation_dataloader,
epochs=epochs,
lr=lr,
eps=eps,
weight_decay=weight_decay,
device=device,
model_saved_path=model_saved_path
)
trainer.train()
|