File size: 4,262 Bytes
cef4f97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
import os
import base64
import spaces

tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval().cuda()

@spaces.GPU
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None, render=False):
    if task == "Plain Text OCR":
        res = model.chat(tokenizer, image, ocr_type='ocr')
    elif task == "Format Text OCR":
        res = model.chat(tokenizer, image, ocr_type='format')
    elif task == "Fine-grained OCR (Box)":
        res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box)
    elif task == "Fine-grained OCR (Color)":
        res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color)
    elif task == "Multi-crop OCR":
        res = model.chat_crop(tokenizer, image_file=image)
    elif task == "Render Formatted OCR":
        res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html')
        with open('./demo.html', 'r') as f:
            html_content = f.read()
        return res, html_content
    
    return res, None

def update_inputs(task):
    if task == "Plain Text OCR" or task == "Format Text OCR" or task == "Multi-crop OCR":
        return [gr.update(visible=False)] * 4
    elif task == "Fine-grained OCR (Box)":
        return [
            gr.update(visible=True, choices=["ocr", "format"]),
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False)
        ]
    elif task == "Fine-grained OCR (Color)":
        return [
            gr.update(visible=True, choices=["ocr", "format"]),
            gr.update(visible=False),
            gr.update(visible=True, choices=["red", "green", "blue"]),
            gr.update(visible=False)
        ]
    elif task == "Render Formatted OCR":
        return [gr.update(visible=False)] * 3 + [gr.update(visible=True)]
    
def ocr_demo(image, task, ocr_type, ocr_box, ocr_color):
    res, html_content = process_image(image, task, ocr_type, ocr_box, ocr_color)
    if html_content:
        return res, html_content
    return res, None

with gr.Blocks() as demo:
    gr.Markdown("#🙋🏻‍♂️Welcome to Tonic's🫴🏻📸GOT-OCR")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="filepath", label="Input Image")
            task_dropdown = gr.Dropdown(
                choices=[
                    "Plain Text OCR",
                    "Format Text OCR",
                    "Fine-grained OCR (Box)",
                    "Fine-grained OCR (Color)",
                    "Multi-crop OCR",
                    "Render Formatted OCR"
                ],
                label="Select Task",
                value="Plain Text OCR"
            )
            ocr_type_dropdown = gr.Dropdown(
                choices=["ocr", "format"],
                label="OCR Type",
                visible=False
            )
            ocr_box_input = gr.Textbox(
                label="OCR Box (x1,y1,x2,y2)",
                placeholder="e.g., 100,100,200,200",
                visible=False
            )
            ocr_color_dropdown = gr.Dropdown(
                choices=["red", "green", "blue"],
                label="OCR Color",
                visible=False
            )
            render_checkbox = gr.Checkbox(
                label="Render Result",
                visible=False
            )
            submit_button = gr.Button("Process")
        
        with gr.Column():
            output_text = gr.Textbox(label="OCR Result")
            output_html = gr.HTML(label="Rendered HTML Output")
    
    task_dropdown.change(
        update_inputs,
        inputs=[task_dropdown],
        outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown, render_checkbox]
    )
    
    submit_button.click(
        ocr_demo,
        inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown],
        outputs=[output_text, output_html]
    )

demo.launch()