unimer_demo / app.py
wufan's picture
Upload app.py
a68694f verified
raw
history blame
3.18 kB
import os
import sys
import argparse
import numpy as np
import cv2
import torch
import gradio as gr
from PIL import Image
# sys.path.insert(0, os.path.join(os.getcwd(), ".."))
# from unimernet.common.config import Config
# import unimernet.tasks as tasks
# from unimernet.processors import load_processor
# class ImageProcessor:
# def __init__(self, cfg_path):
# self.cfg_path = cfg_path
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model, self.vis_processor = self.load_model_and_processor()
# def load_model_and_processor(self):
# args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
# cfg = Config(args)
# task = tasks.setup_task(cfg)
# model = task.build_model(cfg).to(self.device)
# vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
# return model, vis_processor
# def process_single_image(self, image_path):
# try:
# raw_image = Image.open(image_path)
# except IOError:
# print(f"Error: Unable to open image at {image_path}")
# return
# # Convert PIL Image to OpenCV format
# open_cv_image = np.array(raw_image)
# # Convert RGB to BGR
# if len(open_cv_image.shape) == 3:
# # Convert RGB to BGR
# open_cv_image = open_cv_image[:, :, ::-1].copy()
# # Display the image using cv2
# image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
# output = self.model.generate({"image": image})
# pred = output["pred_str"][0]
# print(f'Prediction:\n{pred}')
# cv2.imshow('Original Image', open_cv_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# return pred
def recognize_image(input_img):
# latex_code = processor.process_single_image(input_img.name)
return "100"
def gradio_reset():
return gr.update(value=None), gr.update(value=None)
if __name__ == "__main__":
# == init model ==
# root_path = os.path.abspath(os.getcwd())
# config_path = os.path.join(root_path, "cfg_tiny.yaml")
# processor_tiny = ImageProcessor(config_path)
# print("== all models init. ==")
# == init model ==
with open("header.html", "r") as file:
header = file.read()
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column():
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button("Clear")
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
with gr.Column():
gr.Button(value="Predict Latex:", interactive=False)
pred_latex = gr.Textbox(label='Latex', interactive=False)
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
predict.click(recognize_image, inputs=[input_img], outputs=[pred_latex])
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)