File size: 732 Bytes
21f0554
997926b
21f0554
 
997926b
 
 
 
 
26c6fcf
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

tokenizer = AutoTokenizer.from_pretrained("Vardan-verma/Question_Answering_model_finetuned_on_bert")
model = AutoModelForQuestionAnswering.from_pretrained("Vardan-verma/Question_Answering_model_finetuned_on_bert")

def get_answer(question, context):
    inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)

    start_idx = torch.argmax(outputs.start_logits)
    end_idx = torch.argmax(outputs.end_logits) + 1

    answer_tokens = inputs["input_ids"][0][start_idx:end_idx]
    answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
    return answer