drlon commited on
Commit
8b47a07
·
1 Parent(s): 3e42347

update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -89
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Optional
2
  import spaces
3
  import gradio as gr
@@ -73,7 +74,7 @@ This demo is powered by [Gradio](https://gradio.app/) and uses OmniParserv2 to g
73
  DEVICE = torch.device('cuda')
74
 
75
  @spaces.GPU
76
- @torch.inference_mode()
77
  def get_som_response(instruction, image_som):
78
  prompt = magma_som_prompt.format(instruction)
79
  if magam_model.config.mm_use_image_start_end:
@@ -110,7 +111,7 @@ def get_som_response(instruction, image_som):
110
  return response
111
 
112
  @spaces.GPU
113
- @torch.inference_mode()
114
  def get_qa_response(instruction, image):
115
  prompt = magma_qa_prompt.format(instruction)
116
  if magam_model.config.mm_use_image_start_end:
@@ -147,7 +148,7 @@ def get_qa_response(instruction, image):
147
  return response
148
 
149
  @spaces.GPU
150
- @torch.inference_mode()
151
  # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
152
  def process(
153
  image_input,
@@ -158,98 +159,103 @@ def process(
158
  instruction,
159
  ) -> Optional[Image.Image]:
160
 
161
- # image_save_path = 'imgs/saved_image_demo.png'
162
- # image_input.save(image_save_path)
163
- # image = Image.open(image_save_path)
164
- box_overlay_ratio = image_input.size[0] / 3200
165
- draw_bbox_config = {
166
- 'text_scale': 0.8 * box_overlay_ratio,
167
- 'text_thickness': max(int(2 * box_overlay_ratio), 1),
168
- 'text_padding': max(int(3 * box_overlay_ratio), 1),
169
- 'thickness': max(int(3 * box_overlay_ratio), 1),
170
- }
171
-
172
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
173
- text, ocr_bbox = ocr_bbox_rslt
174
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
175
- parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
176
-
177
- if len(instruction) == 0:
178
- print('finish processing')
179
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
180
- return image, str(parsed_content_list)
181
-
182
- elif instruction.startswith('Q:'):
183
- response = get_qa_response(instruction, image_input)
184
- return image_input, response
185
-
186
- # parsed_content_list = str(parsed_content_list)
187
- # convert xywh to yxhw
188
- label_coordinates_yxhw = {}
189
- for key, val in label_coordinates.items():
190
- if val[2] < 0 or val[3] < 0:
191
- continue
192
- label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]]
193
- image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False)
194
-
195
- # convert xywh to xyxy
196
- for key, val in label_coordinates.items():
197
- label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]]
198
-
199
- # normalize label_coordinates
200
- for key, val in label_coordinates.items():
201
- label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]]
202
-
203
- magma_response = get_som_response(instruction, image_som)
204
- print("magma repsonse: ", magma_response)
205
-
206
- # map magma_response into the mark id
207
- mark_id = extract_mark_id(magma_response)
208
- if mark_id is not None:
209
- if str(mark_id) in label_coordinates:
210
- bbox_for_mark = label_coordinates[str(mark_id)]
 
 
 
211
  else:
212
  bbox_for_mark = None
213
- else:
214
- bbox_for_mark = None
215
-
216
- if bbox_for_mark:
217
- # draw bbox_for_mark on the image
218
- image_som = plot_boxes_with_marks(
219
- image_input,
220
- [label_coordinates_yxhw[str(mark_id)]],
221
- som_generator,
222
- edgecolor=(255,127,111),
223
- alpha=30,
224
- fn_save=None,
225
- normalized_to_pixel=False,
226
- add_mark=False
227
- )
228
- else:
229
- try:
230
- if 'box' in magma_response:
231
- pred_bbox = extract_bbox(magma_response)
232
- click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
233
- click_point = [item / 1000 for item in click_point]
234
- else:
235
- click_point = pred_2_point(magma_response)
236
- # de-normalize click_point (width, height)
237
- click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
238
-
239
- image_som = plot_circles_with_marks(
240
  image_input,
241
- [click_point],
242
- som_generator,
243
  edgecolor=(255,127,111),
244
- linewidth=3,
245
- fn_save=None,
246
  normalized_to_pixel=False,
247
  add_mark=False
248
  )
