import gradio as gr

import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi
from huggingface_hub.utils._errors import RepositoryNotFoundError

from label_dicts import CAP_NUM_DICT, CAP_LABEL_NAMES

HF_TOKEN = os.environ["hf_read"]

languages = [
    "Danish",
    "Dutch",
    "English",
    "French",
    "German",
    "Hungarian",
    "Italian",
    "Polish",
    "Portuguese",
    "Spanish",
    "Czech",
    "Slovak",
    "Norwegian"
]

domains = {
    "media": "media",
    "social media": "social",
    "parliamentary speech": "parlspeech",
    "legislative documents": "legislative",
    "executive speech": "execspeech",
    "executive order": "execorder",
    "party programs": "party",
    "judiciary": "judiciary",
    "budget": "budget",
    "public opinion": "publicopinion",
    "local government agenda": "localgovernment"
}

def check_huggingface_path(checkpoint_path: str):
    try:
        hf_api = HfApi(token=HF_TOKEN)
        hf_api.model_info(checkpoint_path, token=HF_TOKEN)
        return True
    except RepositoryNotFoundError:
        return False

def build_huggingface_path(language: str, domain: str):
    base_path = "xlm-roberta-large"
    lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
    lang_path = f"poltextlab/{base_path}-{language}-cap-v3"

    path_map = {
        "L": lang_path,
        "L-D": lang_domain_path,
        "X": lang_domain_path,
    }
    value = None

    try:
        lang_domain_table = pd.read_csv("language_domain_models.csv")
        lang_domain_table["language"] = lang_domain_table["language"].str.lower()
        lang_domain_table.columns = lang_domain_table.columns.str.lower()
        # get the row for the language and them get the value from the domain column
        row = lang_domain_table[(lang_domain_table["language"] == language)]
        tmp = row.get(domain)
        if not tmp.empty:
            value = tmp.iloc[0]
    except (AttributeError, FileNotFoundError):
        value = None

    if value and value in path_map:
        model_path = path_map[value]
        if check_huggingface_path(model_path):
            # if the model is available on Huggingface, return the path
            return model_path
        else:
            # if the model is not available on Huggingface, look for other models
            filtered_path_map = {k: v for k, v in path_map.items() if k != value}
            for k, v in filtered_path_map.items():
                if check_huggingface_path(v):
                    return v
    elif check_huggingface_path(lang_domain_path):
        return lang_domain_path
    elif check_huggingface_path(lang_path):
        return lang_path
    else:
        return "poltextlab/xlm-roberta-large-pooled-cap"

def predict(text, model_id, tokenizer_id):
    device = torch.device("cpu")
    model = AutoModelForSequenceClassification.from_pretrained(model_id, token=HF_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
    model.to(device)

    inputs = tokenizer(text,
                       max_length=512,
                       truncation=True,
                       padding="do_not_pad",
                       return_tensors="pt").to(device)
    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
    output_pred = {f"[{CAP_NUM_DICT[i]}] {CAP_LABEL_NAMES[CAP_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
    output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
    return output_pred, output_info

def predict_cap(text, language, domain):
    domain = domains[domain]
    model_id = build_huggingface_path(language, domain)
    tokenizer_id = "xlm-roberta-large"
    return predict(text, model_id, tokenizer_id)

demo = gr.Interface(
    fn=predict_cap,
    inputs=[gr.Textbox(lines=6, label="Input"),
            gr.Dropdown(languages, label="Language"),
            gr.Dropdown(domains.keys(), label="Domain")],
    outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])