unimer_demo / app.py
wufan's picture
Upload 3 files
ca41cb1 verified
raw
history blame
3.15 kB
import os
import sys
import argparse
import numpy as np
import cv2
import torch
import gradio as gr
from PIL import Image
# sys.path.insert(0, os.path.join(os.getcwd(), ".."))
# from unimernet.common.config import Config
# import unimernet.tasks as tasks
# from unimernet.processors import load_processor
# class ImageProcessor:
# def __init__(self, cfg_path):
# self.cfg_path = cfg_path
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model, self.vis_processor = self.load_model_and_processor()
# def load_model_and_processor(self):
# args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
# cfg = Config(args)
# task = tasks.setup_task(cfg)
# model = task.build_model(cfg).to(self.device)
# vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
# return model, vis_processor
# def process_single_image(self, image_path):
# try:
# raw_image = Image.open(image_path)
# except IOError:
# print(f"Error: Unable to open image at {image_path}")
# return
# # Convert PIL Image to OpenCV format
# open_cv_image = np.array(raw_image)
# # Convert RGB to BGR
# if len(open_cv_image.shape) == 3:
# # Convert RGB to BGR
# open_cv_image = open_cv_image[:, :, ::-1].copy()
# # Display the image using cv2
# image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
# output = self.model.generate({"image": image})
# pred = output["pred_str"][0]
# print(f'Prediction:\n{pred}')
# cv2.imshow('Original Image', open_cv_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# return pred
def recognize_image(input_img):
# latex_code = processor.process_single_image(input_img.name)
return "100"
def gradio_reset():
return gr.update(value=None)
if __name__ == "__main__":
# == init model ==
# root_path = os.path.abspath(os.getcwd())
# config_path = os.path.join(root_path, "cfg_tiny.yaml")
# processor_tiny = ImageProcessor(config_path)
# print("== all models init. ==")
# == init model ==
with open("header.html", "r") as file:
header = file.read()
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column():
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button("Clear")
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
with gr.Column():
gr.Button(value="Predict Latex:", interactive=False)
pred_latex = gr.Textbox(label='Latex', interactive=False)
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
predict.click(recognize_image, inputs=[input_img], outputs=[pred_latex])
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)