import requests
import json
import pandas as pd
from tqdm.auto import tqdm
import streamlit as st
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.repocard import metadata_load

aliases_lang = {"sv": "sv-SE"}
cer_langs = ["ja", "zh-CN", "zh-HK", "zh-TW"]
with open("languages.json") as f:
    lang2name = json.load(f)
suggested_datasets = [
    "librispeech_asr",
    "mozilla-foundation/common_voice_8_0",
    "mozilla-foundation/common_voice_11_0",
    "speech-recognition-community-v2/eval_data",
    "facebook/multilingual_librispeech"
]


def make_clickable(model_name):
    link = "https://huggingface.co/" + model_name
    return f'<a target="_blank" href="{link}">{model_name}</a>'


def get_model_ids():
    api = HfApi()
    models = api.list_models(filter="hf-asr-leaderboard")
    model_ids = [x.modelId for x in models]
    return model_ids


def get_metadata(model_id):
    try:
        readme_path = hf_hub_download(model_id, filename="README.md")
        return metadata_load(readme_path)
    except:
        # 404 README.md not found
        print(f"Model id: {model_id} is not great!")
        return None
    


def parse_metric_value(value):
    if isinstance(value, str):
        "".join(value.split("%"))
        try:
            value = float(value)
        except:  # noqa: E722
            value = None
    elif isinstance(value, float) and value < 1.1:
        # assuming that WER is given in 0.xx format
        value = 100 * value
    elif isinstance(value, list):
        if len(value) > 0:
            value = value[0]
        else:
            value = None
    value = round(value, 2) if value is not None else None
    return value


def parse_metrics_rows(meta):
    if "model-index" not in meta or "language" not in meta:
        return None
    for result in meta["model-index"][0]["results"]:
        if "dataset" not in result or "metrics" not in result:
            continue
        dataset = result["dataset"]["type"]
        if "args" in result["dataset"] and "language" in result["dataset"]["args"]:
            lang = result["dataset"]["args"]["language"]
        else:
            lang = meta["language"]
            lang = lang[0] if isinstance(lang, list) else lang
        lang = aliases_lang[lang] if lang in aliases_lang else lang
        config = result["dataset"]["config"] if "config" in result["dataset"] else lang
        split = result["dataset"]["split"] if "split" in result["dataset"] else None
        row = {
            "dataset": dataset,
            "lang": lang,
            "config": config,
            "split": split
        }
        for metric in result["metrics"]:
            type = metric["type"].lower().strip()
            if type not in ["wer", "cer"]:
                continue
            value = parse_metric_value(metric["value"])
            if value is None:
                continue
            if type not in row or value < row[type]:
                # overwrite the metric if the new value is lower (e.g. with LM)
                row[type] = value
        if "wer" in row or "cer" in row:
            yield row


@st.cache(ttl=600)
def get_data():
    data = []
    model_ids = get_model_ids()
    for model_id in tqdm(model_ids):
        meta = get_metadata(model_id)
        if meta is None:
            continue
        for row in parse_metrics_rows(meta):
            if row is None:
                continue
            row["model_id"] = model_id
            data.append(row)
    return pd.DataFrame.from_records(data)


def sort_datasets(datasets):
    # 1. sort by name
    datasets = sorted(datasets)
    # 2. bring the suggested datasets to the top and append the rest
    datasets = sorted(
        datasets,
        key=lambda dataset_id: suggested_datasets.index(dataset_id)
        if dataset_id in suggested_datasets
        else len(suggested_datasets),
    )
    return datasets


@st.cache(ttl=600)
def generate_dataset_info(datasets):
    msg = """
    The models have been trained and/or evaluated on the following datasets:
    """
    for dataset_id in datasets:
        if dataset_id in suggested_datasets:
            msg += f"* [{dataset_id}](https://hf.co/datasets/{dataset_id}) *(recommended)*\n"
        else:
            msg += f"* [{dataset_id}](https://hf.co/datasets/{dataset_id})\n"

    msg = "\n".join([line.strip() for line in msg.split("\n")])
    return msg


dataframe = get_data()
dataframe = dataframe.fillna("")

st.sidebar.image("logo.png", width=200)

st.markdown("# The 🤗 Speech Bench")

st.markdown(
    f"This is a leaderboard of **{dataframe['model_id'].nunique()}** speech recognition models "
    f"and **{dataframe['dataset'].nunique()}** datasets.\n\n"
    "⬅ Please select the language you want to find a model for from the dropdown on the left."
)

lang = st.sidebar.selectbox(
    "Language",
    sorted(dataframe["lang"].unique(), key=lambda key: lang2name.get(key, key)),
    format_func=lambda key: lang2name.get(key, key),
    index=0,
)
lang_df = dataframe[dataframe.lang == lang]

sorted_datasets = sort_datasets(lang_df["dataset"].unique())

lang_name = lang2name[lang] if lang in lang2name else ""
num_models = len(lang_df["model_id"].unique())
num_datasets = len(lang_df["dataset"].unique())
text = f"""
For the `{lang}` ({lang_name}) language, there are currently `{num_models}` model(s) 
trained on `{num_datasets}` dataset(s) available for `automatic-speech-recognition`.
"""
st.markdown(text)

st.sidebar.markdown("""
Choose the dataset that is most relevant to your task and select it from the dropdown below:
""")

dataset = st.sidebar.selectbox(
    "Dataset",
    sorted_datasets,
    index=0,
)
dataset_df = lang_df[lang_df.dataset == dataset]

text = generate_dataset_info(sorted_datasets)
st.sidebar.markdown(text)

# sort by WER or CER depending on the language
metric_col = "cer" if lang in cer_langs else "wer"
if dataset_df["config"].nunique() > 1:
    # if there are more than one dataset config
    dataset_df = dataset_df[["model_id", "config", metric_col]]
    dataset_df = dataset_df.pivot_table(index=['model_id'], columns=["config"], values=[metric_col])
    dataset_df = dataset_df.reset_index(level=0)
else:
    dataset_df = dataset_df[["model_id", metric_col]]
dataset_df.sort_values(dataset_df.columns[-1], inplace=True)
dataset_df = dataset_df.fillna("")

dataset_df.rename(
    columns={
        "model_id": "Model",
        "wer": "WER (lower is better)",
        "cer": "CER (lower is better)",
    },
    inplace=True,
)

st.markdown(
    "Please click on the model's name to be redirected to its model card which includes documentation and examples on how to use it."
)

# display the model ranks
dataset_df = dataset_df.reset_index(drop=True)
dataset_df.index += 1

# turn the model ids into clickable links
dataset_df["Model"] = dataset_df["Model"].apply(make_clickable)

table_html = dataset_df.to_html(escape=False)
table_html = table_html.replace("<th>", '<th align="left">')  # left-align the headers
st.write(table_html, unsafe_allow_html=True)

if lang in cer_langs:
    st.markdown(
        "---\n\* **CER** is [Char Error Rate](https://huggingface.co/metrics/cer)"
    )
else:
    st.markdown(
        "---\n\* **WER** is [Word Error Rate](https://huggingface.co/metrics/wer)"
    )

st.markdown(
    "Want to beat the Leaderboard? Don't see your speech recognition model show up here? "
    "Simply add the `hf-asr-leaderboard` tag to your model card alongside your evaluation metrics. "
    "Try our [Metrics Editor](https://huggingface.co/spaces/huggingface/speech-bench-metrics-editor) to get started!"
)