import os
from typing import List

import pandas as pd

from .utils import process_kernels, process_quantizations

DATASET_DIRECTORY = "dataset"

COLUMNS_MAPPING = {
    "config.name": "Experiment 🧪",
    "config.backend.model": "Model 🤗",
    # primary measurements
    "report.prefill.latency.p50": "Prefill (s)",
    "report.per_token.latency.p50": "Per Token (s)",
    "report.decode.throughput.value": "Decode (tokens/s)",
    "report.decode.efficiency.value": "Energy (tokens/kWh)",
    # deployment settings
    "config.backend.name": "Backend 🏭",
    "config.backend.torch_dtype": "Precision 📥",
    "quantization": "Quantization 🗜️",
    "attention": "Attention 👁️",
    "kernel": "Kernel ⚛️",
    # additional information
    "architecture": "Architecture 🏛️",
    "prefill+decode": "End-to-End (s)",
    "Average ⬆️": "Open LLM Score (%)",
    "#Params (B)": "Params (B)",
}

CUDA_COLUMNS_MAPPING = COLUMNS_MAPPING | {
    "report.decode.memory.max_allocated": "Memory (MB)",
}

INTEL_COLUMNS_MAPPING = COLUMNS_MAPPING | {
    "report.decode.memory.max_ram": "Memory (MB)",
}

SORTING_COLUMNS = ["Open LLM Score (%)", "Decode (tokens/s)", "Prefill (s)"]
SORTING_ASCENDING = [False, True, False]


def get_raw_llm_perf_df(
    machine: str, subsets: List[str], backends: List[str], hardware_type: str
):
    dfs = []
    for subset in subsets:
        for backend in backends:
            try:
                url = f"hf://datasets/optimum-benchmark/llm-perf-leaderboard/perf-df-{backend}-{hardware_type}-{subset}-{machine}.csv"
                dfs.append(
                    pd.read_csv(
                        url
                    )
                )
            except Exception:
                print("Dataset not found for:")
                print(f"  • Backend: {backend}")
                print(f"  • Subset: {subset}")
                print(f"  • Machine: {machine}")
                print(f"  • Hardware Type: {hardware_type}")
                url = f"https://huggingface.co/datasets/optimum-benchmark/llm-perf-leaderboard/blob/main/perf-df-{backend}-{hardware_type}-{subset}-{machine}.csv"
                print(f"  • URL: {url}")

    if len(dfs) == 0:
        raise ValueError(
            f"No datasets found for machine {machine}, check your hardware.yml config file or your datatset on huggingface"
        )

    perf_df = pd.concat(dfs)
    llm_df = pd.read_csv(
        "hf://datasets/optimum-benchmark/llm-perf-leaderboard/llm-df.csv"
    )

    llm_perf_df = pd.merge(
        llm_df, perf_df, left_on="Model", right_on="config.backend.model"
    )

    return llm_perf_df


def processed_llm_perf_df(llm_perf_df, hardware_type: str):
    # some assertions
    assert llm_perf_df["config.scenario.input_shapes.batch_size"].nunique() == 1
    assert llm_perf_df["config.scenario.input_shapes.sequence_length"].nunique() == 1
    assert llm_perf_df["config.scenario.generate_kwargs.max_new_tokens"].nunique() == 1
    assert llm_perf_df["config.scenario.generate_kwargs.min_new_tokens"].nunique() == 1
    # fix couple stuff
    llm_perf_df.dropna(subset=["report.decode.latency.p50"], inplace=True)
    llm_perf_df["config.name"] = llm_perf_df["config.name"].str.replace(
        "flash_attention_2", "fa2"
    )
    llm_perf_df["prefill+decode"] = (
        llm_perf_df["report.prefill.latency.p50"]
        + (llm_perf_df["report.decode.latency.p50"])
    )
    # llm_perf_df["architecture"] = llm_perf_df["config.backend.model"].apply(
    #     process_architectures
    # )
    llm_perf_df["architecture"] = llm_perf_df["Architecture"]
    llm_perf_df["attention"] = (
        llm_perf_df["config.backend.attn_implementation"]
        .str.replace("flash_attention_2", "FAv2")
        .str.replace("eager", "Eager")
        .str.replace("sdpa", "SDPA")
    )
    llm_perf_df["quantization"] = llm_perf_df.apply(process_quantizations, axis=1)
    llm_perf_df["kernel"] = llm_perf_df.apply(process_kernels, axis=1)
    # round numerical columns
    llm_perf_df = llm_perf_df.round(
        {
            "report.prefill.latency.p50": 3,
            "report.decode.latency.p50": 3,
            "report.decode.throughput.value": 3,
            "report.decode.efficiency.value": 3,
            "report.decode.memory.max_allocated": 3,
            "report.decode.memory.max_ram": 3,
            "Average ⬆️": 3,
            "prefill+decode": 3,
            "#Params (B)": 3,
        }
    )

    # filter columns
    if hardware_type == "cuda":
        llm_perf_df = llm_perf_df[list(CUDA_COLUMNS_MAPPING.keys())]
        llm_perf_df.rename(columns=CUDA_COLUMNS_MAPPING, inplace=True)
    elif hardware_type == "cpu":
        llm_perf_df = llm_perf_df[list(INTEL_COLUMNS_MAPPING.keys())]
        llm_perf_df.rename(columns=INTEL_COLUMNS_MAPPING, inplace=True)
    else:
        raise ValueError(f"Hardware type {hardware_type} not supported")

    # sort by metric
    llm_perf_df.sort_values(
        by=SORTING_COLUMNS,
        ascending=SORTING_ASCENDING,
        inplace=True,
    )

    assert llm_perf_df["Memory (MB)"].notna().any(), "The dataset should contain at least one memory value, otherwise this implies that all the benchmarks have failed (contains only a traceback)"
    assert llm_perf_df.columns.is_unique, "All columns should be unique"

    return llm_perf_df


def get_llm_perf_df(
    machine: str, subsets: List[str], backends: List[str], hardware_type: str
):
    if not os.path.exists(DATASET_DIRECTORY):
        os.makedirs(DATASET_DIRECTORY)

    if os.path.exists(f"{DATASET_DIRECTORY}/llm-perf-leaderboard-{machine}.csv"):
        llm_perf_df = pd.read_csv(
            f"{DATASET_DIRECTORY}/llm-perf-leaderboard-{machine}.csv"
        )
    else:
        print(f"Dataset machine {machine} not found, downloading...")
        llm_perf_df = get_raw_llm_perf_df(machine, subsets, backends, hardware_type)
        llm_perf_df = processed_llm_perf_df(llm_perf_df, hardware_type)
        llm_perf_df.to_csv(
            f"{DATASET_DIRECTORY}/llm-perf-leaderboard-{machine}.csv", index=False
        )

    return llm_perf_df