File size: 5,719 Bytes
			
			| 2b4f5ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import pytorch_lightning as pl
import torch
from transformers.optimization import AdamW
import torchmetrics
from torchmetrics.classification import F1Score
class SequenceClassificationModule(pl.LightningModule):
    def __init__(
        self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.model = model
        self.learning_rate = learning_rate
        self.train_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=model.num_labels
        )
        self.val_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=model.num_labels
        )
        self.test_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=model.num_labels
        )
        self.train_f1 = F1Score(
            task="multiclass", num_classes=model.num_labels, average="macro"
        )
        self.val_f1 = F1Score(
            task="multiclass", num_classes=model.num_labels, average=None
        )
        self.test_f1 = F1Score(
            task="multiclass", num_classes=model.num_labels, average=None
        )
        self.use_question_stance_approach = use_question_stance_approach
    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer
    def training_step(self, batch, batch_idx):
        x, x_mask, y = batch
        outputs = self(x, attention_mask=x_mask, labels=y)
        logits = outputs.logits
        loss = outputs.loss
        preds = torch.argmax(logits, axis=1)
        self.log("train_loss", loss)
        return {"loss": loss}
    def validation_step(self, batch, batch_idx):
        x, x_mask, y = batch
        outputs = self(x, attention_mask=x_mask, labels=y)
        logits = outputs.logits
        loss = outputs.loss
        preds = torch.argmax(logits, axis=1)
        if not self.use_question_stance_approach:
            self.val_acc(preds, y)
            self.log("val_acc_step", self.val_acc)
            self.val_f1(preds, y)
            self.log("val_loss", loss)
        return {"val_loss": loss, "src": x, "pred": preds, "target": y}
    def validation_epoch_end(self, outs):
        if self.use_question_stance_approach:
            self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
        self.log("val_acc_epoch", self.val_acc)
        f1 = self.val_f1.compute()
        self.val_f1.reset()
        self.log("val_f1_epoch", torch.mean(f1))
        class_names = ["supported", "refuted", "nei", "conflicting"]
        for i, c_name in enumerate(class_names):
            self.log("val_f1_" + c_name, f1[i])
    def test_step(self, batch, batch_idx):
        x, x_mask, y = batch
        outputs = self(x, attention_mask=x_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, axis=1)
        if not self.use_question_stance_approach:
            self.test_acc(preds, y)
            self.log("test_acc_step", self.test_acc)
            self.test_f1(preds, y)
        return {"src": x, "pred": preds, "target": y}
    def test_epoch_end(self, outs):
        if self.use_question_stance_approach:
            self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
        self.log("test_acc_epoch", self.test_acc)
        f1 = self.test_f1.compute()
        self.test_f1.reset()
        self.log("test_f1_epoch", torch.mean(f1))
        class_names = ["supported", "refuted", "nei", "conflicting"]
        for i, c_name in enumerate(class_names):
            self.log("test_f1_" + c_name, f1[i])
    def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
        gold_labels = {}
        question_support = {}
        for out in outputs:
            srcs = out["src"]
            preds = out["pred"]
            tgts = out["target"]
            tokens = self.tokenizer.batch_decode(
                srcs, skip_special_tokens=True, clean_up_tokenization_spaces=True
            )
            for src, pred, tgt in zip(tokens, preds, tgts):
                claim_id = src.split("[ question ]")[0]
                if claim_id not in gold_labels:
                    gold_labels[claim_id] = tgt
                    question_support[claim_id] = []
                question_support[claim_id].append(pred)
        for k, gold_label in gold_labels.items():
            support = question_support[k]
            has_unanswerable = False
            has_true = False
            has_false = False
            for v in support:
                if v == 0:
                    has_true = True
                if v == 1:
                    has_false = True
                if v in (
                    2,
                    3,
                ):  # TODO very ugly hack -- we cant have different numbers of labels for train and test so we do this
                    has_unanswerable = True
            if has_unanswerable:
                answer = 2
            elif has_true and not has_false:
                answer = 0
            elif has_false and not has_true:
                answer = 1
            elif has_true and has_false:
                answer = 3
            # TODO this is very hacky and wont work if the device is literally anything other than cuda:0
            acc_scorer(
                torch.as_tensor([answer]).to("cuda:0"),
                torch.as_tensor([gold_label]).to("cuda:0"),
            )
            f1_scorer(
                torch.as_tensor([answer]).to("cuda:0"),
                torch.as_tensor([gold_label]).to("cuda:0"),
            )
 |