ginipick commited on
Commit
1974dbd
ยท
verified ยท
1 Parent(s): 0ef105d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -61
app.py CHANGED
@@ -5,29 +5,26 @@ import numpy as np
5
  import torch
6
  from PIL import Image
7
  import io
8
-
9
-
10
  import base64, os
11
- from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
12
- import torch
13
- from PIL import Image
14
-
15
  from huggingface_hub import snapshot_download
16
 
17
- # Define repository and local directory
18
- repo_id = "microsoft/OmniParser-v2.0" # HF repo
19
- local_dir = "weights" # Target local directory
 
 
 
20
 
21
- # Download the entire repository
22
  snapshot_download(repo_id=repo_id, local_dir=local_dir)
23
-
24
  print(f"Repository downloaded to: {local_dir}")
25
 
26
-
27
  yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
28
  caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption")
 
29
  # caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
30
 
 
31
  MARKDOWN = """
32
  # OmniParser V2 for Pure Vision Based General GUI Agent ๐Ÿ”ฅ
33
  <div>
@@ -35,66 +32,134 @@ MARKDOWN = """
35
  <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
36
  </a>
37
  </div>
38
-
39
- OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
40
  """
41
 
42
  DEVICE = torch.device('cuda')
43
 
 
 
 
 
 
 
 
 
 
44
  @spaces.GPU
45
  @torch.inference_mode()
46
- # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
47
  def process(
48
  image_input,
49
  box_threshold,
50
  iou_threshold,
51
  use_paddleocr,
52
  imgsz
53
- ) -> Optional[Image.Image]:
54
-
55
- # image_save_path = 'imgs/saved_image_demo.png'
56
- # image_input.save(image_save_path)
57
- # image = Image.open(image_save_path)
58
- box_overlay_ratio = image_input.size[0] / 3200
59
- draw_bbox_config = {
60
- 'text_scale': 0.8 * box_overlay_ratio,
61
- 'text_thickness': max(int(2 * box_overlay_ratio), 1),
62
- 'text_padding': max(int(3 * box_overlay_ratio), 1),
63
- 'thickness': max(int(3 * box_overlay_ratio), 1),
64
- }
65
- # import pdb; pdb.set_trace()
66
-
67
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
68
- text, ocr_bbox = ocr_bbox_rslt
69
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
70
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
71
- print('finish processing')
72
- parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
73
- # parsed_content_list = str(parsed_content_list)
74
- return image, str(parsed_content_list)
75
-
76
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  gr.Markdown(MARKDOWN)
 
78
  with gr.Row():
79
- with gr.Column():
80
- image_input_component = gr.Image(
81
- type='pil', label='Upload image')
82
- # set the threshold for removing the bounding boxes with low confidence, default is 0.05
83
- box_threshold_component = gr.Slider(
84
- label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
85
- # set the threshold for removing the bounding boxes with large overlap, default is 0.1
86
- iou_threshold_component = gr.Slider(
87
- label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
88
- use_paddleocr_component = gr.Checkbox(
89
- label='Use PaddleOCR', value=True)
90
- imgsz_component = gr.Slider(
91
- label='Icon Detect Image Size', minimum=640, maximum=1920, step=32, value=640)
92
- submit_button_component = gr.Button(
93
- value='Submit', variant='primary')
94
- with gr.Column():
95
- image_output_component = gr.Image(type='pil', label='Image Output')
96
- text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
97
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  submit_button_component.click(
99
  fn=process,
100
  inputs=[
@@ -107,6 +172,5 @@ with gr.Blocks() as demo:
107
  outputs=[image_output_component, text_output_component]
108
  )
109
 
110
- # demo.launch(debug=False, show_error=True, share=True)
111
- # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
112
- demo.queue().launch(share=False)
 
5
  import torch
6
  from PIL import Image
7
  import io
 
 
8
  import base64, os
 
 
 
 
9
  from huggingface_hub import snapshot_download
10
 
11
+ # Import ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋“ค
12
+ from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
13
+
14
+ # Download repository (if not already downloaded)
15
+ repo_id = "microsoft/OmniParser-v2.0" # HF repository ID
16
+ local_dir = "weights" # Local directory for weights
17
 
 
18
  snapshot_download(repo_id=repo_id, local_dir=local_dir)
 
19
  print(f"Repository downloaded to: {local_dir}")
20
 
21
+ # Load models
22
  yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
23
  caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption")
24
+ # Alternative caption model (BLIP2) can be used as below:
25
  # caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
26
 
27
+ # Markdown header text
28
  MARKDOWN = """
29
  # OmniParser V2 for Pure Vision Based General GUI Agent ๐Ÿ”ฅ
30
  <div>
 
32
  <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
33
  </a>
34
  </div>
35
+ OmniParser converts general GUI screens into structured elements using pure vision-based parsing.
 
