File size: 6,825 Bytes
280d87f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import f1_score
from torch.optim import AdamW
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup

from .config_train import (device, epochs, eps, lr, model_saved_path,
                           weight_decay)
from .load_data import train_dataloader, validation_dataloader
from .model import model


class Key_Ner_Training:
    def __init__(self, model, train_dataloader, validation_dataloader, epochs, lr, eps, weight_decay, device, model_saved_path):
        """
        Initializes the Key_Ner_Training with the necessary components for training.

        Args:
            model (torch.nn.Module): The model to be trained.
            train_dataloader (DataLoader): DataLoader for training data.
            validation_dataloader (DataLoader): DataLoader for validation data.
            epochs (int): Number of training epochs.
            lr (float): Learning rate for the optimizer.
            eps (float): Epsilon value for the optimizer.
            weight_decay (float): Weight decay for the optimizer.
            device (str): Device to run the model on ("cuda" or "cpu").
            model_saved_path (str): Path to save the trained model.
        """
        self.model = model.to(device)
        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader
        self.epochs = epochs
        self.device = device
        self.model_saved_path = model_saved_path

        # AdamW optimizer
        self.optimizer = AdamW(self.model.parameters(), lr=lr, eps=eps, weight_decay=weight_decay)

        # Total number of training steps
        self.total_steps = len(train_dataloader) * epochs

        # Learning rate scheduler
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=self.total_steps)

        # Metrics
        self.train_losses = []
        self.val_losses = []
        self.train_f1_scores = []
        self.val_f1_scores = []

    def train(self):
        """Trains the model over the specified number of epochs."""
        for epoch in range(self.epochs):
            print(f'Epoch {epoch + 1}/{self.epochs}')
            print('-' * 10)

            # Training
            avg_train_loss, train_f1 = self._train_epoch()
            self.train_losses.append(avg_train_loss)
            self.train_f1_scores.append(train_f1)
            print(f'Training loss: {avg_train_loss}, F1-score: {train_f1}')

            # Validation
            avg_val_loss, val_f1 = self._validate_epoch()
            self.val_losses.append(avg_val_loss)
            self.val_f1_scores.append(val_f1)
            print(f'Validation Loss: {avg_val_loss}, F1-score: {val_f1}')

        print("Training complete!")

        # Plot losses and F1 scores
        self._plot_metrics()

        # Save model
        self.model.save_pretrained(self.model_saved_path)

    def _train_epoch(self):
        """Runs a single training epoch."""
        self.model.train()

        total_loss = 0
        train_predictions = []
        train_targets = []

        train_dataloader_iterator = tqdm(self.train_dataloader, desc="Training")

        for step, batch in enumerate(train_dataloader_iterator):
            b_input_ids = batch[0].to(self.device)
            b_input_mask = batch[1].to(self.device)
            b_labels = batch[2].to(self.device)

            self.model.zero_grad()

            outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
            loss = outputs.loss

            total_loss += loss.item()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            train_dataloader_iterator.set_postfix({"Loss": loss.item()})

            logits = outputs.logits
            predictions = torch.argmax(logits, dim=2)
            train_predictions.extend(predictions.cpu().numpy().flatten())
            train_targets.extend(b_labels.cpu().numpy().flatten())

        avg_train_loss = total_loss / len(self.train_dataloader)
        train_f1 = f1_score(train_targets, train_predictions, average='macro')

        return avg_train_loss, train_f1

    def _validate_epoch(self):
        """Runs a single validation epoch."""
        self.model.eval()

        total_eval_loss = 0
        val_predictions = []
        val_targets = []

        validation_dataloader_iterator = tqdm(self.validation_dataloader, desc="Validation")

        for batch in validation_dataloader_iterator:
            b_input_ids = batch[0].to(self.device)
            b_input_mask = batch[1].to(self.device)
            b_labels = batch[2].to(self.device)

            with torch.no_grad():
                outputs = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
                loss = outputs.loss

            total_eval_loss += loss.item()

            validation_dataloader_iterator.set_postfix({"Loss": loss.item()})

            logits = outputs.logits
            predictions = torch.argmax(logits, dim=2)
            val_predictions.extend(predictions.cpu().numpy().flatten())
            val_targets.extend(b_labels.cpu().numpy().flatten())

        avg_val_loss = total_eval_loss / len(self.validation_dataloader)
        val_f1 = f1_score(val_targets, val_predictions, average='macro')

        return avg_val_loss, val_f1

    def _plot_metrics(self):
        """Plots training and validation losses and F1 scores."""
        epochs_range = range(1, self.epochs + 1)

        # Plotting Loss
        plt.figure(figsize=(12, 6))
        plt.plot(epochs_range, self.train_losses, label='Training Loss')
        plt.plot(epochs_range, self.val_losses, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.show()

        # Plotting F1-score
        plt.figure(figsize=(12, 6))
        plt.plot(epochs_range, self.train_f1_scores, label='Training F1-score')
        plt.plot(epochs_range, self.val_f1_scores, label='Validation F1-score')
        plt.xlabel('Epochs')
        plt.ylabel('F1-score')
        plt.title('Training and Validation F1-score')
        plt.legend()
        plt.show()


# Example usage:
if __name__ == "__main__":
    trainer = Key_Ner_Training(
        model=model,
        train_dataloader=train_dataloader,
        validation_dataloader=validation_dataloader,
        epochs=epochs,
        lr=lr,
        eps=eps,
        weight_decay=weight_decay,
        device=device,
        model_saved_path=model_saved_path
    )
    trainer.train()