File size: 2,100 Bytes
7dd2453
 
dccc40c
7dd2453
 
 
 
 
 
dccc40c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering
import tensorflow as tf
import numpy as np

checkpoint = "distilbert-base-cased-distilled-squad"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = TFAutoModelForQuestionAnswering.from_pretrained(checkpoint)

def question_answering_tf(question, context):
    inputs = tokenizer(question, 
                   context, 
                   max_length=384, 
                   stride=50, 
                   truncation='only_second', 
                   padding=True, 
                   return_overflowing_tokens=True, 
                   return_offsets_mapping=True, 
                   return_tensors="tf")
      
  _ = inputs.pop("overflow_to_sample_mapping")
    offset_mapping = inputs.pop("offset_mapping")
    
    outputs = model(inputs)

    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    #Masking
    sequence_ids = inputs.sequence_ids()
    mask = [i != 1 for i in sequence_ids]
    mask[0] = False
    mask = tf.math.logical_or(tf.constant(mask)[None], inputs["attention_mask"] == 0)

    start_logits = tf.where(mask, -10000, start_logits)
    end_logits = tf.where(mask, -10000, end_logits)

    #Softmax
    start_probabilities = tf.nn.softmax(start_logits, axis=-1).numpy()
    end_probabilities = tf.nn.softmax(end_logits, axis=-1).numpy()

    #Finding (start token, end token) pair with best probability score
    max_score = 0.0
    start_index,end_index = 0,0
    for i, probs in enumerate(zip(start_probabilities, end_probabilities)):
      sp, ep = probs
      scores = tf.matmul(sp[:,np.newaxis], ep[np.newaxis,:])
      index = np.triu(scores).argmax().item()
      row = int(index/384)
      col = index % 384
      score = scores[row][col]
      if(score > max_score):
        max_score = score
        start_index = row
        end_index = col

    #Return characters from context corresponding to start and end of token characters 
    start = int(offset_mapping[i][start_index][0])
    end = int(offset_mapping[i][end_index][1])
    return context[start:end+1]