import gradio as gr
from transformers import pipeline, AutoTokenizer
from turkish_lm_tuner import T5ForClassification
import os

# Retrieve Hugging Face authentication token from environment variables
hf_auth_token = os.getenv('HF_AUTH_TOKEN')
print(hf_auth_token)

# Example inputs for the different tasks
binary_classification_examples = [["Yahudi terörüne karşı protestolar kararlılıkla devam ediyor."]]
categorization_examples = [["Ermeni zulmü sırasında hayatını kaybeden kadınlar anısına dikilen anıt ziyarete açıldı."]]
target_detection_examples = [["Dün 5 bin suriyeli enik doğmuştur zaten Türkiyede aq 5 bin suriyelinin gitmesi çok çok az"]]

# Application description and citation placeholder
APP_DESCRIPTION = """
## Hate Speech Detection in Turkish News

This tool performs hate speech detection across several tasks, including binary classification, categorization, and target detection. Choose a model and input text to analyze its hatefulness, categorize it, or detect targets of hate speech.
"""

APP_CITATION = """
For citation, please refer to the tool's documentation.
"""

def inference_t5(input_text, selected_model):
    model = T5ForClassification.from_pretrained("gokceuludogan/turna_tr_hateprint_w0.1_new_") #_b128")
    tokenizer = AutoTokenizer.from_pretrained("gokceuludogan/turna_tr_hateprint_w0.1_new_") #_b128")
    return model(**tokenizer(test_texts, return_tensors='pt')).logits

    
# Functions for model-based tasks
def perform_binary_classification(input_text, selected_model):
    if (selected_model is not None) and ('turna' in selected_model):
        return inference_t5(input_text, selected_model)
        
    model = pipeline(model=f'gokceuludogan/{selected_model}')
    return model(input_text)[0]

def perform_categorization(input_text, selected_model):
    model = pipeline(model=f'gokceuludogan/{selected_model}')
    return model(input_text)[0]

def perform_target_detection(input_text):
    model = pipeline(model='gokceuludogan/turna_generation_tr_hateprint_target')
    return model(input_text)[0]['generated_text']

def perform_multi_detection(input_text):
    model = pipeline(model='gokceuludogan/turna_generation_tr_hateprint_multi')
    return model(input_text)[0]['generated_text']

# Gradio interface
with gr.Blocks(theme="abidlabs/Lime") as hate_speech_demo:

    # Main description
    with gr.Tab("About"):
        gr.Markdown(APP_DESCRIPTION)

    # Binary Classification Tab
    with gr.Tab("Binary Classification"):
        gr.Markdown("Analyze the hatefulness of a given text using selected models.")
        with gr.Column():
            model_choice_binary = gr.Radio(
                choices=[
                    "turna_tr_hateprint", 
                    "turna_tr_hateprint_5e6_w0.1_", 
                    "berturk_tr_hateprint_w0.1", 
                    "berturk_tr_hateprint_w0.1_b128"
                ], 
                label="Select Model", 
                value="turna_tr_hateprint"
            )
            text_input_binary = gr.Textbox(label="Input Text")
            classify_button = gr.Button("Analyze")
            classification_output = gr.Textbox(label="Classification Result")
            classify_button.click(
                perform_binary_classification, 
                inputs=[text_input_binary, model_choice_binary], 
                outputs=classification_output
            )
            gr.Examples(
                examples=binary_classification_examples,
                inputs=[text_input_binary, model_choice_binary],
                outputs=classification_output,
                fn=perform_binary_classification
            )

    # Hate Speech Categorization Tab
    with gr.Tab("Hate Speech Categorization"):
        gr.Markdown("Categorize the hate speech type in the provided text.")
        with gr.Column():
            model_choice_category = gr.Radio(
                choices=["berturk_tr_hateprint_cat_w0.1_b128", "berturk_tr_hateprint_cat_w0.1"], 
                label="Select Model"
            )
            text_input_category = gr.Textbox(label="Input Text")
            categorize_button = gr.Button("Categorize")
            categorization_output = gr.Textbox(label="Categorization Result")
            categorize_button.click(
                perform_categorization, 
                inputs=[text_input_category, model_choice_category], 
                outputs=categorization_output
            )
            gr.Examples(
                examples=categorization_examples,
                inputs=[text_input_category, model_choice_category],
                outputs=categorization_output,
                fn=perform_categorization
            )

    # Target Detection Tab
    with gr.Tab("Target Detection"):
        gr.Markdown("Detect the targets of hate speech in the provided text.")
        with gr.Column():
            text_input_target = gr.Textbox(label="Input Text")
            target_button = gr.Button("Detect Targets")
            target_output = gr.Textbox(label="Target Detection Result")
            target_button.click(
                perform_target_detection, 
                inputs=[text_input_target], 
                outputs=target_output
            )
            gr.Examples(
                examples=target_detection_examples,
                inputs=[text_input_target],
                outputs=target_output,
                fn=perform_target_detection
            )

    # Multi Detection Tab
    with gr.Tab("Multi Detection"):
        gr.Markdown("Detect hate speech, its category, and its targets in the text.")
        with gr.Column():
            text_input_multi = gr.Textbox(label="Input Text")
            multi_button = gr.Button("Detect All")
            multi_output = gr.Textbox(label="Multi Detection Result")
            multi_button.click(
                perform_multi_detection, 
                inputs=[text_input_multi], 
                outputs=multi_output
            )
            gr.Examples(
                examples=target_detection_examples,
                inputs=[text_input_multi],
                outputs=multi_output,
                fn=perform_multi_detection
            )

    # Citation Section
    gr.Markdown(APP_CITATION)

# Launch the application
hate_speech_demo.launch()