import os
import whisper
import evaluate
from evaluate.utils import launch_gradio_widget
import gradio as gr
import torch
import pandas as pd
import random
import classify
from whisper.model import Whisper
from whisper.tokenizer import get_tokenizer
from transformers import pipeline, WhisperTokenizer


# pull in emotion detection
# --- Add element for specification
# pull in text classification
# --- Add custom labels
# --- Associate labels with radio elements
# add logic to initiate mock notificaiton when detected
# pull in misophonia-specific model

model_cache = {}


# static classes for now, but it would be best ot have the user select from multiple, and to enter their own
class_options = {
    "misophonia": ["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"]
}

pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large")
model = whisper.load_model("large")
tokenizer = get_tokenizer("large")

def slider_logic(slider):
    threshold = 0
    if slider == 1:
        threshold = .98
    elif slider == 2:
        threshold = .88
    elif slider == 3:
        threshold = .78
    elif slider == 4:
        threshold = .68
    elif slider == 5:
        threshold = .58
    else:
        threshold = []
    return threshold

# Create a Gradio interface with audio file and text inputs
def classify_toxicity(audio_file, selected_sounds, slider):
    # Transcribe the audio file using Whisper ASR
    # transcribed_text = pipe(audio_file)["text"]

    threshold = slider_logic(slider)
    # MODEL LINE model = whisper.load_model("large")
    # model = model_cache[model_name]
    # class_names = classify_anxiety.split(",")
    classify_anxiety = "misophonia"
    class_names_list = class_options.get(classify_anxiety, [])
    class_str = ""
    for elm in class_names_list:
        class_str += elm + ","
    #class_names = class_names_temp.split(",")
    class_names = class_str.split(",")
    print("class names ", class_names, "classify_anxiety ", classify_anxiety)
    
    # TOKENIZER LINE tokenizer = get_tokenizer("large")
    # tokenizer= WhisperTokenizer.from_pretrained("openai/whisper-large")

    internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
        model=model,
        class_names=class_names,
        # class_names=classify_anxiety,
        tokenizer=tokenizer,
    )
    audio_features = classify.calculate_audio_features(audio_file, model)
    average_logprobs = classify.calculate_average_logprobs(
        model=model,
        audio_features=audio_features,
        class_names=class_names,
        tokenizer=tokenizer,
    )
    average_logprobs -= internal_lm_average_logprobs
    scores = average_logprobs.softmax(-1).tolist()
    
    class_score_dict = {class_name: score for class_name, score in zip(class_names, scores)}
    matching_label_score = {}
    
    # Iterate through the selected sounds
    for selected_class_name in selected_sounds:
        if selected_class_name in class_score_dict:
            score = class_score_dict[selected_class_name]
            matching_label_score[selected_class_name] = score
            print("matching label score type is ", type(matching_label_score))
            
    highest_score = max(matching_label_score.values())
    highest_float = float(highest_score)

    if highest_score is not None and highest_float > threshold:
        affirm = "Threshold Exceeded, initiate intervention"
    else:
        affirm = " "
            
    # miso_label_dict = {label: score for label, score in classify_anxiety[0].items()}

    return class_score_dict, affirm
    
with gr.Blocks() as iface:
    with gr.Column():
        miso_sounds = gr.CheckboxGroup(["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"])
        sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
    with gr.Column():
        aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
        submit_btn = gr.Button(label="Run")
    with gr.Column():
        # out_val = gr.Textbox()
        out_class = gr.Label()
        out_text = gr.Textbox()
    submit_btn.click(fn=classify_toxicity, inputs=[aud_input, miso_sounds, sense_slider], outputs=[out_class, out_text])


iface.launch()