Akshat1000 commited on
Commit
997926b
·
verified ·
1 Parent(s): ad9331d

Update generate_answers.py

Browse files
Files changed (1) hide show
  1. generate_answers.py +15 -15
generate_answers.py CHANGED
@@ -1,15 +1,15 @@
1
- from transformers import BertTokenizer, BertForQuestionAnswering
2
- import torch
3
-
4
- model_path = "D:/code/bert_easy/bert-large-uncased-whole-word-masking-finetuned-squad"
5
- tokenizer = BertTokenizer.from_pretrained(model_path)
6
- model = BertForQuestionAnswering.from_pretrained(model_path)
7
-
8
- def get_answer(question, context):
9
- """Answers a question using BERT on given context."""
10
- inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
11
- with torch.no_grad():
12
- outputs = model(**inputs)
13
- start = torch.argmax(outputs.start_logits)
14
- end = torch.argmax(outputs.end_logits) + 1
15
- return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end]))
 
1
+ from transformers import BertTokenizer, BertForQuestionAnswering
2
+ import torch
3
+
4
+ model_path = "bert-large-uncased-whole-word-masking-finetuned-squad"
5
+ tokenizer = BertTokenizer.from_pretrained(model_path)
6
+ model = BertForQuestionAnswering.from_pretrained(model_path)
7
+
8
+ def get_answer(question, context):
9
+ """Answers a question using BERT on given context."""
10
+ inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
11
+ with torch.no_grad():
12
+ outputs = model(**inputs)
13
+ start = torch.argmax(outputs.start_logits)
14
+ end = torch.argmax(outputs.end_logits) + 1
15
+ return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end]))