import os
import time
from huggingface_hub import create_repo, whoami
import gradio as gr
from config_store import (
    get_inference_config,
    get_onnxruntime_config,
    get_openvino_config,
    get_pytorch_config,
    get_process_config,
)
from optimum_benchmark.backends.openvino.utils import TASKS_TO_OVMODEL
from optimum_benchmark.backends.transformers_utils import TASKS_TO_MODEL_LOADERS
from optimum_benchmark.backends.onnxruntime.utils import TASKS_TO_ORTMODELS
from optimum_benchmark.backends.ipex.utils import TASKS_TO_IPEXMODEL
from optimum_benchmark import (
    BenchmarkConfig,
    PyTorchConfig,
    OVConfig,
    ORTConfig,
    IPEXConfig,
    ProcessConfig,
    InferenceConfig,
    Benchmark,
)
from optimum_benchmark.logging_utils import setup_logging

os.environ["LOG_TO_FILE"] = "0"
os.environ["LOG_LEVEL"] = "INFO"
setup_logging(level="INFO", prefix="MAIN-PROCESS")

DEVICE = "cpu"
BACKENDS = ["pytorch", "onnxruntime", "openvino", "ipex"]

CHOSEN_MODELS = ["bert-base-uncased", "gpt2"]
CHOSEN_TASKS = (
    set(TASKS_TO_OVMODEL.keys())
    & set(TASKS_TO_ORTMODELS.keys())
    & set(TASKS_TO_IPEXMODEL.keys())
    & set(TASKS_TO_MODEL_LOADERS.keys())
)


def run_benchmark(kwargs, oauth_token: gr.OAuthToken):
    if oauth_token.token is None:
        return "You must be logged in to use this space"

    username = whoami(oauth_token.token)["name"]
    create_repo(
        f"{username}/benchmarks",
        token=oauth_token.token,
        repo_type="dataset",
        exist_ok=True,
    )

    configs = {
        "process": {},
        "inference": {},
        "onnxruntime": {},
        "openvino": {},
        "pytorch": {},
        "ipex": {},
    }

    for key, value in kwargs.items():
        if key.label == "model":
            model = value
        elif key.label == "task":
            task = value
        elif "." in key.label:
            backend, argument = key.label.split(".")
            configs[backend][argument] = value
        else:
            continue

    process_config = ProcessConfig(**configs.pop("process"))
    inference_config = InferenceConfig(**configs.pop("inference"))

    configs["onnxruntime"] = ORTConfig(
        task=task,
        model=model,
        device=DEVICE,
        **configs["onnxruntime"],
    )
    configs["openvino"] = OVConfig(
        task=task,
        model=model,
        device=DEVICE,
        **configs["openvino"],
    )
    configs["pytorch"] = PyTorchConfig(
        task=task,
        model=model,
        device=DEVICE,
        **configs["pytorch"],
    )
    configs["ipex"] = IPEXConfig(
        task=task,
        model=model,
        device=DEVICE,
        **configs["ipex"],
    )

    for backend in configs:
        benchmark_name = (
            f"{model}-{task}-{backend}-{time.strftime('%Y-%m-%d-%H-%M-%S')}"
        )
        benchmark_config = BenchmarkConfig(
            name=benchmark_name,
            launcher=process_config,
            scenario=inference_config,
            backend=configs[backend],
        )
        benchmark_report = Benchmark.run(benchmark_config)
        benchmark = Benchmark(config=benchmark_config, report=benchmark_report)
        benchmark.push_to_hub(
            repo_id=f"{username}/benchmarks",
            subfolder=benchmark_name,
            token=oauth_token.token,
        )

    return f"🚀 Benchmark {benchmark_name} has been pushed to {username}/benchmarks"


with gr.Blocks() as demo:
    # add login button
    gr.LoginButton(min_width=250)

    # add image
    gr.Markdown(
        """<img src="https://huggingface.co/spaces/optimum/optimum-benchmark-ui/resolve/main/huggy_bench.png" style="display: block; margin-left: auto; margin-right: auto; width: 30%;">"""
    )

    # title text
    gr.Markdown("<h1 style='text-align: center'>🤗 Optimum-Benchmark Interface 🏋️</h1>")

    # explanation text
    gr.HTML(
        "<h3 style='text-align: center'>"
        "Zero code Gradio interface of "
        "<a href='https://github.com/huggingface/optimum-benchmark.git'>"
        "Optimum-Benchmark"
        "</a>"
        "<br>"
        "</h3>"
    )

    model = gr.Dropdown(
        label="model",
        choices=CHOSEN_MODELS,
        value="bert-base-uncased",
        info="Model to run the benchmark on.",
    )
    task = gr.Dropdown(
        label="task",
        choices=CHOSEN_TASKS,
        value="feature-extraction",
        info="Task to run the benchmark on.",
    )

    with gr.Row():
        with gr.Accordion(label="Process Config", open=False, visible=True):
            process_config = get_process_config()

    with gr.Row():
        with gr.Accordion(label="PyTorch Config", open=True, visible=True):
            pytorch_config = get_pytorch_config()
        with gr.Accordion(label="OpenVINO Config", open=True, visible=True):
            openvino_config = get_openvino_config()
        with gr.Accordion(label="OnnxRuntime Config", open=True, visible=True):
            onnxruntime_config = get_onnxruntime_config()

    with gr.Row():
        with gr.Accordion(label="Scenario Config", open=False, visible=True):
            inference_config = get_inference_config()

    button = gr.Button(value="Run Benchmark", variant="primary")

    html_output = gr.HTML()

    button.click(
        fn=run_benchmark,
        inputs={
            task,
            model,
            *process_config.values(),
            *inference_config.values(),
            *onnxruntime_config.values(),
            *openvino_config.values(),
            *pytorch_config.values(),
        },
        outputs=[html_output],
        concurrency_limit=1,
    )


demo.queue(max_size=10).launch()