wjbmattingly commited on
Commit
62421d3
·
verified ·
1 Parent(s): 854f4c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -96
app.py CHANGED
@@ -25,12 +25,6 @@ DEFAULT_NER_LABELS = "person, organization, location, date, event"
25
 
26
  # }
27
 
28
- class TextWithMetadata(list):
29
- def __init__(self, *args, **kwargs):
30
- super().__init__(*args)
31
- self.original_text = kwargs.get('original_text', '')
32
- self.entities = kwargs.get('entities', [])
33
-
34
  def array_to_image_path(image_array):
35
  # Convert numpy array to PIL Image
36
  img = Image.fromarray(np.uint8(image_array))
@@ -80,92 +74,104 @@ prompt_suffix = "<|end|>\n"
80
 
81
  @spaces.GPU
82
  def run_example(image, model_id="Qwen/Qwen2.5-VL-7B-Instruct", run_ner=False, ner_labels=DEFAULT_NER_LABELS):
83
- # First get the OCR text
84
- text_input = "Convert the image to text."
85
- image_path = array_to_image_path(image)
86
-
87
- model = models[model_id]
88
- processor = processors[model_id]
89
-
90
- prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
91
- image = Image.fromarray(image).convert("RGB")
92
- messages = [
93
- {
94
- "role": "user",
95
- "content": [
96
- {
97
- "type": "image",
98
- "image": image_path,
99
- },
100
- {"type": "text", "text": text_input},
101
- ],
102
- }
103
- ]
104
-
105
- # Preparation for inference
106
- text = processor.apply_chat_template(
107
- messages, tokenize=False, add_generation_prompt=True
108
- )
109
- image_inputs, video_inputs = process_vision_info(messages)
110
- inputs = processor(
111
- text=[text],
112
- images=image_inputs,
113
- videos=video_inputs,
114
- padding=True,
115
- return_tensors="pt",
116
- )
117
- inputs = inputs.to("cuda")
118
-
119
- # Inference: Generation of the output
120
- generated_ids = model.generate(**inputs, max_new_tokens=1024)
121
- generated_ids_trimmed = [
122
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
123
- ]
124
- output_text = processor.batch_decode(
125
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
126
- )
127
-
128
- ocr_text = output_text[0]
129
-
130
- # If NER is enabled, process the OCR text
131
- if run_ner:
132
- ner_results = gliner_model.predict_entities(
133
- ocr_text,
134
- ner_labels.split(","),
135
- threshold=0.3
136
  )
137
 
138
- # Create a list of tuples (text, label) for highlighting
139
- highlighted_text = []
140
- last_end = 0
141
 
142
- # Sort entities by start position
143
- sorted_entities = sorted(ner_results, key=lambda x: x["start"])
 
 
 
144
 
145
- # Process each entity and add non-entity text segments
146
- for entity in sorted_entities:
147
- # Add non-entity text before the current entity
148
- if last_end < entity["start"]:
149
- highlighted_text.append((ocr_text[last_end:entity["start"]], None))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # Add the entity text with its label
152
- highlighted_text.append((
153
- ocr_text[entity["start"]:entity["end"]],
154
- entity["label"]
155
- ))
156
- last_end = entity["end"]
157
 
158
- # Add any remaining text after the last entity
159
- if last_end < len(ocr_text):
160
- highlighted_text.append((ocr_text[last_end:], None))
161
 
162
- # Create TextWithMetadata instance with the highlighted text and metadata
163
- result = TextWithMetadata(highlighted_text, original_text=ocr_text, entities=ner_results)
164
- return result, result # Return twice: once for display, once for state
165
-
166
- # If NER is disabled, return the text without highlighting
167
- result = TextWithMetadata([(ocr_text, None)], original_text=ocr_text, entities=[])
168
- return result, result # Return twice: once for display, once for state
169
 
170
  css = """
171
  /* Overall app styling */
@@ -175,7 +181,6 @@ css = """
175
  padding: 20px;
