Soft_Computing_Project / generate_answers.py
Akshat1000's picture
Update generate_answers.py
26c6fcf verified
raw
history blame
710 Bytes
from transformers import BertForQuestionAnswering, BertTokenizer
import torch
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)
def get_answer(question, context):
inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
start_idx = torch.argmax(outputs.start_logits)
end_idx = torch.argmax(outputs.end_logits) + 1
answer_tokens = inputs["input_ids"][0][start_idx:end_idx]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
return answer