OmniParser2 / app.py
derekalia's picture
update this
fece146
from typing import Optional
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import torch
from PIL import Image
# yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
# caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
from ultralytics import YOLO
yolo_model = YOLO('weights/icon_detect/best.pt').to('cuda')
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True).to('cuda')
caption_model_processor = {'processor': processor, 'model': model}
print('finish loading model!!!')
platform = 'pc'
if platform == 'pc':
draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 2,
'thickness': 2,
}
elif platform == 'web':
draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 3,
'thickness': 3,
}
elif platform == 'mobile':
draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 3,
'thickness': 3,
}
MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent 🔥
<div>
<a href="https://arxiv.org/pdf/2408.00203">
<img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
</a>
</div>
OmniParser is a screen parsing tool to convert general GUI screen to structured elements. ✅
"""
# DEVICE = torch.device('cuda')
# @spaces.GPU
@torch.inference_mode()
# @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@spaces.GPU(duration=65)
def process(
image_input,
box_threshold,
iou_threshold
) -> Optional[Image.Image]:
image_save_path = 'imgs/saved_image_demo.png'
image_input.save(image_save_path)
# import pdb; pdb.set_trace()
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=True)
text, ocr_bbox = ocr_bbox_rslt
# print('prompt:', prompt)
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold)
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print('finish processing')
# Format the coordinates output in a more readable way
# coordinates_text = "Bounding Box Coordinates (x, y, width, height):\n"
# for box_id, coords in sorted(label_coordinates.items(), key=lambda x: int(x[0])):
# # Convert numpy array to list and round values
# coords_list = coords.tolist()
# coords_formatted = [f"{coord:.1f}" for coord in coords_list]
# coordinates_text += f"Box {box_id}: [{coords_formatted[0]}, {coords_formatted[1]}, {coords_formatted[2]}, {coords_formatted[3]}]\n"
combined_content = []
for i, content in enumerate(parsed_content_list):
if content.startswith('Text Box ID'):
box_id = str(i)
else:
# Extract the ID number from Icon Box ID format
box_id = content.split('Icon Box ID ')[1].split(':')[0]
coords = label_coordinates.get(box_id)
if coords is not None: # Changed from 'if coords:' to handle numpy arrays
coords_str = [round(x) for x in coords] # Convert numpy values to rounded integers
combined_content.append(f"{content} | Coordinates: {coords_str}")
else:
combined_content.append(content)
print(combined_content)
parsed_content_list = '\n'.join(parsed_content_list)
return image, str(parsed_content_list), str(combined_content)
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image')
box_threshold_component = gr.Slider(
label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
iou_threshold_component = gr.Slider(
label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(type='pil', label='Image Output')
text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
coordinates_output_component = gr.Textbox(
label='Bounding Box Coordinates',
placeholder='Coordinates will appear here',
lines=20, # Increased lines to show more coordinates
interactive=False # Make it read-only
)
submit_button_component.click(
fn=process,
inputs=[
image_input_component,
box_threshold_component,
iou_threshold_component
],
outputs=[
image_output_component,
text_output_component,
coordinates_output_component
]
)
demo.queue().launch(share=False, show_error=True)
# demo.launch(debug=False, show_error=True, share=True)
# demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
# demo.queue().launch(share=False)