176
  background-color: #f8f9fa;
177
  }
178
-
179
  /* Tabs styling */
180
  .tabs {
181
  border-radius: 8px;
@@ -183,7 +188,6 @@ css = """
183
  padding: 20px;
184
  box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
185
  }
186
-
187
  /* Input/Output containers */
188
  .input-container, .output-container {
189
  background: white;
@@ -192,7 +196,6 @@ css = """
192
  margin: 10px 0;
193
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
194
  }
195
-
196
  /* Button styling */
197
  .submit-btn {
198
  background-color: #2d31fa !important;
@@ -207,7 +210,6 @@ css = """
207
  background-color: #1f24c7 !important;
208
  transform: translateY(-1px);
209
  }
210
-
211
  /* Output text area */
212
  #output {
213
  height: 500px;
@@ -218,13 +220,11 @@ css = """
218
  background: #ffffff;
219
  font-family: 'Arial', sans-serif;
220
  }
221
-
222
  /* Dropdown styling */
223
  .gr-dropdown {
224
  border-radius: 6px !important;
225
  border: 1px solid #e0e0e0 !important;
226
  }
227
-
228
  /* Image upload area */
229
  .gr-image-input {
230
  border: 2px dashed #ccc;
@@ -232,7 +232,6 @@ css = """
232
  padding: 20px;
233
  transition: all 0.3s ease;
234
  }
235
-
236
  .gr-image-input:hover {
237
  border-color: #2d31fa;
238
  }
@@ -283,7 +282,7 @@ with gr.Blocks(css=css) as demo:
283
  # Modify create_zip to use the state data
284
  def create_zip(image, fname, ocr_result):
285
  # Validate inputs
286
- if not fname or image is None: # Changed the validation check
287
  return None
288
 
289
  try:
@@ -298,9 +297,9 @@ with gr.Blocks(css=css) as demo:
298
  img_path = os.path.join(temp_dir, f"{fname}.png")
299
  image.save(img_path)
300
 
301
- # Use the OCR result from state
302
- original_text = ocr_result.original_text if ocr_result else ""
303
- entities = ocr_result.entities if ocr_result else []
304
 
305
  # Save text
306
  txt_path = os.path.join(temp_dir, f"{fname}.txt")
 
25
 
26
  # }
27
 
 
 
 
 
 
 
28
  def array_to_image_path(image_array):
29
  # Convert numpy array to PIL Image
30
  img = Image.fromarray(np.uint8(image_array))
 
74
 
75
  @spaces.GPU
76
  def run_example(image, model_id="Qwen/Qwen2.5-VL-7B-Instruct", run_ner=False, ner_labels=DEFAULT_NER_LABELS):
77
+ try:
78
+ # First get the OCR text
79
+ text_input = "Convert the image to text."
80
+ image_path = array_to_image_path(image)
81
+
82
+ model = models[model_id]
83
+ processor = processors[model_id]
84
+
85
+ prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
86
+ image = Image.fromarray(image).convert("RGB")
87
+ messages = [
88
+ {
89
+ "role": "user",
90
+ "content": [
91
+ {
92
+ "type": "image",
93
+ "image": image_path,
94
+ },
95
+ {"type": "text", "text": text_input},
96
+ ],
97
+ }
98
+ ]
99
+
100
+ # Preparation for inference
101
+ text = processor.apply_chat_template(
102
+ messages, tokenize=False, add_generation_prompt=True
103
+ )
104
+ image_inputs, video_inputs = process_vision_info(messages)
105
+ inputs = processor(
106
+ text=[text],
107
+ images=image_inputs,
108
+ videos=video_inputs,
109
+ padding=True,
110
+ return_tensors="pt",
111
+ )
112
+ inputs = inputs.to("cuda")
113
+
114
+ # Inference: Generation of the output
115
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
116
+ generated_ids_trimmed = [
117
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
118
+ ]
119
+ output_text = processor.batch_decode(
120
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
 
 
 
 
 
 
 
121
  )
122
 
123
+ ocr_text = output_text[0]
 
 
124
 
125
+ # Create state dictionary
126
+ state_dict = {
127
+ "original_text": ocr_text,
128
+ "entities": []
129
+ }
130
 
131
+ # If NER is enabled, process the OCR text
132
+ if run_ner:
133
+ ner_results = gliner_model.predict_entities(
134
+ ocr_text,
135
+ ner_labels.split(","),
136
+ threshold=0.3
137
+ )
138
+
139
+ # Update state with entities
140
+ state_dict["entities"] = ner_results
141
+
142
+ # Create a list of tuples (text, label) for highlighting
143
+ highlighted_text = []
144
+ last_end = 0
145
+
146
+ # Sort entities by start position
147
+ sorted_entities = sorted(ner_results, key=lambda x: x["start"])
148
+
149
+ # Process each entity and add non-entity text segments
150
+ for entity in sorted_entities:
151
+ # Add non-entity text before the current entity
152
+ if last_end < entity["start"]:
153
+ highlighted_text.append((ocr_text[last_end:entity["start"]], None))
154
+
155
+ # Add the entity text with its label
156
+ highlighted_text.append((
157
+ ocr_text[entity["start"]:entity["end"]],
158
+ entity["label"]
159
+ ))
160
+ last_end = entity["end"]
161
+
162
+ # Add any remaining text after the last entity
163
+ if last_end < len(ocr_text):
164
+ highlighted_text.append((ocr_text[last_end:], None))
165
 
166
+ return highlighted_text, state_dict
 
 
 
 
 
167
 
168
+ # If NER is disabled, return the text without highlighting
169
+ highlighted_text = [(ocr_text, None)]
170
+ return highlighted_text, state_dict
171
 
172
+ except Exception as e:
173
+ error_msg = f"Error processing image: {str(e)}"
174
+ return [(error_msg, None)], {"original_text": error_msg, "entities": []}
 
 
 
 
175
 
176
  css = """
177
  /* Overall app styling */
 
181
  padding: 20px;
182
  background-color: #f8f9fa;
183
  }
 