249
- except:
250
- image_som = image_input
251
-
252
- return image_som, str(parsed_content_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  with gr.Blocks() as demo:
255
  gr.Markdown(MARKDOWN)
@@ -291,4 +297,4 @@ with gr.Blocks() as demo:
291
 
292
  demo.launch(debug=True, show_error=True, share=True)
293
  # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
294
- # demo.queue().launch(share=False)
 
1
+ import traceback
2
  from typing import Optional
3
  import spaces
4
  import gradio as gr
 
74
  DEVICE = torch.device('cuda')
75
 
76
  @spaces.GPU
77
+ # @torch.inference_mode()
78
  def get_som_response(instruction, image_som):
79
  prompt = magma_som_prompt.format(instruction)
80
  if magam_model.config.mm_use_image_start_end:
 
111
  return response
112
 
113
  @spaces.GPU
114
+ # @torch.inference_mode()
115
  def get_qa_response(instruction, image):
116
  prompt = magma_qa_prompt.format(instruction)
117
  if magam_model.config.mm_use_image_start_end:
 
148
  return response
149
 
150
  @spaces.GPU
151
+ # @torch.inference_mode()
152
  # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
153
  def process(
154
  image_input,
 
159
  instruction,
160
  ) -> Optional[Image.Image]:
161
 
162
+ try:
163
+ # image_save_path = 'imgs/saved_image_demo.png'
164
+ # image_input.save(image_save_path)
165
+ # image = Image.open(image_save_path)
166
+ box_overlay_ratio = image_input.size[0] / 3200
167
+ draw_bbox_config = {
168
+ 'text_scale': 0.8 * box_overlay_ratio,
169
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
170
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
171
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
172
+ }
173
+
174
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
175
+ text, ocr_bbox = ocr_bbox_rslt
176
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
177
+ parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
178
+
179
+ if len(instruction) == 0:
180
+ print('finish processing')
181
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
182
+ return image, str(parsed_content_list)
183
+
184
+ elif instruction.startswith('Q:'):
185
+ response = get_qa_response(instruction, image_input)
186
+ return image_input, response
187
+
188
+ # parsed_content_list = str(parsed_content_list)
189
+ # convert xywh to yxhw
190
+ label_coordinates_yxhw = {}
191
+ for key, val in label_coordinates.items():
192
+ if val[2] < 0 or val[3] < 0:
193
+ continue
194
+ label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]]
195
+ image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False)
196
+
197
+ # convert xywh to xyxy
198
+ for key, val in label_coordinates.items():
199
+ label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]]
200
+
201
+ # normalize label_coordinates
202
+ for key, val in label_coordinates.items():
203
+ label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]]
204
+
205
+ magma_response = get_som_response(instruction, image_som)
206
+ print("magma repsonse: ", magma_response)
207
+
208
+ # map magma_response into the mark id
209
+ mark_id = extract_mark_id(magma_response)
210
+ if mark_id is not None:
211
+ if str(mark_id) in label_coordinates:
212
+ bbox_for_mark = label_coordinates[str(mark_id)]
213
+ else:
214
+ bbox_for_mark = None
215
  else:
216
  bbox_for_mark = None
217
+
218
+ if bbox_for_mark:
219
+ # draw bbox_for_mark on the image
220
+ image_som = plot_boxes_with_marks(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  image_input,
222
+ [label_coordinates_yxhw[str(mark_id)]],
223
+ som_generator,
224
  edgecolor=(255,127,111),
225
+ alpha=30,
226
+ fn_save=None,
227
  normalized_to_pixel=False,
228
  add_mark=False
229
  )
230
+ else:
231
+ try:
232
+ if 'box' in magma_response:
233
+ pred_bbox = extract_bbox(magma_response)
234
+ click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
235
+ click_point = [item / 1000 for item in click_point]
236
+ else:
237
+ click_point = pred_2_point(magma_response)
238
+ # de-normalize click_point (width, height)
239
+ click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
240
+
241
+ image_som = plot_circles_with_marks(
242
+ image_input,
243
+ [click_point],
244
+ som_generator,
245
+ edgecolor=(255,127,111),
246
+ linewidth=3,
247
+ fn_save=None,
248
+ normalized_to_pixel=False,
249
+ add_mark=False
250
+ )
251
+ except:
252
+ image_som = image_input
253
+
254
+ return image_som, str(parsed_content_list)
255
+ except Exception as e:
256
+ print('error in process')
257
+ traceback.print_exc()
258
+ return image_input, 'error in process'
259
 
260
  with gr.Blocks() as demo:
261
  gr.Markdown(MARKDOWN)
 
297
 
298
  demo.launch(debug=True, show_error=True, share=True)
299
  # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
300
+ # demo.queue().launch(share=False)