File size: 791 Bytes
7dd2453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering

import tensorflow as tf

checkpoint = "distilbert-base-cased-distilled-squad"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = TFAutoModelForQuestionAnswering.from_pretrained(checkpoint)

def question_answering_tf(question, context):
  inputs = tokenizer(question, context, return_tensors="tf")
  #print(inputs["input_ids"])
  #print(tokenizer.decode(inputs["input_ids"][0]))
  outputs = model(inputs)
  #print(outputs.start_logits)
  #print(outputs.end_logits)
  start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
  end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])
  print(start_index, end_index)
  return tokenizer.decode(inputs["input_ids"][0][start_index: end_index+1])