Akshat1000 commited on
Commit
26c6fcf
·
verified ·
1 Parent(s): 809119f

Update generate_answers.py

Browse files
Files changed (1) hide show
  1. generate_answers.py +11 -8
generate_answers.py CHANGED
@@ -1,15 +1,18 @@
1
- from transformers import BertTokenizer, BertForQuestionAnswering
2
  import torch
3
 
4
- model_path = "bert-large-uncased-whole-word-masking-finetuned-squad"
5
- tokenizer = BertTokenizer.from_pretrained(model_path)
6
- model = BertForQuestionAnswering.from_pretrained(model_path)
7
 
8
  def get_answer(question, context):
9
- """Answers a question using BERT on given context."""
10
  inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
11
  with torch.no_grad():
12
  outputs = model(**inputs)
13
- start = torch.argmax(outputs.start_logits)
14
- end = torch.argmax(outputs.end_logits) + 1
15
- return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end]))
 
 
 
 
 
1
+ from transformers import BertForQuestionAnswering, BertTokenizer
2
  import torch
3
 
4
+ model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
5
+ tokenizer = BertTokenizer.from_pretrained(model_name)
6
+ model = BertForQuestionAnswering.from_pretrained(model_name)
7
 
8
  def get_answer(question, context):
 
9
  inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
10
  with torch.no_grad():
11
  outputs = model(**inputs)
12
+
13
+ start_idx = torch.argmax(outputs.start_logits)
14
+ end_idx = torch.argmax(outputs.end_logits) + 1
15
+
16
+ answer_tokens = inputs["input_ids"][0][start_idx:end_idx]
17
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
18
+ return answer