ginipick's picture
Update app.py
3934656 verified
raw
history blame
6.12 kB
from typing import Optional
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from huggingface_hub import snapshot_download
# Import ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋“ค
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
# Download repository (if not already downloaded)
repo_id = "microsoft/OmniParser-v2.0" # HF repository ID
local_dir = "weights" # Local directory for weights
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print(f"Repository downloaded to: {local_dir}")
# Load models
yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption")
# Alternative caption model (BLIP2) can be used as below:
# caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
# Markdown header text
MARKDOWN = """
# OmniParser V2 Pro๐Ÿ”ฅ
"""
DEVICE = torch.device('cuda')
# Custom CSS for UI enhancement
custom_css = """
body { background-color: #f0f2f5; }
.gradio-container { font-family: 'Segoe UI', sans-serif; }
h1, h2, h3, h4 { color: #283E51; }
button { border-radius: 6px; }
"""
@spaces.GPU
@torch.inference_mode()
def process(
image_input,
box_threshold,
iou_threshold,
use_paddleocr,
imgsz
) -> Optional[tuple]:
# ์ž…๋ ฅ๊ฐ’ ๊ฒ€์ฆ
if image_input is None:
return None, "Please upload an image for processing."
try:
# Calculate overlay ratio based on input image width
box_overlay_ratio = image_input.size[0] / 3200
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
# Run OCR bounding box detection
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_input,
display_img=False,
output_bb_format='xyxy',
goal_filtering=None,
easyocr_args={'paragraph': False, 'text_threshold': 0.9},
use_paddleocr=use_paddleocr
)
text, ocr_bbox = ocr_bbox_rslt
# Get labeled image and parsed content via SOM (YOLO + caption model)
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_input,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
imgsz=imgsz
)
# Decode processed image from base64
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print('Finish processing image.')
# Format parsed content list into a multi-line string
parsed_text = "\n".join([f"icon {i}: {v}" for i, v in enumerate(parsed_content_list)])
return image, parsed_text
except Exception as e:
print(f"Error during processing: {str(e)}")
return None, f"Error: {str(e)}"
# Build Gradio UI with enhanced layout and functionality
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
# ์ขŒ์ธก ์‚ฌ์ด๋“œ๋ฐ” (์•„์ฝ”๋””์–ธ ํ˜•ํƒœ) : ์—…๋กœ๋“œ ๋ฐ ์„ค์ •
with gr.Column(scale=1):
with gr.Accordion("Upload Image & Settings", open=True):
image_input_component = gr.Image(
type='pil',
label='Upload Image',
elem_id="input_image"
)
gr.Markdown("### Detection Settings")
box_threshold_component = gr.Slider(
label='Box Threshold',
minimum=0.01, maximum=1.0, step=0.01, value=0.05,
info="Minimum confidence for bounding boxes."
)
iou_threshold_component = gr.Slider(
label='IOU Threshold',
minimum=0.01, maximum=1.0, step=0.01, value=0.1,
info="Threshold for non-maximum suppression overlap."
)
use_paddleocr_component = gr.Checkbox(
label='Use PaddleOCR', value=True,
info="Toggle between PaddleOCR and EasyOCR."
)
imgsz_component = gr.Slider(
label='Icon Detect Image Size',
minimum=640, maximum=1920, step=32, value=640,
info="Resize input image for icon detection."
)
submit_button_component = gr.Button(
value='Process Image', variant='primary'
)
# ์šฐ์ธก ๋ฉ”์ธ ์˜์—ญ : ๊ฒฐ๊ณผ ํƒญ
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("Output Image"):
image_output_component = gr.Image(
type='pil', label='Processed Image'
)
with gr.Tab("Parsed Text"):
text_output_component = gr.Textbox(
label='Parsed Screen Elements',
placeholder='The structured elements will appear here.',
lines=10
)
# ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ํ”„๋กœ์„ธ์Šค ์‹คํ–‰ (๋กœ๋”ฉ ์Šคํ”ผ๋„ˆ ์ ์šฉ)
submit_button_component.click(
fn=process,
inputs=[
image_input_component,
box_threshold_component,
iou_threshold_component,
use_paddleocr_component,
imgsz_component
],
outputs=[image_output_component, text_output_component]
)
# Launch with queue support
demo.queue().launch(share=False)