File size: 3,078 Bytes
cc5f7b7
 
 
a671104
cc5f7b7
 
a671104
cc5f7b7
 
bb7b589
cc5f7b7
 
 
 
a671104
 
cc5f7b7
 
 
 
 
a671104
cc5f7b7
 
 
 
 
 
a671104
cc5f7b7
a671104
cc5f7b7
 
 
 
 
 
 
 
 
 
 
 
 
a671104
cc5f7b7
 
 
 
 
 
 
 
 
 
a671104
 
cc5f7b7
 
 
 
 
 
a671104
cc5f7b7
 
 
 
 
a671104
cc5f7b7
 
 
a671104
cc5f7b7
 
 
 
a671104
 
cc5f7b7
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import os
import sys
import numpy as np

import cv2
import torch
import gradio as gr
from PIL import Image

sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor


class ImageProcessor:
    def __init__(self, cfg_path):
        self.cfg_path = cfg_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, self.vis_processor = self.load_model_and_processor()

    def load_model_and_processor(self):
        args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
        cfg = Config(args)
        task = tasks.setup_task(cfg)
        model = task.build_model(cfg).to(self.device)
        vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)

        return model, vis_processor

    def process_single_image(self, image_path):
        try:
            raw_image = Image.open(image_path)
        except IOError:
            print(f"Error: Unable to open image at {image_path}")
            return
        # Convert PIL Image to OpenCV format
        open_cv_image = np.array(raw_image)
        # Convert RGB to BGR
        if len(open_cv_image.shape) == 3:
            # Convert RGB to BGR
            open_cv_image = open_cv_image[:, :, ::-1].copy()
        # Display the image using cv2

        image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
        output = self.model.generate({"image": image})
        pred = output["pred_str"][0]
        print(f'Prediction:\n{pred}')

        cv2.imshow('Original Image', open_cv_image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

        return pred
    

def recognize_image(input_img):
    # latex_code = processor.process_single_image(input_img.name)
    return "100"
    
def gradio_reset():
    return gr.update(value=None)

    
if __name__ == "__main__":
    # == init model ==
    # root_path = os.path.abspath(os.getcwd())
    # config_path = os.path.join(root_path, "cfg_tiny.yaml")

    # processor_tiny = ImageProcessor(config_path)
    # print("== all models init. ==")
    # == init model ==
    
    with open("header.html", "r") as file:
        header = file.read()
    with gr.Blocks() as demo:
        gr.HTML(header)
        
        with gr.Row():
            with gr.Column():
                input_img = gr.Image(label=" ", interactive=True)
                with gr.Row():
                    clear = gr.Button("Clear")
                    predict = gr.Button(value="Recognize", interactive=True, variant="primary")
            with gr.Column():
                gr.Button(value="Predict Latex:", interactive=False)
                pred_latex = gr.Textbox(label='Latex', interactive=False)
    
        clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
        predict.click(recognize_image, inputs=[input_img], outputs=[pred_latex])
    
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)