Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import io | |
import base64, os | |
from huggingface_hub import snapshot_download | |
# Import ์ ํธ๋ฆฌํฐ ํจ์๋ค | |
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img | |
# Download repository (if not already downloaded) | |
repo_id = "microsoft/OmniParser-v2.0" # HF repository ID | |
local_dir = "weights" # Local directory for weights | |
snapshot_download(repo_id=repo_id, local_dir=local_dir) | |
print(f"Repository downloaded to: {local_dir}") | |
# Load models | |
yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') | |
caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption") | |
# Alternative caption model (BLIP2) can be used as below: | |
# caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2") | |
# Markdown header text | |
MARKDOWN = """ | |
# OmniParser V2 Pro๐ฅ | |
""" | |
DEVICE = torch.device('cuda') | |
# Custom CSS for UI enhancement | |
custom_css = """ | |
body { background-color: #f0f2f5; } | |
.gradio-container { font-family: 'Segoe UI', sans-serif; } | |
h1, h2, h3, h4 { color: #283E51; } | |
button { border-radius: 6px; } | |
""" | |
def process( | |
image_input, | |
box_threshold, | |
iou_threshold, | |
use_paddleocr, | |
imgsz | |
) -> Optional[tuple]: | |
# ์ ๋ ฅ๊ฐ ๊ฒ์ฆ | |
if image_input is None: | |
return None, "Please upload an image for processing." | |
try: | |
# Calculate overlay ratio based on input image width | |
box_overlay_ratio = image_input.size[0] / 3200 | |
draw_bbox_config = { | |
'text_scale': 0.8 * box_overlay_ratio, | |
'text_thickness': max(int(2 * box_overlay_ratio), 1), | |
'text_padding': max(int(3 * box_overlay_ratio), 1), | |
'thickness': max(int(3 * box_overlay_ratio), 1), | |
} | |
# Run OCR bounding box detection | |
ocr_bbox_rslt, is_goal_filtered = check_ocr_box( | |
image_input, | |
display_img=False, | |
output_bb_format='xyxy', | |
goal_filtering=None, | |
easyocr_args={'paragraph': False, 'text_threshold': 0.9}, | |
use_paddleocr=use_paddleocr | |
) | |
text, ocr_bbox = ocr_bbox_rslt | |
# Get labeled image and parsed content via SOM (YOLO + caption model) | |
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
image_input, | |
yolo_model, | |
BOX_TRESHOLD=box_threshold, | |
output_coord_in_ratio=True, | |
ocr_bbox=ocr_bbox, | |
draw_bbox_config=draw_bbox_config, | |
caption_model_processor=caption_model_processor, | |
ocr_text=text, | |
iou_threshold=iou_threshold, | |
imgsz=imgsz | |
) | |
# Decode processed image from base64 | |
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) | |
print('Finish processing image.') | |
# Format parsed content list into a multi-line string | |
parsed_text = "\n".join([f"icon {i}: {v}" for i, v in enumerate(parsed_content_list)]) | |
return image, parsed_text | |
except Exception as e: | |
print(f"Error during processing: {str(e)}") | |
return None, f"Error: {str(e)}" | |
# Build Gradio UI with enhanced layout and functionality | |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
# ์ข์ธก ์ฌ์ด๋๋ฐ (์์ฝ๋์ธ ํํ) : ์ ๋ก๋ ๋ฐ ์ค์ | |
with gr.Column(scale=1): | |
with gr.Accordion("Upload Image & Settings", open=True): | |
image_input_component = gr.Image( | |
type='pil', | |
label='Upload Image', | |
elem_id="input_image" | |
) | |
gr.Markdown("### Detection Settings") | |
box_threshold_component = gr.Slider( | |
label='Box Threshold', | |
minimum=0.01, maximum=1.0, step=0.01, value=0.05, | |
info="Minimum confidence for bounding boxes." | |
) | |
iou_threshold_component = gr.Slider( | |
label='IOU Threshold', | |
minimum=0.01, maximum=1.0, step=0.01, value=0.1, | |
info="Threshold for non-maximum suppression overlap." | |
) | |
use_paddleocr_component = gr.Checkbox( | |
label='Use PaddleOCR', value=True, | |
info="Toggle between PaddleOCR and EasyOCR." | |
) | |
imgsz_component = gr.Slider( | |
label='Icon Detect Image Size', | |
minimum=640, maximum=1920, step=32, value=640, | |
info="Resize input image for icon detection." | |
) | |
submit_button_component = gr.Button( | |
value='Process Image', variant='primary' | |
) | |
# ์ฐ์ธก ๋ฉ์ธ ์์ญ : ๊ฒฐ๊ณผ ํญ | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.Tab("Output Image"): | |
image_output_component = gr.Image( | |
type='pil', label='Processed Image' | |
) | |
with gr.Tab("Parsed Text"): | |
text_output_component = gr.Textbox( | |
label='Parsed Screen Elements', | |
placeholder='The structured elements will appear here.', | |
lines=10 | |
) | |
# ๋ฒํผ ํด๋ฆญ ์ ํ๋ก์ธ์ค ์คํ (๋ก๋ฉ ์คํผ๋ ์ ์ฉ) | |
submit_button_component.click( | |
fn=process, | |
inputs=[ | |
image_input_component, | |
box_threshold_component, | |
iou_threshold_component, | |
use_paddleocr_component, | |
imgsz_component | |
], | |
outputs=[image_output_component, text_output_component] | |
) | |
# Launch with queue support | |
demo.queue().launch(share=False) | |