File size: 4,718 Bytes
cef4f97
 
 
 
 
 
 
 
00ee90b
cef4f97
 
00ee90b
cef4f97
 
 
00ee90b
 
 
cef4f97
00ee90b
cef4f97
00ee90b
cef4f97
00ee90b
cef4f97
00ee90b
cef4f97
00ee90b
cef4f97
00ee90b
cef4f97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
import os
import base64
import spaces

tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval().cuda()
model.config.pad_token_id = tokenizer.eos_token_id

@spaces.GPU
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None, render=False):
    # Create attention mask
    attention_mask = torch.ones((1, model.config.max_position_embeddings), dtype=torch.long, device=model.device)
    
    if task == "Plain Text OCR":
        res = model.chat(tokenizer, image, ocr_type='ocr', attention_mask=attention_mask)
    elif task == "Format Text OCR":
        res = model.chat(tokenizer, image, ocr_type='format', attention_mask=attention_mask)
    elif task == "Fine-grained OCR (Box)":
        res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box, attention_mask=attention_mask)
    elif task == "Fine-grained OCR (Color)":
        res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color, attention_mask=attention_mask)
    elif task == "Multi-crop OCR":
        res = model.chat_crop(tokenizer, image_file=image, attention_mask=attention_mask)
    elif task == "Render Formatted OCR":
        res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html', attention_mask=attention_mask)
        with open('./demo.html', 'r') as f:
            html_content = f.read()
        return res, html_content
    
    return res, None

def update_inputs(task):
    if task == "Plain Text OCR" or task == "Format Text OCR" or task == "Multi-crop OCR":
        return [gr.update(visible=False)] * 4
    elif task == "Fine-grained OCR (Box)":
        return [
            gr.update(visible=True, choices=["ocr", "format"]),
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False)
        ]
    elif task == "Fine-grained OCR (Color)":
        return [
            gr.update(visible=True, choices=["ocr", "format"]),
            gr.update(visible=False),
            gr.update(visible=True, choices=["red", "green", "blue"]),
            gr.update(visible=False)
        ]
    elif task == "Render Formatted OCR":
        return [gr.update(visible=False)] * 3 + [gr.update(visible=True)]
    
def ocr_demo(image, task, ocr_type, ocr_box, ocr_color):
    res, html_content = process_image(image, task, ocr_type, ocr_box, ocr_color)
    if html_content:
        return res, html_content
    return res, None

with gr.Blocks() as demo:
    gr.Markdown("#🙋🏻‍♂️Welcome to Tonic's🫴🏻📸GOT-OCR")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="filepath", label="Input Image")
            task_dropdown = gr.Dropdown(
                choices=[
                    "Plain Text OCR",
                    "Format Text OCR",
                    "Fine-grained OCR (Box)",
                    "Fine-grained OCR (Color)",
                    "Multi-crop OCR",
                    "Render Formatted OCR"
                ],
                label="Select Task",
                value="Plain Text OCR"
            )
            ocr_type_dropdown = gr.Dropdown(
                choices=["ocr", "format"],
                label="OCR Type",
                visible=False
            )
            ocr_box_input = gr.Textbox(
                label="OCR Box (x1,y1,x2,y2)",
                placeholder="e.g., 100,100,200,200",
                visible=False
            )
            ocr_color_dropdown = gr.Dropdown(
                choices=["red", "green", "blue"],
                label="OCR Color",
                visible=False
            )
            render_checkbox = gr.Checkbox(
                label="Render Result",
                visible=False
            )
            submit_button = gr.Button("Process")
        
        with gr.Column():
            output_text = gr.Textbox(label="OCR Result")
            output_html = gr.HTML(label="Rendered HTML Output")
    
    task_dropdown.change(
        update_inputs,
        inputs=[task_dropdown],
        outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown, render_checkbox]
    )
    
    submit_button.click(
        ocr_demo,
        inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown],
        outputs=[output_text, output_html]
    )

demo.launch()