import pandas as pd
import json
from typing import Dict, Any, Tuple
import os
from constants import (
    MODEL_NAME_MAP,
    DIMENSION_NAME_MAP,
    KEYWORD_NAME_MAP,
    MODEL_URLS,
    BASE_MODEL_GROUPS
)

class MEGABenchEvalDataLoader:
    def __init__(self, base_path):
        self.base_path = base_path
        # Load both model and summary data at once
        self.KEYWORD_DATA, self.SUMMARY_DATA = self._load_data()
        self.SUPER_GROUPS = self._initialize_super_groups()
        self.MODEL_GROUPS = self._initialize_model_groups()

    def _get_base_path(self) -> str:
        raise NotImplementedError("Subclasses must implement _get_base_path")

    def _load_data(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        summary_data = {}
        keyword_data = {}
        model_folders = [f for f in os.listdir(self.base_path) if os.path.isdir(os.path.join(self.base_path, f))]
        for model_name in model_folders:
            model_path = f"{self.base_path}/{model_name}/summary_and_keyword_stats.json"
            with open(model_path, "r") as f:
                data = json.load(f)
                if "keyword_stats" in data:
                    keyword_data[model_name] = data["keyword_stats"]
                if "model_summary" in data:
                    summary_data[model_name] = data["model_summary"]
        
        return keyword_data, summary_data

    def _initialize_super_groups(self):
        # Get a sample model to access the structure
        sample_model = next(iter(self.KEYWORD_DATA))
        
        # Create groups with task counts
        groups = {}
        self.keyword_display_map = {}  # Add this map to store display-to-original mapping
        
        for dim in self.KEYWORD_DATA[sample_model]:
            dim_name = DIMENSION_NAME_MAP[dim]
            # Create a list of tuples (display_name, count, keyword) for sorting
            keyword_info = []
            
            for keyword in self.KEYWORD_DATA[sample_model][dim]:
                # Get the task count for this keyword
                task_count = self.KEYWORD_DATA[sample_model][dim][keyword]["count"]
                original_name = KEYWORD_NAME_MAP.get(keyword, keyword)
                display_name = f"{original_name}({task_count})"
                keyword_info.append((display_name, task_count, keyword))
            
            # Sort by count (descending) and then by display name (for ties)
            keyword_info.sort(key=lambda x: (-x[1], x[0]))
            
            # Store sorted display names and update mapping
            groups[dim_name] = [info[0] for info in keyword_info]
            for display_name, _, keyword in keyword_info:
                self.keyword_display_map[display_name] = keyword
        
        # Sort based on predefined order
        order = ["Application", "Skills", "Output Format", "Input Format", "Visual Input Number"]
        return {k: groups[k] for k in order if k in groups}

    def _initialize_model_groups(self) -> Dict[str, list]:
        available_models = set(self.KEYWORD_DATA.keys())
        
        filtered_groups = {}
        for group_name, models in BASE_MODEL_GROUPS.items():
            if group_name == "All":
                filtered_groups[group_name] = sorted(list(available_models))
            else:
                filtered_models = [model for model in models if model in available_models]
                if filtered_models:
                    filtered_groups[group_name] = filtered_models
        
        return filtered_groups

    def get_df(self, selected_super_group: str, selected_model_group: str) -> pd.DataFrame:
        original_dimension = get_original_dimension(selected_super_group)
        data = []
        
        for model in self.MODEL_GROUPS[selected_model_group]:
            if model not in self.KEYWORD_DATA or model not in self.SUMMARY_DATA:
                continue
            
            model_data = self.KEYWORD_DATA[model]
            summary = self.SUMMARY_DATA[model]
            
            # Basic model information
            row = {
                "Models": get_display_model_name(model, as_link=True),
                "Overall": round(summary["overall_score"] * 100, 2),
                "Core": round(summary["core"]["macro_mean_score"] * 100, 2),
                "Open-ended": round(summary["open"]["macro_mean_score"] * 100, 2)
            }
            
            # Add dimension-specific scores
            if original_dimension in model_data:
                for display_name in self.SUPER_GROUPS[selected_super_group]:
                    original_keyword = self.keyword_display_map[display_name]
                    if original_keyword in model_data[original_dimension]:
                        row[display_name] = round(model_data[original_dimension][original_keyword]["average_score"] * 100, 2)
                    else:
                        row[display_name] = None
            else:
                for display_name in self.SUPER_GROUPS[selected_super_group]:
                    row[display_name] = None
                
            data.append(row)
        
        df = pd.DataFrame(data)
        df = df.sort_values(by="Overall", ascending=False)
        return df

    def get_leaderboard_data(self, selected_super_group: str, selected_model_group: str) -> Tuple[list, list]:
        df = self.get_df(selected_super_group, selected_model_group)
        
        # Get total task counts from the first model's data
        sample_model = "GPT_4o"
        total_core_tasks = self.SUMMARY_DATA[sample_model]["core"]["num_eval_tasks"]
        total_open_tasks = self.SUMMARY_DATA[sample_model]["open"]["num_eval_tasks"]
        total_tasks = total_core_tasks + total_open_tasks

        # Define headers with task counts
        column_headers = {
            "Models": "Models",
            "Overall": f"Overall({total_tasks})",
            "Core": f"Core({total_core_tasks})",
            "Open-ended": f"Open-ended({total_open_tasks})"
        }
        
        # Rename the columns in DataFrame to match headers
        df = df.rename(columns=column_headers)
        
        headers = [
            column_headers["Models"],
            column_headers["Overall"],
            column_headers["Core"],
            column_headers["Open-ended"]
        ] + self.SUPER_GROUPS[selected_super_group]
        
        data = df[[
            column_headers["Models"],
            column_headers["Overall"],
            column_headers["Core"],
            column_headers["Open-ended"]
        ] + self.SUPER_GROUPS[selected_super_group]].values.tolist()
        
        return headers, data


# Keep your helper functions
def get_original_dimension(mapped_dimension):
    return next(k for k, v in DIMENSION_NAME_MAP.items() if v == mapped_dimension)

def get_original_keyword(mapped_keyword):
    return next((k for k, v in KEYWORD_NAME_MAP.items() if v == mapped_keyword), mapped_keyword)

def get_display_model_name(model_name: str, as_link: bool = True) -> str:
    display_name = MODEL_NAME_MAP.get(model_name, model_name)
    if as_link and model_name in MODEL_URLS:
        return f'<a href="{MODEL_URLS[model_name]}" target="_blank" style="text-decoration: none; color: #2196F3;">{display_name}</a>'
    return display_name