__all__ = ['block', 'make_clickable_model', 'make_clickable_user', 'get_submissions']

import gradio as gr
import pandas as pd
import json
import pdb
import tempfile
import re

from constants import *
from src.auto_leaderboard.model_metadata_type import ModelType

global data_component, filter_component

def validate_model_size(s):
    pattern = r'^\d+B$|^-$'
    if re.match(pattern, s):
        return s
    else:
        return '-'

def upload_file(files):
    file_paths = [file.name for file in files]
    return file_paths

def prediction_analyse(prediction_content):
    # pdb.set_trace()
    predictions = prediction_content.split("\n")

    # 读取 ground_truth JSON 文件
    with open("./file/SEED-Bench-1.json", "r") as file:
        ground_truth_data = json.load(file)["questions"]

    # 将 ground_truth 数据转换为以 question_id 为键的字典
    ground_truth = {item["question_id"]: item for item in ground_truth_data}

    # 初始化结果统计字典
    results = {i: {"correct": 0, "total": 0} for i in range(1, 13)}

    # 遍历 predictions,计算每个 question_type_id 的正确预测数和总预测数
    for prediction in predictions:
        # pdb.set_trace()
        prediction = prediction.strip()
        if not prediction:
            continue
        try:
            prediction = json.loads(prediction)
        except json.JSONDecodeError:
            print(f"Warning: Skipping invalid JSON data in line: {prediction}")
            continue
        question_id = prediction["question_id"]
        if question_id not in ground_truth:
            continue
        gt_item = ground_truth[question_id]
        question_type_id = gt_item["question_type_id"]

        if prediction["prediction"] == gt_item["answer"]:
            results[question_type_id]["correct"] += 1

        results[question_type_id]["total"] += 1
    
    return results

def prediction_analyse_v2(prediction_content):
    # pdb.set_trace()
    predictions = prediction_content.split("\n")

    # 读取 ground_truth JSON 文件
    with open("./file/SEED-Bench-2.json", "r") as file:
        ground_truth_data = json.load(file)["questions"]

    # 将 ground_truth 数据转换为以 question_id 为键的字典
    ground_truth = {item["question_id"]: item for item in ground_truth_data}

    # 初始化结果统计字典
    results = {i: {"correct": 0, "total": 0} for i in range(1, 28)}

    # 遍历 predictions,计算每个 question_type_id 的正确预测数和总预测数
    for prediction in predictions:
        # pdb.set_trace()
        prediction = prediction.strip()
        if not prediction:
            continue
        try:
            prediction = json.loads(prediction)
        except json.JSONDecodeError:
            print(f"Warning: Skipping invalid JSON data in line: {prediction}")
            continue
        question_id = prediction["question_id"]
        if question_id not in ground_truth:
            continue
        gt_item = ground_truth[question_id]
        question_type_id = gt_item["question_type_id"]

        if prediction["prediction"] == gt_item["answer"]:
            results[question_type_id]["correct"] += 1

        results[question_type_id]["total"] += 1
    
    return results


