de-Rodrigo commited on
Commit
d0d6669
·
1 Parent(s): 00b05e0

Update to Get Donut Results

Browse files
Files changed (1) hide show
  1. app.py +109 -11
app.py CHANGED
@@ -1,20 +1,34 @@
1
  import io
2
  import requests
3
  import gradio as gr
4
- # from transformers import AutoModel, AutoTokenizer
5
  from huggingface_hub import list_models
6
  from datasets import load_dataset
7
  from typing import List
8
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def get_image_names(dataset):
12
  return [str(i) for i in range(len(dataset))]
13
 
 
14
  def get_image_from_dataset(index):
15
  image_data = dataset[int(index)]["image"]
16
  return image_data
17
 
 
18
  def process_image(image=None, dataset_image_index=None):
19
  if dataset_image_index:
20
  image = get_image_from_dataset(dataset_image_index)
@@ -22,19 +36,20 @@ def process_image(image=None, dataset_image_index=None):
22
  return image
23
 
24
 
25
-
26
  def create_interface(tag, image_indices):
27
- """ Create Gradio interface"""
28
  iface = gr.Interface(
29
  fn=process_image,
30
  inputs=[
31
  gr.Dropdown(choices=get_collection_models(tag), label="Select Model"),
32
  gr.Image(type="pil", label="Upload Image"),
33
- gr.Dropdown(choices=image_indices, label="Select one from MERIT Dataset test-set"),
 
 
34
  ],
35
  outputs=gr.Image(label="Output Image"),
36
  title="Saliency Visualization",
37
- description="Upload your image or select one from the MERIT Dataset test-set."
38
  )
39
  return iface
40
 
@@ -50,19 +65,102 @@ def get_collection_models(tag: str) -> List[str]:
50
 
51
  return model_names
52
 
 
53
  def load_model(model_name: str):
54
  """Load a model from Hugging Face Hub."""
55
  model = AutoModel.from_pretrained(model_name)
56
  tokenizer = AutoTokenizer.from_pretrained(model_name)
57
  return model, tokenizer
58
 
59
- # # Example processing function
60
- # def process_input(text: str, model_name: str) -> str:
61
- # model, tokenizer = load_model(model_name)
62
- # inputs = tokenizer(text, return_tensors="pt")
63
- # outputs = model(**inputs)
64
- # return f"Processed output with {model_name}"
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  dataset_name = "de-Rodrigo/merit"
68
  dataset = load_dataset(dataset_name, name="en-digital-seq", split="train", num_proc=8)
 
1
  import io
2
  import requests
3
  import gradio as gr
 
4
  from huggingface_hub import list_models
5
  from datasets import load_dataset
6
  from typing import List
7
  from PIL import Image
8
+ import torch
9
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
10
+ import json
11
+ import re
12
+ import logging
13
+
14
+ # Logging configuration
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Global variables for Donut model and processor
19
+ donut_model = None
20
+ donut_processor = None
21
 
22
 
23
  def get_image_names(dataset):
24
  return [str(i) for i in range(len(dataset))]
25
 
26
+
27
  def get_image_from_dataset(index):
28
  image_data = dataset[int(index)]["image"]
29
  return image_data
30
 
31
+
32
  def process_image(image=None, dataset_image_index=None):
33
  if dataset_image_index:
34
  image = get_image_from_dataset(dataset_image_index)
 
36
  return image
37
 
38
 
 
39
  def create_interface(tag, image_indices):
40
+ """Create Gradio interface"""
41
  iface = gr.Interface(
42
  fn=process_image,
43
  inputs=[
44
  gr.Dropdown(choices=get_collection_models(tag), label="Select Model"),
45
  gr.Image(type="pil", label="Upload Image"),
46
+ gr.Dropdown(
47
+ choices=image_indices, label="Select one from MERIT Dataset test-set"
48
+ ),
49
  ],
50
  outputs=gr.Image(label="Output Image"),
51
  title="Saliency Visualization",
52
+ description="Upload your image or select one from the MERIT Dataset test-set.",
53
  )
