from transformers import BertTokenizer, BertForQuestionAnswering import torch model_path = "D:/code/bert_easy/bert-large-uncased-whole-word-masking-finetuned-squad" tokenizer = BertTokenizer.from_pretrained(model_path) model = BertForQuestionAnswering.from_pretrained(model_path) def get_answer(question, context): """Answers a question using BERT on given context.""" inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) start = torch.argmax(outputs.start_logits) end = torch.argmax(outputs.end_logits) + 1 return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end]))