from transformers import AutoTokenizer, AutoModelForQuestionAnswering import numpy as np from tqdm import tqdm import torch import collections luke_beam_size = 5 n_best = 30 max_length = 512 stride = 128 batch_size = 8 n_best = 20 max_answer_length = 30 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") luke_model = AutoModelForQuestionAnswering.from_pretrained("botcon/LUKE_squadshift_finetuned_large").to(device) luke_tokenizer = AutoTokenizer.from_pretrained("roberta-base") def compute_beam(start_logits, end_logits, features, examples): example_to_features = collections.defaultdict(list) for idx, feature in enumerate(features): example_to_features[feature["example_id"]].append(idx) predicted_answers = [] for example in tqdm(examples): example_id = example["id"] context = example["context"] answers = [] # Loop through all features associated with that example for feature_index in example_to_features[example_id]: start_logit = start_logits[feature_index] end_logit = end_logits[feature_index] offsets = features[feature_index]["offset_mapping"] start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist() end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist() for start_index in start_indexes: for end_index in end_indexes: # Skip answers that are not fully in the context if offsets[start_index] is None or offsets[end_index] is None: continue # Skip answers with a length that is either < 0 or > max_answer_length if ( end_index < start_index or end_index - start_index + 1 > max_answer_length ): continue answer = { "text": context[offsets[start_index][0] : offsets[end_index][1]], "logit_score": start_logit[start_index] + end_logit[end_index], } answers.append(answer) # Select the answer with the best score if len(answers) > 0: best_answers = sorted(answers, key=lambda x: x["logit_score"], reverse=True) best_ans = [] best_logits = [] i = 0 while i < len(best_answers[:luke_beam_size]): best_ans.append(best_answers[i]["text"]) best_logits.append(best_answers[i]["logit_score"]) i += 1 while i < luke_beam_size: best_ans.append("") best_logits.append(1e-5) # treat this as negative infinity i += 1 predicted_answers.append({"id":example_id, "prediction_text": best_ans, "logits": best_logits}) else: predicted_answers.append({"id": example_id, "prediction_text": ""}) return predicted_answers def preprocess_validation_examples(examples): questions = [q.strip() for q in examples["question"]] inputs = luke_tokenizer( questions, examples["context"], max_length=max_length, truncation="only_second", stride=stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) sample_map = inputs.pop("overflow_to_sample_mapping") example_ids = [] for i in range(len(inputs["input_ids"])): sample_idx = sample_map[i] example_ids.append(examples["id"][sample_idx]) sequence_ids = inputs.sequence_ids(i) offset = inputs["offset_mapping"][i] inputs["offset_mapping"][i] = [ o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) ] inputs["example_id"] = example_ids return inputs def generate(dataset): luke_model.eval() with torch.no_grad(): preprocessed = dataset.map( preprocess_validation_examples, batched=True, remove_columns=dataset.column_names ) eval_set_for_model = preprocessed.remove_columns(["example_id", "offset_mapping"]) eval_set_for_model.set_format("torch") batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names} outputs = luke_model(**batch) start_logits = outputs.start_logits.cpu().numpy() end_logits = outputs.end_logits.cpu().numpy() res = compute_beam(start_logits, end_logits, preprocessed, dataset) return res