54
  return iface
55
 
 
65
 
66
  return model_names
67
 
68
+
69
  def load_model(model_name: str):
70
  """Load a model from Hugging Face Hub."""
71
  model = AutoModel.from_pretrained(model_name)
72
  tokenizer = AutoTokenizer.from_pretrained(model_name)
73
  return model, tokenizer
74
 
 
 
 
 
 
 
75
 
76
+ def get_donut():
77
+ global donut_model, donut_processor
78
+ if donut_model is None or donut_processor is None:
79
+ try:
80
+ donut_model = VisionEncoderDecoderModel.from_pretrained(
81
+ "de-Rodrigo/donut-merit"
82
+ )
83
+ donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
84
+ if torch.cuda.is_available():
85
+ donut_model = donut_model.to("cuda")
86
+ logger.info("Donut model loaded successfully")
87
+ except Exception as e:
88
+ logger.error(f"Error loading Donut model: {str(e)}")
89
+ raise
90
+ return donut_model, donut_processor
91
+
92
+
93
+ def process_image_donut(model, processor, image):
94
+ try:
95
+ if not isinstance(image, Image.Image):
96
+ image = Image.fromarray(image)
97
+
98
+ pixel_values = processor(image, return_tensors="pt").pixel_values
99
+ if torch.cuda.is_available():
100
+ pixel_values = pixel_values.to("cuda")
101
+
102
+ task_prompt = "<s_cord-v2>"
103
+ decoder_input_ids = processor.tokenizer(
104
+ task_prompt, add_special_tokens=False, return_tensors="pt"
105
+ )["input_ids"]
106
+
107
+ outputs = model.generate(
108
+ pixel_values,
109
+ decoder_input_ids=decoder_input_ids,
110
+ max_length=model.decoder.config.max_position_embeddings,
111
+ early_stopping=True,
112
+ pad_token_id=processor.tokenizer.pad_token_id,
113
+ eos_token_id=processor.tokenizer.eos_token_id,
114
+ use_cache=True,
115
+ num_beams=1,
116
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
117
+ return_dict_in_generate=True,
118
+ )
119
+
120
+ sequence = processor.batch_decode(outputs.sequences)[0]
121
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
122
+ processor.tokenizer.pad_token, ""
123
+ )
124
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
125
+
126
+ result = processor.token2json(sequence)
127
+ return json.dumps(result, indent=2)
128
+ except Exception as e:
129
+ logger.error(f"Error processing image with Donut: {str(e)}")
130
+ return f"Error: {str(e)}"
131
+
132
+
133
+ def process_image(model_name, image=None, dataset_image_index=None):
134
+ if dataset_image_index is not None:
135
+ image = get_image_from_dataset(dataset_image_index)
136
+
137
+ if model_name == "de-Rodrigo/donut-merit":
138
+ model, processor = get_donut()
139
+ result = process_image_donut(model, processor, image)
140
+ else:
141
+ # Here you should implement processing for other models
142
+ result = f"Processing for model {model_name} not implemented"
143
+
144
+ return image, result
145
+
146
+
147
+ if __name__ == "__main__":
148
+ models = get_collection_models("saliency")
149
+ models.append("de-Rodrigo/donut-merit")
150
+
151
+ demo = gr.Interface(
152
+ fn=process_image,
153
+ inputs=[
154
+ gr.Dropdown(choices=models, label="Select Model"),
155
+ gr.Image(type="pil", label="Upload Image"),
156
+ gr.Slider(minimum=0, maximum=99, step=1, label="Dataset Image Index"),
157
+ ],
158
+ outputs=[gr.Image(label="Processed Image"), gr.Textbox(label="Result")],
159
+ title="Document Understanding with Donut",
160
+ description="Upload an image or select one from the dataset to process with the selected model.",
161
+ )
162
+
163
+ demo.launch()
164
 
165
  dataset_name = "de-Rodrigo/merit"
166
  dataset = load_dataset(dataset_name, name="en-digital-seq", split="train", num_proc=8)