|
import torch |
|
import random |
|
from transformers import AutoTokenizer, BertForSequenceClassification |
|
from datasets import load_dataset |
|
from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, AutoModelForQuestionAnswering |
|
from transformers.modeling_outputs import ModelOutput |
|
from typing import Optional, Tuple, Union |
|
|
|
import numpy as np |
|
from tqdm import tqdm |
|
import evaluate |
|
import torch |
|
from dataclasses import dataclass |
|
from datasets import load_dataset, concatenate_datasets, load_metric |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
import collections |
|
import re |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device) |
|
|
|
def preprocess_training_examples(examples): |
|
questions = [q.strip() for q in examples["question"]] |
|
answers = [] |
|
labels = [] |
|
final_questions = [] |
|
|
|
for i in range(len(questions)): |
|
context = examples["context"][i] |
|
words = context.split() |
|
final_questions.append(questions[i]) |
|
original_answer = examples["answers"][i]["text"][0] |
|
original_words = original_answer.split() |
|
answers.append(original_answer) |
|
labels.append(1) |
|
answer_start = examples["answers"][i]["answer_start"][0] |
|
answer_end = answer_start + len(original_answer) |
|
|
|
begin_context = context[:answer_start] |
|
end_context = context[answer_end:] |
|
|
|
end = 1 if len(original_words) == 1 else 3 |
|
case_ind = random.randint(0, end) |
|
pre_words = begin_context.rsplit(maxsplit=1) |
|
post_words = end_context.split(maxsplit=1) |
|
if case_ind == 0 and pre_words: |
|
words = pre_words |
|
wrong_context = " ".join([words[-1], original_answer]) |
|
final_questions.append(questions[i]) |
|
answers.append(wrong_context) |
|
labels.append(0) |
|
elif case_ind == 1 and post_words: |
|
words = post_words |
|
wrong_context = " ".join([original_answer, words[0]]) |
|
final_questions.append(questions[i]) |
|
answers.append(wrong_context) |
|
labels.append(0) |
|
elif case_ind == 3: |
|
wrong_context = " ".join(original_words[1:]) |
|
final_questions.append(questions[i]) |
|
answers.append(wrong_context) |
|
labels.append(0) |
|
elif case_ind == 4: |
|
wrong_context = " ".join(original_words[:len(original_words) - 1]) |
|
final_questions.append(questions[i]) |
|
answers.append(wrong_context) |
|
labels.append(0) |
|
|
|
inputs = tokenizer( |
|
final_questions, |
|
answers, |
|
padding="max_length", |
|
) |
|
inputs["labels"] = labels |
|
return inputs |
|
|
|
raw_datasets = load_dataset("squad") |
|
raw_train = raw_datasets["train"] |
|
raw_eval = raw_datasets["validation"] |
|
|
|
train_dataset = raw_train.map( |
|
preprocess_training_examples, |
|
batched=True, |
|
remove_columns=raw_train.column_names, |
|
) |
|
|
|
eval_dataset = raw_eval.map( |
|
preprocess_training_examples, |
|
batched=True, |
|
remove_columns=raw_train.column_names, |
|
) |
|
|
|
batch_size = 8 |
|
|
|
|
|
|
|
args = TrainingArguments( |
|
"right_span_bert", |
|
evaluation_strategy = "no", |
|
save_strategy="epoch", |
|
learning_rate=2e-5, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
num_train_epochs=2, |
|
weight_decay=0.01, |
|
push_to_hub=True, |
|
fp16=True |
|
) |
|
|
|
def compute_metrics(eval_pred): |
|
load_accuracy = evaluate.load("accuracy") |
|
load_f1 = evaluate.load("f1") |
|
logits, labels = eval_pred |
|
predictions = np.argmax(logits, axis=-1) |
|
accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"] |
|
f1 = load_f1.compute(predictions=predictions, references=labels)["f1"] |
|
return {"accuracy": accuracy, "f1": f1} |
|
|
|
trainer = Trainer( |
|
model, |
|
args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
data_collator=default_data_collator, |
|
tokenizer=tokenizer, |
|
compute_metrics=compute_metrics |
|
) |
|
|
|
trainer.train() |
|
|
|
res = trainer.evaluate() |
|
print(res) |
|
|