Update README.md
Browse files
README.md
CHANGED
@@ -21,10 +21,10 @@ This model was used as a baseline in [Hierarchical multimodal transformers for M
|
|
21 |
|
22 |
## How to use
|
23 |
|
24 |
-
|
25 |
|
26 |
```python
|
27 |
-
from transformers import BigBirdForQuestionAnswering
|
28 |
|
29 |
# by default its in `block_sparse` mode with num_random_blocks=3, block_size=64
|
30 |
model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa")
|
@@ -35,10 +35,20 @@ model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-
|
|
35 |
# you can change `block_size` & `num_random_blocks` like this:
|
36 |
model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa", block_size=16, num_random_blocks=2)
|
37 |
|
|
|
|
|
38 |
question = "Replace me by any text you'd like."
|
39 |
context = "Put some context for answering"
|
|
|
40 |
encoded_input = tokenizer(question, context, return_tensors='pt')
|
41 |
output = model(**encoded_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
```
|
43 |
|
44 |
## Model results
|
|
|
21 |
|
22 |
## How to use
|
23 |
|
24 |
+
How to use this model to perform inference on a sample question and context in PyTorch:
|
25 |
|
26 |
```python
|
27 |
+
from transformers import BigBirdForQuestionAnswering, BigBirdTokenizerFast
|
28 |
|
29 |
# by default its in `block_sparse` mode with num_random_blocks=3, block_size=64
|
30 |
model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa")
|
|
|
35 |
# you can change `block_size` & `num_random_blocks` like this:
|
36 |
model = BigBirdForQuestionAnswering.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa", block_size=16, num_random_blocks=2)
|
37 |
|
38 |
+
tokenizer = BigBirdTokenizerFast.from_pretrained("rubentito/bigbird-base-itc-mpdocvqa")
|
39 |
+
|
40 |
question = "Replace me by any text you'd like."
|
41 |
context = "Put some context for answering"
|
42 |
+
|
43 |
encoded_input = tokenizer(question, context, return_tensors='pt')
|
44 |
output = model(**encoded_input)
|
45 |
+
|
46 |
+
start_pos = torch.argmax(output.start_logits, dim=-1).item()
|
47 |
+
end_pos = torch.argmax(output.end_logits, dim=-1).item()
|
48 |
+
|
49 |
+
context_tokens = tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0].tolist())
|
50 |
+
answer_tokens = context_tokens[start_pos: end_pos]
|
51 |
+
answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens))
|
52 |
```
|
53 |
|
54 |
## Model results
|