import torch import random from transformers import AutoTokenizer, BertForSequenceClassification from datasets import load_dataset from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, AutoModelForQuestionAnswering from transformers.modeling_outputs import ModelOutput from typing import Optional, Tuple, Union import numpy as np from tqdm import tqdm import evaluate import torch from dataclasses import dataclass from datasets import load_dataset, concatenate_datasets, load_metric from torch import nn from torch.nn import CrossEntropyLoss import collections import re torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device) def preprocess_training_examples(examples): questions = [q.strip() for q in examples["question"]] answers = [] labels = [] final_questions = [] # Generate close-looking answers from existing context for i in range(len(questions)): context = examples["context"][i] words = context.split() final_questions.append(questions[i]) original_answer = examples["answers"][i]["text"][0] original_words = original_answer.split() answers.append(original_answer) labels.append(1) answer_start = examples["answers"][i]["answer_start"][0] answer_end = answer_start + len(original_answer) begin_context = context[:answer_start] end_context = context[answer_end:] end = 1 if len(original_words) == 1 else 3 case_ind = random.randint(0, end) pre_words = begin_context.rsplit(maxsplit=1) post_words = end_context.split(maxsplit=1) if case_ind == 0 and pre_words: # Append left words = pre_words wrong_context = " ".join([words[-1], original_answer]) final_questions.append(questions[i]) answers.append(wrong_context) labels.append(0) elif case_ind == 1 and post_words: # Append right words = post_words wrong_context = " ".join([original_answer, words[0]]) final_questions.append(questions[i]) answers.append(wrong_context) labels.append(0) elif case_ind == 3: # Drop left wrong_context = " ".join(original_words[1:]) final_questions.append(questions[i]) answers.append(wrong_context) labels.append(0) elif case_ind == 4: # Drop right wrong_context = " ".join(original_words[:len(original_words) - 1]) final_questions.append(questions[i]) answers.append(wrong_context) labels.append(0) inputs = tokenizer( final_questions, answers, padding="max_length", ) inputs["labels"] = labels return inputs raw_datasets = load_dataset("squad") raw_train = raw_datasets["train"] raw_eval = raw_datasets["validation"] train_dataset = raw_train.map( preprocess_training_examples, batched=True, remove_columns=raw_train.column_names, ) eval_dataset = raw_eval.map( preprocess_training_examples, batched=True, remove_columns=raw_train.column_names, ) batch_size = 8 # train_dataset = train_dataset.with_format("torch") args = TrainingArguments( "right_span_bert", evaluation_strategy = "no", save_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=2, weight_decay=0.01, push_to_hub=True, fp16=True ) def compute_metrics(eval_pred): load_accuracy = evaluate.load("accuracy") load_f1 = evaluate.load("f1") logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"] f1 = load_f1.compute(predictions=predictions, references=labels)["f1"] return {"accuracy": accuracy, "f1": f1} trainer = Trainer( model, args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=default_data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics ) trainer.train() res = trainer.evaluate() print(res)