unimer_demo / app.py
wufan's picture
Upload 4 files
cc5f7b7 verified
raw
history blame
3.08 kB
import argparse
import os
import sys
import numpy as np
import cv2
import torch
import gradio as gr
from PIL import Image
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
class ImageProcessor:
def __init__(self, cfg_path):
self.cfg_path = cfg_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model, self.vis_processor = self.load_model_and_processor()
def load_model_and_processor(self):
args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
cfg = Config(args)
task = tasks.setup_task(cfg)
model = task.build_model(cfg).to(self.device)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
def process_single_image(self, image_path):
try:
raw_image = Image.open(image_path)
except IOError:
print(f"Error: Unable to open image at {image_path}")
return
# Convert PIL Image to OpenCV format
open_cv_image = np.array(raw_image)
# Convert RGB to BGR
if len(open_cv_image.shape) == 3:
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# Display the image using cv2
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
output = self.model.generate({"image": image})
pred = output["pred_str"][0]
print(f'Prediction:\n{pred}')
cv2.imshow('Original Image', open_cv_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
return pred
def recognize_image(input_img):
# latex_code = processor.process_single_image(input_img.name)
return "100"
def gradio_reset():
return gr.update(value=None)
if __name__ == "__main__":
# == init model ==
# root_path = os.path.abspath(os.getcwd())
# config_path = os.path.join(root_path, "cfg_tiny.yaml")
# processor_tiny = ImageProcessor(config_path)
# print("== all models init. ==")
# == init model ==
with open("header.html", "r") as file:
header = file.read()
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column():
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button("Clear")
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
with gr.Column():
gr.Button(value="Predict Latex:", interactive=False)
pred_latex = gr.Textbox(label='Latex', interactive=False)
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
predict.click(recognize_image, inputs=[input_img], outputs=[pred_latex])
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)