def add_new_eval(
    input_file,
    model_name_textbox: str,
    revision_name_textbox: str,
    model_type: str,
    model_link: str,
    model_size: str,
    benchmark_version: str,
    LLM_type: str,
    LLM_name_textbox: str,
    Evaluation_dimension: str,
    Evaluation_dimension_2: str,
    Evaluation_method: str

):
    if input_file is None:
        return "Error! Empty file!"
    else:
        model_size = validate_model_size(model_size)
        # v1 evaluation
        if benchmark_version == 'v1':
            content = input_file.decode("utf-8")
            prediction = prediction_analyse(content)
            csv_data = pd.read_csv(CSV_DIR)

            Start_dimension, End_dimension = 1, 13
            if Evaluation_dimension == 'Image':
                End_dimension = 10
            elif Evaluation_dimension == 'Video':
                Start_dimension = 10
            each_task_accuracy = {i: round(prediction[i]["correct"] / prediction[i]["total"] * 100, 1) if i >= Start_dimension and i < End_dimension else 0 for i in range(1, 13)}

            # count for average image\video\all
            total_correct_image = sum(prediction[i]["correct"] for i in range(1, 10))
            total_correct_video = sum(prediction[i]["correct"] for i in range(10, 13))

            total_image = sum(prediction[i]["total"] for i in range(1, 10))
            total_video = sum(prediction[i]["total"] for i in range(10, 13))

            if Evaluation_dimension != 'Video':
                average_accuracy_image = round(total_correct_image / total_image * 100, 1)
            else:
                average_accuracy_image = 0
            
            if Evaluation_dimension != 'Image':
                average_accuracy_video = round(total_correct_video / total_video * 100, 1)
            else:
                average_accuracy_video = 0
            
            if Evaluation_dimension == 'All':
                overall_accuracy = round((total_correct_image + total_correct_video) / (total_image + total_video) * 100, 1)
            else:
                overall_accuracy = 0

            if LLM_type == 'Other':
                LLM_name = LLM_name_textbox
            else:
                LLM_name = LLM_type
            
            if revision_name_textbox == '':
                col = csv_data.shape[0]
                model_name = model_name_textbox
            else:
                model_name = revision_name_textbox
                model_name_list = csv_data['Model']
                name_list = [name.split(']')[0][1:] for name in model_name_list]
                if revision_name_textbox not in name_list:
                    col = csv_data.shape[0]
                else:
                    col = name_list.index(revision_name_textbox)    
            
            if model_link == '':
                model_name = model_name  # no url
            else:
                model_name = '[' + model_name + '](' + model_link + ')'

            # add new data
            new_data = [
                model_type, 
                model_name, 
                LLM_name,
                model_size,
                Evaluation_method,
                overall_accuracy,
                average_accuracy_image,
                average_accuracy_video,
                each_task_accuracy[1],
                each_task_accuracy[2],
                each_task_accuracy[3],
                each_task_accuracy[4],
                each_task_accuracy[5],
                each_task_accuracy[6],
                each_task_accuracy[7],
                each_task_accuracy[8],
                each_task_accuracy[9],
                each_task_accuracy[10],
                each_task_accuracy[11],
                each_task_accuracy[12], 
                ]
            csv_data.loc[col] = new_data
            csv_data = csv_data.to_csv(CSV_DIR, index=False)
        # v2 evaluation
        else:
            content = input_file.decode("utf-8")
            prediction = prediction_analyse_v2(content)
            csv_data = pd.read_csv(CSV_V2_DIR)

            Start_dimension, End_dimension = 1, 28
            if Evaluation_dimension_2 == 'L1':
                End_dimension = 23
            elif Evaluation_dimension_2 == 'L2':
                End_dimension = 25
            elif Evaluation_dimension_2 == 'L3':
                End_dimension = 28
            # pdb.set_trace()
            each_task_accuracy = {i: round(prediction[i]["correct"] / prediction[i]["total"] * 100, 1) if i >= Start_dimension and i < End_dimension else 0 for i in range(1, 28)}
            average_p1 = round(sum(each_task_accuracy[key] for key in range(1,23)) / 22, 1)

            if Evaluation_dimension_2 == 'L2':
                average_p2 = round(sum(each_task_accuracy[key] for key in range(23,25)) / 2, 1)
                average_p3 = 0
            else:
                average_p2 = round(sum(each_task_accuracy[key] for key in range(23,25)) / 2, 1)
                average_p3 = round(sum(each_task_accuracy[key] for key in range(25,28)) / 3, 1)
            
            if LLM_type == 'Other':
                LLM_name = LLM_name_textbox
            else:
                LLM_name = LLM_type
            
            if revision_name_textbox == '':
                col = csv_data.shape[0]
                model_name = model_name_textbox
            else:
                model_name = revision_name_textbox
                model_name_list = csv_data['Model']
                name_list = [name.split(']')[0][1:] for name in model_name_list]
                if revision_name_textbox not in name_list:
                    col = csv_data.shape[0]
                else:
                    col = name_list.index(revision_name_textbox)    
            
            if model_link == '':
                model_name = model_name  # no url
            else:
                model_name = '[' + model_name + '](' + model_link + ')'

            # add new data
            new_data = [
                model_name, 
                LLM_name, 
                model_size,
                Evaluation_method,
                average_p1,
                average_p2,
                average_p3,
                each_task_accuracy[1],
                each_task_accuracy[2],
                each_task_accuracy[3],
                each_task_accuracy[4],
                each_task_accuracy[5],
                each_task_accuracy[6],
                each_task_accuracy[7],
                each_task_accuracy[8],
                each_task_accuracy[9],
                each_task_accuracy[10],
                each_task_accuracy[11],
                each_task_accuracy[12], 
                each_task_accuracy[13], 
                each_task_accuracy[14], 
                each_task_accuracy[15], 
                each_task_accuracy[16], 
                each_task_accuracy[17], 
                each_task_accuracy[18], 
                each_task_accuracy[19], 
                each_task_accuracy[20], 
                each_task_accuracy[21], 
                each_task_accuracy[22], 
                each_task_accuracy[23], 
                each_task_accuracy[24], 
                each_task_accuracy[25], 
                each_task_accuracy[26], 
                each_task_accuracy[27]
                ]
            csv_data.loc[col] = new_data
            csv_data = csv_data.to_csv(CSV_V2_DIR, index=False)
    return 0

