import subprocess
import gradio as gr
import pandas as pd
from ansi2html import Ansi2HTMLConverter

ansi2html_converter = Ansi2HTMLConverter(inline=True)


def run_benchmark(kwargs):
    for key, value in kwargs.copy().items():
        if key.label == "Compare to Baseline":
            baseline = value
            kwargs.pop(key)
        elif key.label == "experiment_name":
            experiment_name = value
            kwargs.pop(key)

        elif key.label == "model":
            model = value
            kwargs.pop(key)
        elif key.label == "task":
            task = value
            kwargs.pop(key)
        elif key.label == "device":
            device = value
            kwargs.pop(key)
        elif key.label == "backend":
            backend = value
            kwargs.pop(key)
        elif key.label == "benchmark":
            benchmark = value
            kwargs.pop(key)
        else:
            continue

    if baseline:
        baseline_arguments = [
            "optimum-benchmark",
            "--config-dir",
            "./configs",
            "--config-name",
            "base_config",
            f"backend=pytorch",
            f"task={task}",
            f"model={model}",
            f"device={device}",
            f"benchmark={benchmark}",
            f"experiment_name=baseline",
        ]
        for component, value in kwargs.items():
            if f"{benchmark}." in component.label:
                label = component.label.replace(f"{benchmark}.", "benchmark.")
                if isinstance(component, gr.Dataframe):
                    for sub_key, sub_value in zip(component.headers, value[0]):
                        baseline_arguments.append(f"++{label}.{sub_key}={sub_value}")
                else:
                    baseline_arguments.append(f"{label}={value}")

        # yield from run_experiment(baseline_arguments) but get the return code
        baseline_return_code, html_text = yield from run_experiment(baseline_arguments, "")
        if baseline_return_code is not None and baseline_return_code != 0:
            yield gr.update(value=html_text), gr.update(interactive=True), gr.update(visible=False)
            return
    else:
        html_text = ""

    arguments = [
        "optimum-benchmark",
        "--config-dir",
        "./configs",
        "--config-name",
        "base_config",
        f"task={task}",
        f"model={model}",
        f"device={device}",
        f"backend={backend}",
        f"benchmark={benchmark}",
        f"experiment_name={experiment_name}",
    ]
    for component, value in kwargs.items():
        if f"{backend}." in component.label or f"{benchmark}." in component.label:
            label = component.label.replace(f"{backend}.", "backend.").replace(f"{benchmark}.", "benchmark.")

            if isinstance(component, gr.Dataframe):
                for sub_key, sub_value in zip(component.headers, value[0]):
                    arguments.append(f"++{label}.{sub_key}={sub_value}")
            else:
                arguments.append(f"{label}={value}")

    return_code, html_text = yield from run_experiment(arguments, html_text)
    if return_code is not None and return_code != 0:
        yield gr.update(value=html_text), gr.update(interactive=True), gr.update(visible=False)
        return

    if baseline:
        baseline_table = pd.read_csv(f"runs/baseline/{benchmark}_results.csv", index_col=0)
        table = pd.read_csv(f"runs/{experiment_name}/{benchmark}_results.csv", index_col=0)
        # concat tables
        table = pd.concat([baseline_table, table], axis=0)
        table = postprocess_table(table, experiment_name)
    else:
        table = pd.read_csv(f"runs/{experiment_name}/{benchmark}_results.csv", index_col=0)

    table_update = gr.update(visible=True, value={"headers": list(table.columns), "data": table.values.tolist()})
    yield gr.update(value=html_text), gr.update(interactive=True), table_update
    return


def run_experiment(args, html_text=""):
    command = "<br>".join(args)
    html_text += f"<h3>Running command:</h3>{command}"
    yield gr.update(value=html_text), gr.update(interactive=False), gr.update(visible=False)

    # stream subprocess output
    process = subprocess.Popen(
        args,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
    )

    curr_ansi_text = ""
    for ansi_line in iter(process.stdout.readline, ""):
        # stream process output to stdout
        print(ansi_line, end="")
        # skip torch.distributed.nn.jit.instantiator messages
        if "torch.distributed.nn.jit.instantiator" in ansi_line:
            continue
        # process download messages
        if "Downloading " in curr_ansi_text and "Downloading " in ansi_line:
            curr_ansi_text = curr_ansi_text.split("\n")[:-2]
            print(curr_ansi_text)
            curr_ansi_text.append(ansi_line)
            curr_ansi_text = "\n".join(curr_ansi_text)
        else:
            # append line to ansi text
            curr_ansi_text += ansi_line
        # convert ansi to html
        curr_html_text = ansi2html_converter.convert(curr_ansi_text)
        # stream html output to gradio
        cumul_html_text = html_text + "<br><h3>Streaming logs:</h3>" + curr_html_text
        yield gr.update(value=cumul_html_text), gr.update(interactive=False), gr.update(visible=False)

    return process.returncode, cumul_html_text


def postprocess_table(table, experiment_name):
    table["experiment_name"] = ["baseline", experiment_name]
    table = table.set_index("experiment_name")
    table.reset_index(inplace=True)
    if "forward.latency(s)" in table.columns:
        table["forward.latency.speedup(%)"] = (
            1 - table["forward.latency(s)"] / table["forward.latency(s)"].iloc[0]
        ) * 100
        table["forward.latency.speedup(%)"] = table["forward.latency.speedup(%)"].round(2)

    if "forward.throughput(samples/s)" in table.columns:
        table["forward.throughput.speedup(%)"] = (
            table["forward.throughput(samples/s)"] / table["forward.throughput(samples/s)"].iloc[0] - 1
        ) * 100
        table["forward.throughput.speedup(%)"] = table["forward.throughput.speedup(%)"].round(2)

    if "forward.peak_memory(MB)" in table.columns:
        table["forward.peak_memory.savings(%)"] = (
            1 - table["forward.peak_memory(MB)"] / table["forward.peak_memory(MB)"].iloc[0]
        ) * 100
        table["forward.peak_memory.savings(%)"] = table["forward.peak_memory.savings(%)"].round(2)

    if "generate.latency(s)" in table.columns:
        table["generate.latency.speedup(%)"] = (
            1 - table["generate.latency(s)"] / table["generate.latency(s)"].iloc[0]
        ) * 100
        table["generate.latency.speedup(%)"] = table["generate.latency.speedup(%)"].round(2)

    if "generate.throughput(tokens/s)" in table.columns:
        table["generate.throughput.speedup(%)"] = (
            table["generate.throughput(tokens/s)"] / table["generate.throughput(tokens/s)"].iloc[0] - 1
        ) * 100
        table["generate.throughput.speedup(%)"] = table["generate.throughput.speedup(%)"].round(2)

    if "generate.peak_memory(MB)" in table.columns:
        table["generate.peak_memory.savings(%)"] = (
            1 - table["generate.peak_memory(MB)"] / table["generate.peak_memory(MB)"].iloc[0]
        ) * 100
        table["generate.peak_memory.savings(%)"] = table["generate.peak_memory.savings(%)"].round(2)

    return table