36
  """
37
 
38
  DEVICE = torch.device('cuda')
39
 
40
+ # Custom CSS for UI enhancement
41
+ custom_css = """
42
+ body { background-color: #f0f2f5; }
43
+ .gradio-container { font-family: 'Segoe UI', sans-serif; }
44
+ h1, h2, h3, h4 { color: #283E51; }
45
+ button { border-radius: 6px; }
46
+ .accordion { background-color: #ffffff; border: 1px solid #ddd; border-radius: 6px; padding: 10px; }
47
+ """
48
+
49
  @spaces.GPU
50
  @torch.inference_mode()
 
51
  def process(
52
  image_input,
53
  box_threshold,
54
  iou_threshold,
55
  use_paddleocr,
56
  imgsz
57
+ ) -> Optional[tuple]:
58
+ # ์ž…๋ ฅ๊ฐ’ ๊ฒ€์ฆ
59
+ if image_input is None:
60
+ return None, "Please upload an image for processing."
61
+
62
+ try:
63
+ # Calculate overlay ratio based on input image width
64
+ box_overlay_ratio = image_input.size[0] / 3200
65
+ draw_bbox_config = {
66
+ 'text_scale': 0.8 * box_overlay_ratio,
67
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
68
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
69
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
70
+ }
71
+
72
+ # Run OCR bounding box detection
73
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
74
+ image_input,
75
+ display_img=False,
76
+ output_bb_format='xyxy',
77
+ goal_filtering=None,
78
+ easyocr_args={'paragraph': False, 'text_threshold': 0.9},
79
+ use_paddleocr=use_paddleocr
80
+ )
81
+ text, ocr_bbox = ocr_bbox_rslt
82
+
83
+ # Get labeled image and parsed content via SOM (YOLO + caption model)
84
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
85
+ image_input,
86
+ yolo_model,
87
+ BOX_TRESHOLD=box_threshold,
88
+ output_coord_in_ratio=True,
89
+ ocr_bbox=ocr_bbox,
90
+ draw_bbox_config=draw_bbox_config,
91
+ caption_model_processor=caption_model_processor,
92
+ ocr_text=text,
93
+ iou_threshold=iou_threshold,
94
+ imgsz=imgsz
95
+ )
96
+
97
+ # Decode processed image from base64
98
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
99
+ print('Finish processing image.')
100
+
101
+ # Format parsed content list into a multi-line string
102
+ parsed_text = "\n".join([f"icon {i}: {v}" for i, v in enumerate(parsed_content_list)])
103
+ return image, parsed_text
104
+ except Exception as e:
105
+ print(f"Error during processing: {str(e)}")
106
+ return None, f"Error: {str(e)}"
107
+
108
+ # Build Gradio UI with enhanced layout and functionality
109
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
110
  gr.Markdown(MARKDOWN)
111
+
112
  with gr.Row():
113
+ # ์ขŒ์ธก ์‚ฌ์ด๋“œ๋ฐ” (์•„์ฝ”๋””์–ธ ํ˜•ํƒœ) : ์—…๋กœ๋“œ ๋ฐ ์„ค์ •
114
+ with gr.Column(scale=1):
115
+ with gr.Accordion("Upload Image & Settings", open=True):
116
+ image_input_component = gr.Image(
117
+ type='pil',
118
+ label='Upload Image',
119
+ tool="editor",
120
+ elem_id="input_image"
121
+ )
122
+ gr.Markdown("### Detection Settings")
123
+ box_threshold_component = gr.Slider(
124
+ label='Box Threshold',
125
+ minimum=0.01, maximum=1.0, step=0.01, value=0.05,
126
+ info="Minimum confidence for bounding boxes."
127
+ )
128
+ iou_threshold_component = gr.Slider(
129
+ label='IOU Threshold',
130
+ minimum=0.01, maximum=1.0, step=0.01, value=0.1,
131
+ info="Threshold for non-maximum suppression overlap."
132
+ )
133
+ use_paddleocr_component = gr.Checkbox(
134
+ label='Use PaddleOCR', value=True,
135
+ info="Toggle between PaddleOCR and EasyOCR."
136
+ )
137
+ imgsz_component = gr.Slider(
138
+ label='Icon Detect Image Size',
139
+ minimum=640, maximum=1920, step=32, value=640,
140
+ info="Resize input image for icon detection."
141
+ )
142
+ submit_button_component = gr.Button(
143
+ value='Process Image', variant='primary'
144
+ )
145
+
146
+ # ์šฐ์ธก ๋ฉ”์ธ ์˜์—ญ : ๊ฒฐ๊ณผ ํƒญ
147
+ with gr.Column(scale=2):
148
+ with gr.Tabs():
149
+ with gr.Tab("Output Image"):
150
+ with gr.Box():
151
+ image_output_component = gr.Image(
152
+ type='pil', label='Processed Image'
153
+ )
154
+ with gr.Tab("Parsed Text"):
155
+ with gr.Box():
156
+ text_output_component = gr.Textbox(
157
+ label='Parsed Screen Elements',
158
+ placeholder='The structured elements will appear here.',
159
+ lines=10
160
+ )
161
+
162
+ # ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ํ”„๋กœ์„ธ์Šค ์‹คํ–‰ (๋กœ๋”ฉ ์Šคํ”ผ๋„ˆ ์ ์šฉ)
163
  submit_button_component.click(
164
  fn=process,
165
  inputs=[
 
172
  outputs=[image_output_component, text_output_component]
173
  )
174
 
175
+ # Launch with queue support
176
+ demo.queue().launch(share=False)