de-Rodrigo commited on
Commit
e76a04b
·
1 Parent(s): d0d6669

Fix Dataset Loading and Streamline Code

Browse files

- Add global dataset variable and load_merit_dataset() function
- Implement get_image_from_dataset() to ensure dataset is loaded before access
- Load dataset at the start of main block
- Update Gradio slider to use dataset length for maximum value
- Remove unused methods and imports
- Refactor process_image() to handle dataset image selection
- Adjust main block to initialize dataset before Gradio interface creation

This commit resolves the NameError related to undefined 'dataset' and
ensures proper dataset loading and access throughout the application.
It also removes unnecessary code, improving overall efficiency and readability.

Files changed (1) hide show
  1. app.py +19 -56
app.py CHANGED
@@ -1,76 +1,43 @@
1
- import io
2
- import requests
3
  import gradio as gr
4
  from huggingface_hub import list_models
5
- from datasets import load_dataset
6
  from typing import List
7
- from PIL import Image
8
  import torch
9
  from transformers import DonutProcessor, VisionEncoderDecoderModel
 
10
  import json
11
  import re
12
  import logging
 
13
 
14
  # Logging configuration
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
- # Global variables for Donut model and processor
19
  donut_model = None
20
  donut_processor = None
 
21
 
22
 
23
- def get_image_names(dataset):
24
- return [str(i) for i in range(len(dataset))]
 
 
 
25
 
26
 
27
  def get_image_from_dataset(index):
 
 
 
28
  image_data = dataset[int(index)]["image"]
29
  return image_data
30
 
31
 
32
- def process_image(image=None, dataset_image_index=None):
33
- if dataset_image_index:
34
- image = get_image_from_dataset(dataset_image_index)
35
-
36
- return image
37
-
38
-
39
- def create_interface(tag, image_indices):
40
- """Create Gradio interface"""
41
- iface = gr.Interface(
42
- fn=process_image,
43
- inputs=[
44
- gr.Dropdown(choices=get_collection_models(tag), label="Select Model"),
45
- gr.Image(type="pil", label="Upload Image"),
46
- gr.Dropdown(
47
- choices=image_indices, label="Select one from MERIT Dataset test-set"
48
- ),
49
- ],
50
- outputs=gr.Image(label="Output Image"),
51
- title="Saliency Visualization",
52
- description="Upload your image or select one from the MERIT Dataset test-set.",
53
- )
54
- return iface
55
-
56
-
57
  def get_collection_models(tag: str) -> List[str]:
58
  """Get a list of models from a specific Hugging Face collection."""
59
  models = list_models(author="de-Rodrigo")
60
-
61
- model_names = []
62
- for model in models:
63
- if tag in model.tags:
64
- model_names.append(model.modelId)
65
-
66
- return model_names
67
-
68
-
69
- def load_model(model_name: str):
70
- """Load a model from Hugging Face Hub."""
71
- model = AutoModel.from_pretrained(model_name)
72
- tokenizer = AutoTokenizer.from_pretrained(model_name)
73
- return model, tokenizer
74
 
75
 
76
  def get_donut():
@@ -145,6 +112,9 @@ def process_image(model_name, image=None, dataset_image_index=None):
145
 
146
 
147
  if __name__ == "__main__":
 
 
 
148
  models = get_collection_models("saliency")
149
  models.append("de-Rodrigo/donut-merit")
150
 
@@ -153,7 +123,9 @@ if __name__ == "__main__":
153
  inputs=[
154
  gr.Dropdown(choices=models, label="Select Model"),
155
  gr.Image(type="pil", label="Upload Image"),
156
- gr.Slider(minimum=0, maximum=99, step=1, label="Dataset Image Index"),
 
 
157
  ],
158
  outputs=[gr.Image(label="Processed Image"), gr.Textbox(label="Result")],
159
  title="Document Understanding with Donut",
@@ -161,12 +133,3 @@ if __name__ == "__main__":
161
  )
162
 
163
  demo.launch()
164
-
165
- dataset_name = "de-Rodrigo/merit"
166
- dataset = load_dataset(dataset_name, name="en-digital-seq", split="train", num_proc=8)
167
- image_indices = get_image_names(dataset)
168
-
169
- models_tag = "saliency-merit"
170
-
171
- iface = create_interface(models_tag, image_indices)
172
- iface.launch()
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import list_models
 
3
  from typing import List
 
4
  import torch
5
  from transformers import DonutProcessor, VisionEncoderDecoderModel
6
+ from PIL import Image
7
  import json
8
  import re
9
  import logging
10
+ from datasets import load_dataset
11
 
12
  # Logging configuration
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # Global variables for Donut model, processor, and dataset
17
  donut_model = None
18
  donut_processor = None
19
+ dataset = None
20
 
21
 
22
+ def load_merit_dataset():
23
+ global dataset
24
+ if dataset is None:
25
+ dataset = load_dataset("de-Rodrigo/merit", name="en-digital-seq", split="train")
26
+ return dataset
27
 
28
 
29
  def get_image_from_dataset(index):
30
+ global dataset
31
+ if dataset is None:
32
+ dataset = load_merit_dataset()
33
  image_data = dataset[int(index)]["image"]
34
  return image_data
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def get_collection_models(tag: str) -> List[str]:
38
  """Get a list of models from a specific Hugging Face collection."""
39
  models = list_models(author="de-Rodrigo")
40
+ return [model.modelId for model in models if tag in model.tags]
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  def get_donut():
 
112
 
113
 
114
  if __name__ == "__main__":
115
+ # Load the dataset
116
+ load_merit_dataset()
117
+
118
  models = get_collection_models("saliency")
119
  models.append("de-Rodrigo/donut-merit")
120
 
 
123
  inputs=[
124
  gr.Dropdown(choices=models, label="Select Model"),
125
  gr.Image(type="pil", label="Upload Image"),
126
+ gr.Slider(
127
+ minimum=0, maximum=len(dataset) - 1, step=1, label="Dataset Image Index"
128
+ ),
129
  ],
130
  outputs=[gr.Image(label="Processed Image"), gr.Textbox(label="Result")],
131
  title="Document Understanding with Donut",
 
133
  )
134
 
135
  demo.launch()