import gradio as gr
import torch
from minicons import cwe
from huggingface_hub import hf_hub_download
import os
import pandas as pd

from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams


def predict (Sentence, Word, LLM, Norm, Layer):
    models = {'BERT': 'bert-base-uncased',
              'ALBERT': 'albert-xxlarge-v2',
              'RoBERTa': 'roberta-base'}
    if Word not in Sentence: return "invalid input: word not in sentence"
    model_name_hf = LLM.lower()
    norm_name_hf = Norm.lower()
    lm = cwe.CWE(models[LLM])

    repo_id = "jwalanthi/semantic-feature-classifiers"
    subfolder = f"{model_name_hf}_models_all"
    name_hf = f"{model_name_hf}_to_{norm_name_hf}_layer{Layer}"

    model_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.ckpt", use_auth_token=os.environ['TOKEN'])
    label_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.txt", use_auth_token=os.environ['TOKEN'])

    model = FeatureNormPredictor.load_from_checkpoint(
        checkpoint_path=model_path,
        map_location=None
    )
    model.eval()

    with open (label_path, "r") as file:
        labels = [line.rstrip() for line in file.readlines()]

    data = (Sentence, Word)
    emb = lm.extract_representation(data, layer=Layer)
    pred = torch.nn.functional.relu(model(emb))
    pred = pred.squeeze(0)
    pred_list = pred.detach().numpy().tolist()

    df = pd.DataFrame({'feature':labels, 'value':pred_list})
    df = df[df['value'] > 0]
    df_sorted = df.sort_values(by='value', ascending=False)
    df_sorted = df_sorted.reset_index()

    Output = [row['feature']+'\t\t\t\t\t\t\t'+str(row['value']) for _, row in df_sorted.iterrows()]
    return "All Positive Predicted Values:\n"+"\n".join(Output)

demo = gr.Interface(
    fn=predict,
    inputs=[
        "text", 
        "text", 
        gr.Radio(["BERT", "ALBERT", "RoBERTa"]),
        gr.Radio(["Binder", "McRae", "Buchanan"]),
        gr.Slider(0,12, step=1)
    ],
    outputs=["text"],
)

demo.launch()

if __name__ == "__main__":
    demo.launch()