import os
os.system('pip install -U transformers==4.44.2')
import sys
import shutil
import torch
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download

sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor


def load_model_and_processor(cfg_path):
    args = argparse.Namespace(cfg_path=cfg_path, options=None)
    cfg = Config(args)
    task = tasks.setup_task(cfg)
    model = task.build_model(cfg)
    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
    return model, vis_processor
    
def recognize_image(input_img, model_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model_type == "base":
        model = model_base.to(device)
    elif model_type == "small":
        model = model_small.to(device)
    else:
        model = model_tiny.to(device)
        
    if len(input_img.shape) == 3:
            input_img = input_img[:, :, ::-1].copy()

    img = Image.fromarray(input_img)
    image = vis_processor(img).unsqueeze(0).to(device)
    output = model.generate({"image": image})
    latex_code = output["pred_str"][0]
    return latex_code
    
def gradio_reset():
    return gr.update(value=None), gr.update(value=None)

    
if __name__ == "__main__":
    root_path = os.path.abspath(os.getcwd())
    
    # == download weights ==
    tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny')
    small_model_dir = snapshot_download('wanderkid/unimernet_small')
    base_model_dir = snapshot_download('wanderkid/unimernet_base')
    
    os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
    shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
    shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
    shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
    # == download weights ==
    
    # == load model ==
    model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
    model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
    model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
    print("== load all models ==")
    # == load model ==
    
    with open("header.html", "r") as file:
        header = file.read()
    with gr.Blocks() as demo:
        gr.HTML(header)
        
        with gr.Row():
            with gr.Column():
                model_type = gr.Radio(
                        choices=["tiny", "small", "base"],
                        value="tiny",
                        label="Model Type",
                        interactive=True,
                    )
                input_img = gr.Image(label=" ", interactive=True)
                with gr.Row():
                    clear = gr.Button("Clear")
                    predict = gr.Button(value="Recognize", interactive=True, variant="primary")
                    
                with gr.Accordion("Examples:"):
                    example_root = os.path.join(os.path.dirname(__file__), "examples")
                    gr.Examples(
                        examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                                    _.endswith("png")],
                        inputs=input_img,
                    )
            with gr.Column():
                gr.Button(value="Predict Latex:", interactive=False)
                pred_latex = gr.Textbox(label='Latex', interactive=False)
    
        clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
        predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
    
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)