def get_baseline_df():
    # pdb.set_trace()
    df = pd.read_csv(CSV_DIR)
    df = df.sort_values(by="Avg. All", ascending=False)
    present_columns = MODEL_INFO + checkbox_group.value
    df = df[present_columns]
    return df

def get_baseline_v2_df():
    # pdb.set_trace()
    df = pd.read_csv(CSV_V2_DIR)
    df = df.sort_values(by="Avg. P1", ascending=False)
    present_columns = MODEL_INFO_V2 + checkbox_group_v2.value
    # pdb.set_trace()
    df = df[present_columns]
    return df

def get_all_df():
    df = pd.read_csv(CSV_DIR)
    df = df.sort_values(by="Avg. All", ascending=False)
    return df

def get_all_v2_df():
    df = pd.read_csv(CSV_V2_DIR)
    df = df.sort_values(by="Avg. P1", ascending=False)
    return df

block = gr.Blocks()


with block:
    gr.Markdown(
        LEADERBORAD_INTRODUCTION
    )
    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("🏅 SEED Benchmark v2", elem_id="seed-benchmark-tab-table", id=0):
            with gr.Row():
                with gr.Accordion("Citation", open=False):
                    citation_button = gr.Textbox(
                        value=CITATION_BUTTON_TEXT,
                        label=CITATION_BUTTON_LABEL,
                        elem_id="citation-button",
                    ).style(show_copy_button=True)
    
            gr.Markdown(
                TABLE_INTRODUCTION
            )

            # selection for column part:
            checkbox_group_v2 = gr.CheckboxGroup(
                choices=TASK_V2_INFO,
                value=AVG_V2_INFO,
                label="Evaluation Dimension",
                interactive=True,
            )

            # selection for model size part:
            model_size_v2 = gr.CheckboxGroup(
                choices=MODEL_SIZE,
                value=MODEL_SIZE,
                label="Model Size",
                interactive=True,
            )

            # selection for model size part:
            evaluation_method_v2 = gr.CheckboxGroup(
                choices=EVALUATION_METHOD,
                value=EVALUATION_METHOD,
                label="Evaluation Method",
                interactive=True,
            )
            
            # 创建数据帧组件
            data_component_v2 = gr.components.Dataframe(
                value=get_baseline_v2_df, 
                headers=COLUMN_V2_NAMES,
                type="pandas", 
                datatype=DATA_TITILE_V2_TYPE,
                interactive=False,
                visible=True,
                )
            
            def on_filter_model_size_method_v2_change(selected_model_size, selected_evaluation_method, selected_columns):

                updated_data = get_all_v2_df()
                # model_size & evaluation_method:
                # 自定义过滤函数
                def custom_filter(row, model_size_filters, evaluation_method_filters):
                    model_size = row['Model Size']
                    evaluation_method = row['Evaluation Method']

                    if model_size == '-':
                        size_filter = '-' in model_size_filters
                    elif 'B' in model_size:
                        size = float(model_size.replace('B', ''))
                        size_filter = ('>=10B' in model_size_filters and size >= 10) or ('<10B' in model_size_filters and size < 10)
                    else:
                        size_filter = False

                    method_filter = evaluation_method in evaluation_method_filters

                    return size_filter and method_filter

                # 使用自定义过滤函数过滤数据
                mask = updated_data.apply(custom_filter, axis=1, model_size_filters=selected_model_size, evaluation_method_filters=selected_evaluation_method)
                updated_data = updated_data[mask]

                # columns:
                selected_columns = [item for item in TASK_V2_INFO if item in selected_columns]
                present_columns = MODEL_INFO_V2 + selected_columns
                updated_data = updated_data[present_columns]
                updated_data = updated_data.sort_values(by=selected_columns[0], ascending=False)
                updated_headers = present_columns
                update_datatype = [DATA_TITILE_V2_TYPE[COLUMN_V2_NAMES.index(x)] for x in updated_headers]

                filter_component = gr.components.Dataframe(
                    value=updated_data, 
                    headers=updated_headers,
                    type="pandas", 
                    datatype=update_datatype,
                    interactive=False,
                    visible=True,
                    )
                # pdb.set_trace()
        
                return filter_component.value

            model_size_v2.change(fn=on_filter_model_size_method_v2_change, inputs=[model_size_v2, evaluation_method_v2, checkbox_group_v2], outputs=data_component_v2)
            evaluation_method_v2.change(fn=on_filter_model_size_method_v2_change, inputs=[model_size_v2, evaluation_method_v2, checkbox_group_v2], outputs=data_component_v2)
            checkbox_group_v2.change(fn=on_filter_model_size_method_v2_change, inputs=[model_size_v2, evaluation_method_v2, checkbox_group_v2], outputs=data_component_v2)

        # table seed-bench-v1
        with gr.TabItem("🏅 SEED Benchmark v1", elem_id="seed-benchmark-tab-table", id=1):
            with gr.Row():
                with gr.Accordion("Citation", open=False):
                    citation_button = gr.Textbox(
                        value=CITATION_BUTTON_TEXT,
                        label=CITATION_BUTTON_LABEL,
                        elem_id="citation-button",
                    ).style(show_copy_button=True)
    
            gr.Markdown(
                TABLE_INTRODUCTION
            )

            # selection for column part:
            checkbox_group = gr.CheckboxGroup(
                choices=TASK_INFO,
                value=AVG_INFO,
                label="Evaluation Dimension",
                interactive=True,
            )

            # selection for model size part:
            model_size = gr.CheckboxGroup(
                choices=MODEL_SIZE,
                value=MODEL_SIZE,
                label="Model Size",
                interactive=True,
            )

            # selection for model size part:
            evaluation_method = gr.CheckboxGroup(
                choices=EVALUATION_METHOD,
                value=EVALUATION_METHOD,
                label="Evaluation Method",
                interactive=True,
            )

            # 创建数据帧组件
            data_component = gr.components.Dataframe(
                value=get_baseline_df, 
                headers=COLUMN_NAMES,
                type="pandas", 
                datatype=DATA_TITILE_TYPE,
                interactive=False,
                visible=True,
                )
    
            def on_filter_model_size_method_change(selected_model_size, selected_evaluation_method, selected_columns):

                updated_data = get_all_df()
                # model_size & evaluation_method:
                # 自定义过滤函数
                def custom_filter(row, model_size_filters, evaluation_method_filters):
                    model_size = row['Model Size']
                    evaluation_method = row['Evaluation Method']

                    if model_size == '-':
                        size_filter = '-' in model_size_filters
                    elif 'B' in model_size:
                        size = float(model_size.replace('B', ''))
                        size_filter = ('>=10B' in model_size_filters and size >= 10) or ('<10B' in model_size_filters and size < 10)
                    else:
                        size_filter = False

                    method_filter = evaluation_method in evaluation_method_filters

                    return size_filter and method_filter

                # 使用自定义过滤函数过滤数据
                mask = updated_data.apply(custom_filter, axis=1, model_size_filters=selected_model_size, evaluation_method_filters=selected_evaluation_method)
                updated_data = updated_data[mask]

                # columns:
                selected_columns = [item for item in TASK_INFO if item in selected_columns]
                present_columns = MODEL_INFO + selected_columns
                updated_data = updated_data[present_columns]
                updated_data = updated_data.sort_values(by=selected_columns[0], ascending=False)
                updated_headers = present_columns
                update_datatype = [DATA_TITILE_TYPE[COLUMN_NAMES.index(x)] for x in updated_headers]

                filter_component = gr.components.Dataframe(
                    value=updated_data, 
                    headers=updated_headers,
                    type="pandas", 
                    datatype=update_datatype,
                    interactive=False,
                    visible=True,
                    )
                # pdb.set_trace()
        
                return filter_component.value

            model_size.change(fn=on_filter_model_size_method_change, inputs=[model_size, evaluation_method, checkbox_group], outputs=data_component)
            evaluation_method.change(fn=on_filter_model_size_method_change, inputs=[model_size, evaluation_method, checkbox_group], outputs=data_component)
            checkbox_group.change(fn=on_filter_model_size_method_change, inputs=[model_size, evaluation_method, checkbox_group], outputs=data_component)

        # table 2
        with gr.TabItem("📝 About", elem_id="seed-benchmark-tab-table", id=2):
            gr.Markdown(LEADERBORAD_INFO, elem_classes="markdown-text")
        
        # table 3 
        with gr.TabItem("🚀 Submit here! ", elem_id="seed-benchmark-tab-table", id=3):
            gr.Markdown(LEADERBORAD_INTRODUCTION, elem_classes="markdown-text")

            with gr.Row():
                gr.Markdown(SUBMIT_INTRODUCTION, elem_classes="markdown-text")

            with gr.Row():
                gr.Markdown("# ✉️✨ Submit your model evaluation json file here!", elem_classes="markdown-text")

            with gr.Row():
                with gr.Column():
                    model_name_textbox = gr.Textbox(
                        label="Model name", placeholder="LLaMA-7B"
                        )
                    revision_name_textbox = gr.Textbox(
                        label="Revision Model Name", placeholder="LLaMA-7B"
                    )
                    model_type = gr.Dropdown(
                        choices=[                         
                            "LLM",
                            "ImageLLM",
                            "VideoLLM",
                            "Other", 
                        ], 
                        label="Model type", 
                        multiselect=False,
                        value="ImageLLM",
                        interactive=True,
                    )
                    model_link = gr.Textbox(
                        label="Model Link", placeholder="https://huggingface.co/decapoda-research/llama-7b-hf"
                    )
                    model_size = gr.Textbox(
                        label="Model size", placeholder="7B(Input content format must be 'number+B' or '-', default is '-')"
                    )
                    benchmark_version= gr.Dropdown(
                        choices=["v1", "v2"],
                        label="Benchmark version", 
                        multiselect=False,
                        value="v1",
                        interactive=True,
                    )

                with gr.Column():
                    LLM_type = gr.Dropdown(
                        choices=["Vicuna-7B", "Flan-T5-XL", "LLaMA-7B", "Other"],
                        label="LLM type", 
                        multiselect=False,
                        value="LLaMA-7B",
                        interactive=True,
                    )
                    LLM_name_textbox = gr.Textbox(
                        label="LLM model (for Other)",
                        placeholder="LLaMA-13B"
                    )
                    Evaluation_dimension = gr.Dropdown(
                        choices=["All", "Image", "Video"],
                        label="Evaluation dimension for SEED-Bench 1(for evaluate SEED-Bench 1)", 
                        multiselect=False,
                        value="All",
                        interactive=True,
                    )
                    Evaluation_dimension_2 = gr.Dropdown(
                        choices=["L1", "L2", "L3"],
                        label="Evaluation dimension for SEED-Bench 2(for evaluate SEED-Bench 2)", 
                        multiselect=False,
                        value="L2",
                        interactive=True,
                    )
                    Evaluation_method = gr.Dropdown(
                        choices=EVALUATION_METHOD,
                        label="Evaluation method", 
                        multiselect=False,
                        value=EVALUATION_METHOD[0],
                        interactive=True,
                    )

            with gr.Column():

                input_file = gr.inputs.File(label = "Click to Upload a json File", file_count="single", type='binary')
                submit_button = gr.Button("Submit Eval")
    
                submission_result = gr.Markdown()
                submit_button.click(
                    add_new_eval,
                    inputs = [
                        input_file,
                        model_name_textbox,
                        revision_name_textbox,
                        model_type,
                        model_link,
                        model_size,
                        benchmark_version,
                        LLM_type,
                        LLM_name_textbox,
                        Evaluation_dimension,
                        Evaluation_dimension_2,
                        Evaluation_method
                    ],
                )


    def refresh_data():
        value1 = get_baseline_df()
        value2 = get_baseline_v2_df()

        return value1, value2

    with gr.Row():
        data_run = gr.Button("Refresh")
        data_run.click(
            refresh_data, outputs=[data_component, data_component_v2]
        )

    # block.load(get_baseline_df, outputs=data_title)

block.launch()