File size: 5,200 Bytes
0fc5095
bfc11d5
0fc5095
 
 
 
 
 
 
 
 
 
 
 
 
0375f07
 
 
 
 
 
 
 
 
 
 
 
0fc5095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e285866
0fc5095
 
0375f07
0fc5095
 
0375f07
0fc5095
bfc11d5
0fc5095
 
 
 
 
 
 
 
 
 
39f8e6b
0fc5095
 
5f56249
0fc5095
 
69842c6
987cc76
5f56249
591bd80
 
 
 
 
 
987cc76
69842c6
0fc5095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987cc76
 
 
 
 
 
0fc5095
 
 
 
 
 
 
 
987cc76
 
 
 
 
0fc5095
 
987cc76
 
0fc5095
39f8e6b
987cc76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from typing import Optional
import spaces

import gradio as gr
import numpy as np
import torch
from PIL import Image
import io


import base64, os
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import torch
from PIL import Image

# yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
# caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")

from ultralytics import YOLO
yolo_model = YOLO('weights/icon_detect/best.pt').to('cuda')
from transformers import AutoProcessor, AutoModelForCausalLM 
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True).to('cuda')
caption_model_processor = {'processor': processor, 'model': model}
print('finish loading model!!!')


platform = 'pc'
if platform == 'pc':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 2,
        'thickness': 2,
    }
elif platform == 'web':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 3,
        'thickness': 3,
    }
elif platform == 'mobile':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 3,
        'thickness': 3,
    }



MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent 🔥
<div>
    <a href="https://arxiv.org/pdf/2408.00203">
        <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
    </a>
</div>

OmniParser is a screen parsing tool to convert general GUI screen to structured elements. ✅
"""

# DEVICE = torch.device('cuda')

# @spaces.GPU
@torch.inference_mode()
# @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@spaces.GPU(duration=65)
def process(
    image_input,
    box_threshold,
    iou_threshold
) -> Optional[Image.Image]:

    image_save_path = 'imgs/saved_image_demo.png'
    image_input.save(image_save_path)
    # import pdb; pdb.set_trace()

    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=True)
    text, ocr_bbox = ocr_bbox_rslt
    # print('prompt:', prompt)
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold)
    image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
    print('finish processing')
    parsed_content_list = '\n'.join(parsed_content_list)
    
    # Format the coordinates output in a more readable way
    # coordinates_text = "Bounding Box Coordinates (x, y, width, height):\n"
    # for box_id, coords in sorted(label_coordinates.items(), key=lambda x: int(x[0])):
    #     # Convert numpy array to list and round values
    #     coords_list = coords.tolist()
    #     coords_formatted = [f"{coord:.1f}" for coord in coords_list]
    #     coordinates_text += f"Box {box_id}: [{coords_formatted[0]}, {coords_formatted[1]}, {coords_formatted[2]}, {coords_formatted[3]}]\n"
        
    return image, str(parsed_content_list), str(label_coordinates)


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            image_input_component = gr.Image(
                type='pil', label='Upload image')
            box_threshold_component = gr.Slider(
                label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
            iou_threshold_component = gr.Slider(
                label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
            submit_button_component = gr.Button(
                value='Submit', variant='primary')
        with gr.Column():
            image_output_component = gr.Image(type='pil', label='Image Output')
            text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
            coordinates_output_component = gr.Textbox(
                label='Bounding Box Coordinates', 
                placeholder='Coordinates will appear here',
                lines=20,  # Increased lines to show more coordinates
                interactive=False  # Make it read-only
            )

    submit_button_component.click(
        fn=process,
        inputs=[
            image_input_component,
            box_threshold_component,
            iou_threshold_component
        ],
        outputs=[
            image_output_component, 
            text_output_component,
            coordinates_output_component
        ]
    )

demo.queue().launch(share=False)

# demo.launch(debug=False, show_error=True, share=True)
# demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
# demo.queue().launch(share=False)