import gradio as gr

import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi

from label_dicts import CAP_NUM_DICT, CAP_LABEL_NAMES

HF_TOKEN = os.environ["hf_read"]

languages = [
    "English",
    "Multilingual"
]

domains = {
    "media": "media",
    "social media": "social",
    "parliamentary speech": "parlspeech",
    "legislative documents": "legislative",
    "executive speech": "execspeech",
    "executive order": "execorder",
    "party programs": "party",
    "judiciary": "judiciary",
    "budget": "budget",
    "public opinion": "publicopinion",
    "local government agenda": "localgovernment"
}

def check_huggingface_path(checkpoint_path: str):
    try:
        hf_api = HfApi(token=HF_TOKEN)
        hf_api.model_info(checkpoint_path, token=HF_TOKEN)
        return True
    except:
        return False

def build_huggingface_path(language: str, domain: str):
    language = language.lower()
    base_path = "xlm-roberta-large"
    lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
    lang_path = f"poltextlab/{base_path}-{language}-cap-v3"

    path_map = {
        "L": lang_path,
        "L-D": lang_domain_path,
        "X": lang_domain_path,
    }
    value = None

    try:
        lang_domain_table = pd.read_csv("language_domain_models.csv")
        lang_domain_table["language"] = lang_domain_table["language"].str.lower()
        lang_domain_table.columns = lang_domain_table.columns.str.lower()
        # get the row for the language and them get the value from the domain column
        row = lang_domain_table[(lang_domain_table["language"] == language)]
        tmp = row.get(domain)
        if not tmp.empty:
            value = tmp.iloc[0]
    except (AttributeError, FileNotFoundError):
        value = None

    if language == 'english':
        model_path = lang_path
    else:
        model_path = "poltextlab/xlm-roberta-large-pooled-cap"

    if check_huggingface_path(model_path):
        return model_path
    else:
        return "poltextlab/xlm-roberta-large-pooled-cap"

def predict(text, model_id, tokenizer_id):
    device = torch.device("cpu")
    model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    inputs = tokenizer(text,
                       max_length=256,
                       truncation=True,
                       padding="do_not_pad",
                       return_tensors="pt").to(device)
    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
    output_pred = {f"[{CAP_NUM_DICT[i]}] {CAP_LABEL_NAMES[CAP_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
    output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
    return output_pred, output_info

def predict_cap(text, language, domain):
    domain = domains[domain]
    model_id = build_huggingface_path(language, domain)
    tokenizer_id = "xlm-roberta-large"
    return predict(text, model_id, tokenizer_id)

demo = gr.Interface(
    fn=predict_cap,
    inputs=[gr.Textbox(lines=6, label="Input"),
            gr.Dropdown(languages, label="Language"),
            gr.Dropdown(domains.keys(), label="Domain")],
    outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])