rubentito commited on
Commit
c9c5e85
1 Parent(s): 75152a8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -2
README.md CHANGED
@@ -21,10 +21,10 @@ This model was used as a baseline in [Hierarchical multimodal transformers for M
21
 
22
  ## How to use
23
 
24
- Here is how to use this model to get the features of a given text in PyTorch:
25
 
26
  ```python
27
- from transformers import BigBirdForQuestionAnswering
28
 
29
  # by default its in `block_sparse` mode with num_random_blocks=3, block_size=64
30
  model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa")
@@ -35,10 +35,20 @@ model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-
35
  # you can change `block_size` & `num_random_blocks` like this:
36
  model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa", block_size=16, num_random_blocks=2)
37
 
 
 
38
  question = "Replace me by any text you'd like."
39
  context = "Put some context for answering"
 
40
  encoded_input = tokenizer(question, context, return_tensors='pt')
41
  output = model(**encoded_input)
 
 
 
 
 
 
 
42
  ```
43
 
44
  ## Model results
 
21
 
22
  ## How to use
23
 
24
+ How to use this model to perform inference on a sample question and context in PyTorch:
25
 
26
  ```python
27
+ from transformers import BigBirdForQuestionAnswering, BigBirdTokenizerFast
28
 
29
  # by default its in `block_sparse` mode with num_random_blocks=3, block_size=64
30
  model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa")
 
35
  # you can change `block_size` & `num_random_blocks` like this:
36
  model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa", block_size=16, num_random_blocks=2)
37
 
38
+ tokenizer = BigBirdTokenizerFast.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa")
39
+
40
  question = "Replace me by any text you'd like."
41
  context = "Put some context for answering"
42
+
43
  encoded_input = tokenizer(question, context, return_tensors='pt')
44
  output = model(**encoded_input)
45
+
46
+ start_pos = torch.argmax(output.start_logits, dim=-1).item()
47
+ end_pos = torch.argmax(output.end_logits, dim=-1).item()
48
+
49
+ context_tokens = tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0].tolist())
50
+ answer_tokens = context_tokens[start_pos: end_pos]
51
+ answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens))
52
  ```
53
 
54
  ## Model results