File size: 4,584 Bytes
795855b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import random
from transformers import AutoTokenizer, BertForSequenceClassification
from datasets import load_dataset
from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, AutoModelForQuestionAnswering
from transformers.modeling_outputs import ModelOutput
from typing import Optional, Tuple, Union

import numpy as np
from tqdm import tqdm
import evaluate
import torch
from dataclasses import dataclass
from datasets import load_dataset, concatenate_datasets, load_metric
from torch import nn
from torch.nn import CrossEntropyLoss
import collections
import re

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)

def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    answers = []
    labels = []
    final_questions = []
    # Generate close-looking answers from existing context
    for i in range(len(questions)):
        context = examples["context"][i]
        words = context.split()
        final_questions.append(questions[i])
        original_answer = examples["answers"][i]["text"][0]
        original_words = original_answer.split()
        answers.append(original_answer)
        labels.append(1)
        answer_start = examples["answers"][i]["answer_start"][0]
        answer_end = answer_start + len(original_answer)

        begin_context = context[:answer_start]
        end_context = context[answer_end:]

        end = 1 if len(original_words) == 1 else 3
        case_ind = random.randint(0, end)
        pre_words = begin_context.rsplit(maxsplit=1)
        post_words = end_context.split(maxsplit=1)
        if case_ind == 0 and pre_words: # Append left
            words = pre_words
            wrong_context = " ".join([words[-1], original_answer])
            final_questions.append(questions[i])
            answers.append(wrong_context)
            labels.append(0)
        elif case_ind == 1 and post_words: # Append right
            words = post_words
            wrong_context = " ".join([original_answer, words[0]])
            final_questions.append(questions[i])
            answers.append(wrong_context)
            labels.append(0)
        elif case_ind == 3: # Drop left
            wrong_context = " ".join(original_words[1:])
            final_questions.append(questions[i])
            answers.append(wrong_context)
            labels.append(0)
        elif case_ind == 4: # Drop right
            wrong_context = " ".join(original_words[:len(original_words) - 1])
            final_questions.append(questions[i])
            answers.append(wrong_context)
            labels.append(0)

    inputs = tokenizer(
            final_questions,
            answers,
            padding="max_length",
        )
    inputs["labels"] = labels
    return inputs

raw_datasets = load_dataset("squad")
raw_train = raw_datasets["train"]
raw_eval = raw_datasets["validation"]

train_dataset = raw_train.map(
            preprocess_training_examples,
            batched=True,
            remove_columns=raw_train.column_names,
        )

eval_dataset = raw_eval.map(
        preprocess_training_examples,
        batched=True,
        remove_columns=raw_train.column_names,
    )

batch_size = 8

# train_dataset = train_dataset.with_format("torch")

args = TrainingArguments(
        "right_span_bert",
        evaluation_strategy = "no",
        save_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=2,
        weight_decay=0.01,
        push_to_hub=True,
        fp16=True
    )

def compute_metrics(eval_pred):
   load_accuracy = evaluate.load("accuracy")
   load_f1 = evaluate.load("f1")
   logits, labels = eval_pred
   predictions = np.argmax(logits, axis=-1)
   accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
   f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]
   return {"accuracy": accuracy, "f1": f1}

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
    )

trainer.train()

res = trainer.evaluate()
print(res)