# data_viewer.py
import base64
import json
from functools import lru_cache
from io import BytesIO

import gradio as gr
from datasets import load_dataset
from PIL import Image


@lru_cache(maxsize=1)
def load_cached_dataset(dataset_name, split):
    return load_dataset(dataset_name, split=split)


def base64_to_image(base64_string):
    img_data = base64.b64decode(base64_string)
    return Image.open(BytesIO(img_data))


def get_responses(responses, rankings):
    if isinstance(responses, str):
        responses = json.loads(responses)
    if isinstance(rankings, str):
        rankings = json.loads(rankings)

    chosen = next((resp for resp, rank in zip(responses, rankings) if rank == 0), "No chosen response")
    rejected = next((resp for resp, rank in zip(responses, rankings) if rank == 1), "No rejected response")

    return chosen, rejected


def load_and_display_sample(dataset_name, split, idx):
    try:
        dataset = load_cached_dataset(dataset_name, split)
        max_idx = len(dataset) - 1
        idx = min(max(0, int(idx)), max_idx)

        sample = dataset[idx]

        # Process image
        image = base64_to_image(sample["image"])

        # Get responses
        chosen_response, rejected_response = get_responses(sample["response"], sample["human_ranking"])

        # Process JSON data
        models = json.loads(sample["models"]) if isinstance(sample["models"], str) else sample["models"]
        meta = json.loads(sample["meta"]) if isinstance(sample["meta"], str) else sample["meta"]
        error_analysis = (
            json.loads(sample["human_error_analysis"])
            if isinstance(sample["human_error_analysis"], str)
            else sample["human_error_analysis"]
        )

        return (
            image,  # image
            sample["id"],  # sample_id
            chosen_response,  # chosen_response
            rejected_response,  # rejected_response
            sample["judge"],  # judge
            sample["query_source"],  # query_source
            sample["query"],  # query
            json.dumps(models, indent=2),  # models_json
            json.dumps(meta, indent=2),  # meta_json
            sample["rationale"],  # rationale
            json.dumps(error_analysis, indent=2),  # error_analysis_json
            sample["ground_truth"],  # ground_truth
            f"Total samples: {len(dataset)}",  # total_samples
        )
    except Exception as e:
        raise gr.Error(f"Error loading dataset: {str(e)}")


def create_data_viewer():
    # Pre-fetch initial data
    initial_dataset_name = "MMInstruction/VRewardBench"
    initial_split = "test"
    initial_idx = 0
    initial_data = load_and_display_sample(initial_dataset_name, initial_split, initial_idx)

    with gr.Column():
        with gr.Row():
            dataset_name = gr.Textbox(label="Dataset Name", value=initial_dataset_name, interactive=True)
            dataset_split = gr.Radio(choices=["test"], value=initial_split, label="Dataset Split")
            sample_idx = gr.Number(label="Sample Index", value=initial_idx, minimum=0, step=1, interactive=True)
            total_samples = gr.Textbox(
                label="Total Samples", value=initial_data[12], interactive=False  # Set initial total samples
            )

        with gr.Row():
            with gr.Column():
                image = gr.Image(label="Sample Image", type="pil", value=initial_data[0])  # Set initial image

            with gr.Column():
                sample_id = gr.Textbox(
                    label="Sample ID", value=initial_data[1], interactive=False  # Set initial sample ID
                )
                chosen_response = gr.TextArea(
                    label="Chosen Response ✅", value=initial_data[2], interactive=False  # Set initial chosen response
                )
                rejected_response = gr.TextArea(
                    label="Rejected Response ❌",
                    value=initial_data[3],  # Set initial rejected response
                    interactive=False,
                )

        with gr.Row():
            judge = gr.Textbox(label="Judge", value=initial_data[4], interactive=False)  # Set initial judge
            query_source = gr.Textbox(
                label="Query Source", value=initial_data[5], interactive=False  # Set initial query source
            )
            query = gr.Textbox(label="Query", value=initial_data[6], interactive=False)  # Set initial query

        with gr.Row():
            with gr.Column():
                models_json = gr.JSON(label="Models", value=json.loads(initial_data[7]))  # Set initial models
                meta_json = gr.JSON(label="Meta", value=json.loads(initial_data[8]))  # Set initial meta
                rationale = gr.TextArea(
                    label="Rationale", value=initial_data[9], interactive=False  # Set initial rationale
                )

            with gr.Column():
                error_analysis_json = gr.JSON(
                    label="Human Error Analysis", value=json.loads(initial_data[10])  # Set initial error analysis
                )
                ground_truth = gr.TextArea(
                    label="Ground Truth", value=initial_data[11], interactive=False  # Set initial ground truth
                )

        # Auto-update when any input changes
        for input_component in [dataset_name, dataset_split, sample_idx]:
            input_component.change(
                fn=load_and_display_sample,
                inputs=[dataset_name, dataset_split, sample_idx],
                outputs=[
                    image,
                    sample_id,
                    chosen_response,
                    rejected_response,
                    judge,
                    query_source,
                    query,
                    models_json,
                    meta_json,
                    rationale,
                    error_analysis_json,
                    ground_truth,
                    total_samples,
                ],
            )

    return dataset_name, dataset_split, sample_idx