rubentito commited on
Commit
6d5463e
1 Parent(s): b3e51bd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -11
README.md CHANGED
@@ -21,11 +21,10 @@ This model was used as a baseline in [Hierarchical multimodal transformers for M
21
 
22
 
23
  ## How to use
24
-
25
- Here is how to use this model to get the features of a given text in PyTorch:
26
 
27
  ```python
28
- import torch
29
  from transformers import LongformerTokenizerFast, LongformerForQuestionAnswering
30
 
31
  tokenizer = LongformerTokenizerFast.from_pretrained("rubentito/longformer-base-mpdocvqa")
@@ -33,17 +32,16 @@ model = LongformerForQuestionAnswering.from_pretrained("rubentito/longformer-bas
33
 
34
  text = "Huggingface has democratized NLP. Huge thanks to Huggingface for this."
35
  question = "What has Huggingface done?"
36
- encoding = tokenizer(question, text, return_tensors="pt")
37
- input_ids = encoding["input_ids"]
38
 
39
- # default is local attention everywhere
40
- # the forward method will automatically set global attention on question tokens attention_mask=encoding["attention_mask"]
41
 
42
- start_scores, end_scores = model(input_ids, attention_mask=attention_mask)
43
- all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
44
 
45
- answer_tokens = all_tokens[torch.argmax(start_scores) :torch.argmax(end_scores)+1]
46
- answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens))
 
47
  ```
48
 
49
  ## Model results
 
21
 
22
 
23
  ## How to use
24
+ ### Inference
25
+ How to use this model to perform inference on a sample question and context in PyTorch:
26
 
27
  ```python
 
28
  from transformers import LongformerTokenizerFast, LongformerForQuestionAnswering
29
 
30
  tokenizer = LongformerTokenizerFast.from_pretrained("rubentito/longformer-base-mpdocvqa")
 
32
 
33
  text = "Huggingface has democratized NLP. Huge thanks to Huggingface for this."
34
  question = "What has Huggingface done?"
 
 
35
 
36
+ encoding = tokenizer(question, text, return_tensors="pt")
37
+ output = model(encoding["input_ids"], attention_mask=encoding["attention_mask"])
38
 
39
+ start_pos = torch.argmax(output.start_logits, dim=-1).item()
40
+ end_pos = torch.argmax(output.end_logits, dim=-1).item()
41
 
42
+ context_tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].tolist())
43
+ answer_tokens = context_tokens[start_pos: end_pos + 1]
44
+ pred_answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens))
45
  ```
46
 
47
  ## Model results