File size: 3,675 Bytes
c7fd94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import requests
from PIL import Image, ImageDraw
from unittest.mock import patch
import gradio as gr
import ast
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.dynamic_module_utils import get_imports

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
    if not str(filename).endswith("/modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports

with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
    model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
    processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)

def draw_boxes(image, quad_boxes):
    draw = ImageDraw.Draw(image)
    for box in quad_boxes:
        draw.polygon(box, outline="red", width=2)
    return image

def run_example(image, task, additional_text=""):
    if image is None:
        return "Please upload an image.", None

    prompt = f"<{task}>"
    if task == "CAPTION_TO_PHRASE_GROUNDING" and additional_text:
        inputs = processor(text=prompt, images=image, return_tensors="pt", text_input=additional_text)
    else:
        inputs = processor(text=prompt, images=image, return_tensors="pt")
    
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
    
    result_text = str(parsed_answer)
    result_image = image.copy()
    
    if task == "OCR_WITH_REGION":
        try:
            result_dict = ast.literal_eval(result_text)
            quad_boxes = result_dict['<OCR_WITH_REGION>']['quad_boxes']
            result_image = draw_boxes(result_image, quad_boxes)
        except:
            print("Failed to draw bounding boxes.")
    
    return result_text, result_image

def update_additional_text_visibility(task):
    return gr.update(visible=(task == "CAPTION_TO_PHRASE_GROUNDING"))

# Define the Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Florence-2 Image Analysis")
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload an image")
        with gr.Column():
            task_dropdown = gr.Dropdown(
                choices=[
                    "CAPTION", "DETAILED_CAPTION", "MORE_DETAILED_CAPTION",
                    "CAPTION_TO_PHRASE_GROUNDING", "OD", "DENSE_REGION_CAPTION",
                    "REGION_PROPOSAL", "OCR", "OCR_WITH_REGION"
                ],
                label="Select Task",
                value="CAPTION"
            )
            additional_text = gr.Textbox(
                label="Additional Text (for Caption to Phrase Grounding)",
                placeholder="Enter caption here",
                visible=False
            )
            submit_button = gr.Button("Analyze Image")
    with gr.Row():
        text_output = gr.Textbox(label="Result")
        image_output = gr.Image(label="Processed Image")

    task_dropdown.change(fn=update_additional_text_visibility, inputs=task_dropdown, outputs=additional_text)
    submit_button.click(
        fn=run_example,
        inputs=[image_input, task_dropdown, additional_text],
        outputs=[text_output, image_output]
    )

# Launch the interface
iface.launch()