sample_bert / bert.py
botcon's picture
Upload bert.py
795855b
raw
history blame
4.58 kB
import torch
import random
from transformers import AutoTokenizer, BertForSequenceClassification
from datasets import load_dataset
from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, AutoModelForQuestionAnswering
from transformers.modeling_outputs import ModelOutput
from typing import Optional, Tuple, Union
import numpy as np
from tqdm import tqdm
import evaluate
import torch
from dataclasses import dataclass
from datasets import load_dataset, concatenate_datasets, load_metric
from torch import nn
from torch.nn import CrossEntropyLoss
import collections
import re
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
def preprocess_training_examples(examples):
questions = [q.strip() for q in examples["question"]]
answers = []
labels = []
final_questions = []
# Generate close-looking answers from existing context
for i in range(len(questions)):
context = examples["context"][i]
words = context.split()
final_questions.append(questions[i])
original_answer = examples["answers"][i]["text"][0]
original_words = original_answer.split()
answers.append(original_answer)
labels.append(1)
answer_start = examples["answers"][i]["answer_start"][0]
answer_end = answer_start + len(original_answer)
begin_context = context[:answer_start]
end_context = context[answer_end:]
end = 1 if len(original_words) == 1 else 3
case_ind = random.randint(0, end)
pre_words = begin_context.rsplit(maxsplit=1)
post_words = end_context.split(maxsplit=1)
if case_ind == 0 and pre_words: # Append left
words = pre_words
wrong_context = " ".join([words[-1], original_answer])
final_questions.append(questions[i])
answers.append(wrong_context)
labels.append(0)
elif case_ind == 1 and post_words: # Append right
words = post_words
wrong_context = " ".join([original_answer, words[0]])
final_questions.append(questions[i])
answers.append(wrong_context)
labels.append(0)
elif case_ind == 3: # Drop left
wrong_context = " ".join(original_words[1:])
final_questions.append(questions[i])
answers.append(wrong_context)
labels.append(0)
elif case_ind == 4: # Drop right
wrong_context = " ".join(original_words[:len(original_words) - 1])
final_questions.append(questions[i])
answers.append(wrong_context)
labels.append(0)
inputs = tokenizer(
final_questions,
answers,
padding="max_length",
)
inputs["labels"] = labels
return inputs
raw_datasets = load_dataset("squad")
raw_train = raw_datasets["train"]
raw_eval = raw_datasets["validation"]
train_dataset = raw_train.map(
preprocess_training_examples,
batched=True,
remove_columns=raw_train.column_names,
)
eval_dataset = raw_eval.map(
preprocess_training_examples,
batched=True,
remove_columns=raw_train.column_names,
)
batch_size = 8
# train_dataset = train_dataset.with_format("torch")
args = TrainingArguments(
"right_span_bert",
evaluation_strategy = "no",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=2,
weight_decay=0.01,
push_to_hub=True,
fp16=True
)
def compute_metrics(eval_pred):
load_accuracy = evaluate.load("accuracy")
load_f1 = evaluate.load("f1")
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]
return {"accuracy": accuracy, "f1": f1}
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
res = trainer.evaluate()
print(res)