Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import sys | |
import numpy as np | |
import cv2 | |
import torch | |
import gradio as gr | |
from PIL import Image | |
sys.path.insert(0, os.path.join(os.getcwd(), "..")) | |
from unimernet.common.config import Config | |
import unimernet.tasks as tasks | |
from unimernet.processors import load_processor | |
class ImageProcessor: | |
def __init__(self, cfg_path): | |
self.cfg_path = cfg_path | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model, self.vis_processor = self.load_model_and_processor() | |
def load_model_and_processor(self): | |
args = argparse.Namespace(cfg_path=self.cfg_path, options=None) | |
cfg = Config(args) | |
task = tasks.setup_task(cfg) | |
model = task.build_model(cfg).to(self.device) | |
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) | |
return model, vis_processor | |
def process_single_image(self, image_path): | |
try: | |
raw_image = Image.open(image_path) | |
except IOError: | |
print(f"Error: Unable to open image at {image_path}") | |
return | |
# Convert PIL Image to OpenCV format | |
open_cv_image = np.array(raw_image) | |
# Convert RGB to BGR | |
if len(open_cv_image.shape) == 3: | |
# Convert RGB to BGR | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
# Display the image using cv2 | |
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) | |
output = self.model.generate({"image": image}) | |
pred = output["pred_str"][0] | |
print(f'Prediction:\n{pred}') | |
cv2.imshow('Original Image', open_cv_image) | |
cv2.waitKey(0) | |
cv2.destroyAllWindows() | |
return pred | |
def recognize_image(input_img): | |
# latex_code = processor.process_single_image(input_img.name) | |
return "100" | |
def gradio_reset(): | |
return gr.update(value=None) | |
if __name__ == "__main__": | |
# == init model == | |
# root_path = os.path.abspath(os.getcwd()) | |
# config_path = os.path.join(root_path, "cfg_tiny.yaml") | |
# processor_tiny = ImageProcessor(config_path) | |
# print("== all models init. ==") | |
# == init model == | |
with open("header.html", "r") as file: | |
header = file.read() | |
with gr.Blocks() as demo: | |
gr.HTML(header) | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(label=" ", interactive=True) | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
predict = gr.Button(value="Recognize", interactive=True, variant="primary") | |
with gr.Column(): | |
gr.Button(value="Predict Latex:", interactive=False) | |
pred_latex = gr.Textbox(label='Latex', interactive=False) | |
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex]) | |
predict.click(recognize_image, inputs=[input_img], outputs=[pred_latex]) | |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |