from transformers import BertForQuestionAnswering, BertTokenizer import torch model_name = "bert-large-uncased-whole-word-masking-finetuned-squad" tokenizer = BertTokenizer.from_pretrained(model_name) model = BertForQuestionAnswering.from_pretrained(model_name) def get_answer(question, context): inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) start_idx = torch.argmax(outputs.start_logits) end_idx = torch.argmax(outputs.end_logits) + 1 answer_tokens = inputs["input_ids"][0][start_idx:end_idx] answer = tokenizer.decode(answer_tokens, skip_special_tokens=True) return answer