File size: 4,379 Bytes
8a8d449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
import spaces

# == download weights ==
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
os.system("ls -l models/unimernet_tiny")
# os.system(f"sed -i 's/MODEL_DIR/{tiny_model_dir}/g' cfg_tiny.yaml")
# os.system(f"sed -i 's/MODEL_DIR/{small_model_dir}/g' cfg_small.yaml")
# os.system(f"sed -i 's/MODEL_DIR/{base_model_dir}/g' cfg_base.yaml")
# root_path = os.path.abspath(os.getcwd())
# os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
# shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
# shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
# shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
# == download weights ==

sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor


def load_model_and_processor(cfg_path):
    args = argparse.Namespace(cfg_path=cfg_path, options=None)
    cfg = Config(args)
    task = tasks.setup_task(cfg)
    model = task.build_model(cfg)
    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
    return model, vis_processor

@spaces.GPU
def recognize_image(input_img, model_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model_type == "base":
        model = model_base.to(device)
    elif model_type == "small":
        model = model_small.to(device)
    else:
        model = model_tiny.to(device)
        
    if len(input_img.shape) == 3:
            input_img = input_img[:, :, ::-1].copy()

    img = Image.fromarray(input_img)
    image = vis_processor(img).unsqueeze(0).to(device)
    output = model.generate({"image": image})
    latex_code = output["pred_str"][0]
    return latex_code
    
def gradio_reset():
    return gr.update(value=None), gr.update(value=None)

    
if __name__ == "__main__":
    root_path = os.path.abspath(os.getcwd())
    # == load model ==
    model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
    model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
    model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
    print("== load all models ==")
    # == load model ==
    
    with open("header.html", "r") as file:
        header = file.read()
    with gr.Blocks() as demo:
        gr.HTML(header)
        
        with gr.Row():
            with gr.Column():
                model_type = gr.Radio(
                        choices=["tiny", "small", "base"],
                        value="tiny",
                        label="Model Type",
                        interactive=True,
                    )
                input_img = gr.Image(label=" ", interactive=True)
                with gr.Row():
                    clear = gr.Button("Clear")
                    predict = gr.Button(value="Recognize", interactive=True, variant="primary")
                    
                with gr.Accordion("Examples:"):
                    example_root = os.path.join(os.path.dirname(__file__), "examples")
                    gr.Examples(
                        examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                                    _.endswith("png")],
                        inputs=input_img,
                    )
            with gr.Column():
                gr.Button(value="Predict Latex:", interactive=False)
                pred_latex = gr.Textbox(label='Latex', interactive=False)
    
        clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
        predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
    
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)