import datasets
import numpy as np
from huggingface_hub import HfApi

from functools import lru_cache

from src.utils import opt_to_index, get_test_target

def get_leaderboard_models_reload():
    api = HfApi()

    # Load prechecked models
    try:
        ungated_models = set(line.strip() for line in open("data/models.txt"))
    except FileNotFoundError:
        ungated_models = set()

    print(ungated_models)
    print(f"Number of prechecked models: {len(ungated_models)}")
    
    # List all datasets in the open-llm-leaderboard organization
    dataset_list = api.list_datasets(author="open-llm-leaderboard")
    
    models = []
    count_api_calls = 0
    for dataset in dataset_list:
        if dataset.id.endswith("-details"):
            # Format: "open-llm-leaderboard/<provider>__<model_name>-details"
            model_part = dataset.id.split("/")[-1].replace("-details", "")
            if "__" in model_part:
                provider, model = model_part.split("__", 1)
                model_name = f"{provider}/{model}"
            else:
                model_name = model_part

            # Only perform the check if dataset_id is not in the ungated_models list.
            if model_name not in ungated_models:
                try:
                    count_api_calls += 1
                    # Check if the dataset can be loaded; if not, skip it.
                    datasets.get_dataset_config_names(dataset.id)
                except Exception as e:
                    continue  # Skip dataset if an exception occurs

            models.append(model_name)
    
    print(f"API calls: {count_api_calls}")
    print(f"Number of models: {len(models)}")

    # Save model list as txt file
    with open("data/models.txt", "w") as f:
        for model in models:
            f.write(model + "\n")

    return sorted(models)


def get_leaderboard_models():
    # Load prechecked (ungated) models
    with open("data/models.txt", "r") as f:
        ungated_models = [line.strip() for line in f]

    return sorted(ungated_models)


@lru_cache(maxsize=1)
def get_leaderboard_models_cached():
    return get_leaderboard_models()


def get_leaderboard_datasets(model_ids):
    if model_ids is None:
        return ['bbh_boolean_expressions', 'bbh_causal_judgement', 'bbh_date_understanding', 'bbh_disambiguation_qa', 'bbh_formal_fallacies', 'bbh_geometric_shapes', 'bbh_hyperbaton', 'bbh_logical_deduction_five_objects', 'bbh_logical_deduction_seven_objects', 'bbh_logical_deduction_three_objects', 'bbh_movie_recommendation', 'bbh_navigate', 'bbh_object_counting', 'bbh_penguins_in_a_table', 'bbh_reasoning_about_colored_objects', 'bbh_ruin_names', 'bbh_salient_translation_error_detection', 'bbh_snarks', 'bbh_sports_understanding', 'bbh_temporal_sequences', 'bbh_tracking_shuffled_objects_five_objects', 'bbh_tracking_shuffled_objects_seven_objects', 'bbh_tracking_shuffled_objects_three_objects', 'bbh_web_of_lies', 'gpqa_diamond', 'gpqa_extended', 'gpqa_main', 'mmlu_pro', 'musr_murder_mysteries', 'musr_object_placements', 'musr_team_allocation']

    # Map each model to its corresponding leaderboard version
    leaderboard_model_ids = [f"open-llm-leaderboard/{model_id.replace('/', '__')}-details" for model_id in model_ids]

    model_datasets = {}

    for model_id in leaderboard_model_ids:
        # Retrieve the list of available configuration names
        config_names = datasets.get_dataset_config_names(model_id)
        dataset_names = [name.split("__leaderboard_")[-1] for name in config_names]
        model_datasets[model_id] = set(dataset_names)

    # Compute the intersection of datasets across all models
    if model_datasets:
        common_datasets = set.intersection(*model_datasets.values())

    # Filter datasets that are not MCQ or currently do not work
    ignore = ["bbh_temporal_sequences", "math_", "ifeval"]
    discard = []
    for dataset in common_datasets:
        for ignore_data in ignore:
            if ignore_data in dataset:
                discard.append(dataset)
    common_datasets = [dataset for dataset in common_datasets if dataset not in discard]

    return sorted(common_datasets)
    

def filter_labels(dataset_name, doc):
    labels = []
    test_target, target_key = get_test_target(doc[0])
    if "answer_index" in doc[0].keys():
        labels = [d["answer_index"] for d in doc]
    elif test_target.startswith("(") or test_target.isalpha():
        labels = [opt_to_index(d[target_key]) for d in doc]
    elif dataset_name in ["bbh_boolean_expressions"]:
        for d in doc:
            if d[target_key] == "True":
                labels.append(1)
            elif d[target_key] == "False":
                labels.append(0)
    elif dataset_name in ["bbh_causal_judgement", "bbh_navigate", "bbh_web_of_lies"]:
        for d in doc:
            if d[target_key] == "Yes":
                labels.append(0)
            elif d[target_key] == "No":
                labels.append(1)
    elif dataset_name in ["bbh_formal_fallacies"]:
        for d in doc:
            if d[target_key] == "valid":
                labels.append(0)
            elif d[target_key] == "invalid":
                labels.append(1)
    elif dataset_name in ["bbh_sports_understanding"]:
        for d in doc:
            if d[target_key] == "yes":
                labels.append(0)
            elif d[target_key] == "no":
                labels.append(1)
    elif test_target.isdigit():
        labels = [int(d[target_key]) for d in doc]
    
    print(f"Number of labels: {len(labels)}")

    return labels


def filter_responses(data):
    # Get log probabilities for each response
    log_probs = []

    for resp in data["filtered_resps"]:
        log_prob = np.array([float(option[0]) for option in resp])
        log_probs.append(log_prob)

    return log_probs


def load_run_data(model_name, dataset_name):
    try:
        model_name = model_name.replace("/", "__")

        data = datasets.load_dataset("open-llm-leaderboard/" + model_name + "-details",
                                    name=model_name + "__leaderboard_" + dataset_name,
                                    split="latest")
        data = data.sort("doc_id")
        data = data.to_dict()

        # Get ground truth labels and logits
        log_probs = filter_responses(data)
        labels = filter_labels(dataset_name, data["doc"])
        
    except Exception as e:
        print(e)
        log_probs = []
        labels = []

    return log_probs, labels


@lru_cache(maxsize=8)
def load_run_data_cached(model_name, dataset_name):
    return load_run_data(model_name, dataset_name)