184
  /* Tabs styling */
185
  .tabs {
186
  border-radius: 8px;
 
188
  padding: 20px;
189
  box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
190
  }
 
191
  /* Input/Output containers */
192
  .input-container, .output-container {
193
  background: white;
 
196
  margin: 10px 0;
197
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
198
  }
 
199
  /* Button styling */
200
  .submit-btn {
201
  background-color: #2d31fa !important;
 
210
  background-color: #1f24c7 !important;
211
  transform: translateY(-1px);
212
  }
 
213
  /* Output text area */
214
  #output {
215
  height: 500px;
 
220
  background: #ffffff;
221
  font-family: 'Arial', sans-serif;
222
  }
 
223
  /* Dropdown styling */
224
  .gr-dropdown {
225
  border-radius: 6px !important;
226
  border: 1px solid #e0e0e0 !important;
227
  }
 
228
  /* Image upload area */
229
  .gr-image-input {
230
  border: 2px dashed #ccc;
 
232
  padding: 20px;
233
  transition: all 0.3s ease;
234
  }
 
235
  .gr-image-input:hover {
236
  border-color: #2d31fa;
237
  }
 
282
  # Modify create_zip to use the state data
283
  def create_zip(image, fname, ocr_result):
284
  # Validate inputs
285
+ if not fname or image is None:
286
  return None
287
 
288
  try:
 
297
  img_path = os.path.join(temp_dir, f"{fname}.png")
298
  image.save(img_path)
299
 
300
+ # Use the OCR result from state - now it's a dictionary
301
+ original_text = ocr_result.get("original_text", "") if ocr_result else ""
302
+ entities = ocr_result.get("entities", []) if ocr_result else []
303
 
304
  # Save text
305
  txt_path = os.path.join(temp_dir, f"{fname}.txt")