File size: 4,584 Bytes
795855b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import random
from transformers import AutoTokenizer, BertForSequenceClassification
from datasets import load_dataset
from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, AutoModelForQuestionAnswering
from transformers.modeling_outputs import ModelOutput
from typing import Optional, Tuple, Union
import numpy as np
from tqdm import tqdm
import evaluate
import torch
from dataclasses import dataclass
from datasets import load_dataset, concatenate_datasets, load_metric
from torch import nn
from torch.nn import CrossEntropyLoss
import collections
import re
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
def preprocess_training_examples(examples):
questions = [q.strip() for q in examples["question"]]
answers = []
labels = []
final_questions = []
# Generate close-looking answers from existing context
for i in range(len(questions)):
context = examples["context"][i]
words = context.split()
final_questions.append(questions[i])
original_answer = examples["answers"][i]["text"][0]
original_words = original_answer.split()
answers.append(original_answer)
labels.append(1)
answer_start = examples["answers"][i]["answer_start"][0]
answer_end = answer_start + len(original_answer)
begin_context = context[:answer_start]
end_context = context[answer_end:]
end = 1 if len(original_words) == 1 else 3
case_ind = random.randint(0, end)
pre_words = begin_context.rsplit(maxsplit=1)
post_words = end_context.split(maxsplit=1)
if case_ind == 0 and pre_words: # Append left
words = pre_words
wrong_context = " ".join([words[-1], original_answer])
final_questions.append(questions[i])
answers.append(wrong_context)
labels.append(0)
elif case_ind == 1 and post_words: # Append right
words = post_words
wrong_context = " ".join([original_answer, words[0]])
final_questions.append(questions[i])
answers.append(wrong_context)
labels.append(0)
elif case_ind == 3: # Drop left
wrong_context = " ".join(original_words[1:])
final_questions.append(questions[i])
answers.append(wrong_context)
labels.append(0)
elif case_ind == 4: # Drop right
wrong_context = " ".join(original_words[:len(original_words) - 1])
final_questions.append(questions[i])
answers.append(wrong_context)
labels.append(0)
inputs = tokenizer(
final_questions,
answers,
padding="max_length",
)
inputs["labels"] = labels
return inputs
raw_datasets = load_dataset("squad")
raw_train = raw_datasets["train"]
raw_eval = raw_datasets["validation"]
train_dataset = raw_train.map(
preprocess_training_examples,
batched=True,
remove_columns=raw_train.column_names,
)
eval_dataset = raw_eval.map(
preprocess_training_examples,
batched=True,
remove_columns=raw_train.column_names,
)
batch_size = 8
# train_dataset = train_dataset.with_format("torch")
args = TrainingArguments(
"right_span_bert",
evaluation_strategy = "no",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=2,
weight_decay=0.01,
push_to_hub=True,
fp16=True
)
def compute_metrics(eval_pred):
load_accuracy = evaluate.load("accuracy")
load_f1 = evaluate.load("f1")
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]
return {"accuracy": accuracy, "f1": f1}
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
res = trainer.evaluate()
print(res)
|