ajimeno commited on
Commit
5bb196c
·
1 Parent(s): c77986f

Updated selections

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -8,7 +8,18 @@ from PIL import Image
8
  from io import BytesIO
9
  from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
10
 
11
- def run_prediction(sample, model, processor, prompt):
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  pixel_values = processor(np.array(
14
  sample,
@@ -20,10 +31,12 @@ def run_prediction(sample, model, processor, prompt):
20
  pixel_values.to(device),
21
  decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
22
  do_sample=True,
23
- top_p=0.92,
24
- top_k=5,
25
- no_repeat_ngram_size=10,
26
- num_beams=3
 
 
27
  )
28
 
29
  # process output
@@ -54,7 +67,7 @@ with st.sidebar:
54
  image_bytes_data = uploaded_file.getvalue()
55
  image_upload = Image.open(BytesIO(image_bytes_data))
56
 
57
- prompt = st.selectbox('Prompt', ('<s><s_pretraining>', '<s><s_plain>', '<s><s_hierarchical>'), index=2)
58
 
59
  if image_upload:
60
  image = image_upload
@@ -95,6 +108,6 @@ with st.spinner(f'Processing the document ...'):
95
  st.session_state['model'] = model
96
 
97
  st.info(f'Parsing document')
98
- parsed_info = run_prediction(image.convert("RGB"), model, processor, prompt)
99
  st.text(f'\nDocument:')
100
  st.text_area('Output text', value=parsed_info, height=800)
 
8
  from io import BytesIO
9
  from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
10
 
11
+ def run_prediction(sample, model, processor, mode):
12
+
13
+ if mode == "OCR":
14
+ prompt = "<s><s_pretraining>"
15
+ no_repeat_ngram_size = 10
16
+ elif mode == "Table":
17
+ prompt = "<s><s_hierarchical>"
18
+ no_repeat_ngram_size = 45
19
+ else:
20
+ prompt = "<s><s_hierarchical>"
21
+ no_repeat_ngram_size = 10
22
+
23
 
24
  pixel_values = processor(np.array(
25
  sample,
 
31
  pixel_values.to(device),
32
  decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
33
  do_sample=True,
34
+ top_p=0.92, #.92,
35
+ top_k=10,
36
+ no_repeat_ngram_size=no_repeat_ngram_size,
37
+ num_beams=3,
38
+ output_attentions=False,
39
+ output_hidden_states=False,
40
  )
41
 
42
  # process output
 
67
  image_bytes_data = uploaded_file.getvalue()
68
  image_upload = Image.open(BytesIO(image_bytes_data))
69
 
70
+ mode = st.selectbox('Mode', ('OCR', 'Tables', 'Element annotation'), index=2)
71
 
72
  if image_upload:
73
  image = image_upload
 
108
  st.session_state['model'] = model
109
 
110
  st.info(f'Parsing document')
111
+ parsed_info = run_prediction(image.convert("RGB"), model, processor, mode)
112
  st.text(f'\nDocument:')
113
  st.text_area('Output text', value=parsed_info, height=800)