|
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering |
|
|
|
import tensorflow as tf |
|
|
|
checkpoint = "distilbert-base-cased-distilled-squad" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = TFAutoModelForQuestionAnswering.from_pretrained(checkpoint) |
|
|
|
def question_answering_tf(question, context): |
|
inputs = tokenizer(question, context, return_tensors="tf") |
|
#print(inputs["input_ids"]) |
|
#print(tokenizer.decode(inputs["input_ids"][0])) |
|
outputs = model(inputs) |
|
#print(outputs.start_logits) |
|
#print(outputs.end_logits) |
|
start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0]) |
|
end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0]) |
|
print(start_index, end_index) |
|
return tokenizer.decode(inputs["input_ids"][0][start_index: end_index+1]) |