# src/display/utils.py

from dataclasses import dataclass
from enum import Enum
from typing import Any, List

from src.about import Tasks

@dataclass
class ColumnContent:
    name: str
    type: Any
    label: str
    description: str
    hidden: bool = False
    displayed_by_default: bool = True
    never_hidden: bool = False

# Initialize the list of columns for the leaderboard
COLUMNS: List[ColumnContent] = []

# Essential columns
COLUMNS.append(
    ColumnContent(
        name="model_name",  # Changed from "model" to "model_name"
        type=str,
        label="Model",
        description="Model name",
        never_hidden=True,
    )
)
COLUMNS.append(
    ColumnContent(
        name="average",
        type=float,
        label="Average Accuracy (%)",
        description="Average accuracy across all subjects",
    )
)

# Include per-subject accuracy columns based on your subjects
for task in Tasks:
    COLUMNS.append(
        ColumnContent(
            name=task.value.benchmark,
            type=float,
            label=f"{task.value.col_name} (%)",
            description=f"Accuracy on {task.value.col_name}",
            displayed_by_default=True,
        )
    )

# Additional columns
COLUMNS.extend([
    ColumnContent(
        name="model_type",
        type=str,
        label="Model Type",
        description="Type of the model (e.g., Transformer, RNN, etc.)",
        displayed_by_default=True,
    ),
    ColumnContent(
        name="weight_type",
        type=str,
        label="Weight Type",
        description="Type of model weights (e.g., Original, Delta, Adapter)",
        displayed_by_default=True,
    ),
    ColumnContent(
        name="precision",
        type=str,
        label="Precision",
        description="Precision of the model weights (e.g., float16)",
        displayed_by_default=True,
    ),
    ColumnContent(
        name="license",
        type=str,
        label="License",
        description="License of the model",
        displayed_by_default=True,
    ),
    ColumnContent(
        name="likes",
        type=int,
        label="Likes",
        description="Number of likes on the Hugging Face Hub",
        displayed_by_default=True,
    ),
    ColumnContent(
        name="still_on_hub",
        type=bool,
        label="Available on the Hub",
        description="Whether the model is still available on the Hugging Face Hub",
        displayed_by_default=True,
    ),
])

# Create lists of column names for use in the application
COLS = [col.name for col in COLUMNS]
BENCHMARK_COLS = [col.name for col in COLUMNS if col.name not in [
    "model_name", "average", "model_type", "weight_type", "precision", "license", "likes", "still_on_hub"
]]

# For the queue columns in the submission tab
@dataclass(frozen=True)
class EvalQueueColumn:
    name: str
    type: Any
    label: str
    description: str

# Define the queue columns
EVAL_QUEUE_COLUMNS: List[EvalQueueColumn] = [
    EvalQueueColumn(
        name="model",
        type=str,
        label="Model",
        description="Model name",
    ),
    EvalQueueColumn(
        name="revision",
        type=str,
        label="Revision",
        description="Model revision or commit hash",
    ),
    EvalQueueColumn(
        name="private",
        type=bool,
        label="Private",
        description="Is the model private?",
    ),
    EvalQueueColumn(
        name="precision",
        type=str,
        label="Precision",
        description="Precision of the model weights",
    ),
    EvalQueueColumn(
        name="weight_type",
        type=str,
        label="Weight Type",
        description="Type of model weights",
    ),
    EvalQueueColumn(
        name="status",
        type=str,
        label="Status",
        description="Evaluation status",
    ),
]

# Create lists for evaluation columns and types
EVAL_COLS = [col.name for col in EVAL_QUEUE_COLUMNS]
EVAL_TYPES = [col.type for col in EVAL_QUEUE_COLUMNS]

# Model information
@dataclass
class ModelDetails:
    name: str
    display_name: str = ""
    symbol: str = ""  # emoji

class ModelType(Enum):
    PT = ModelDetails(name="pretrained", symbol="🟢")
    FT = ModelDetails(name="fine-tuned", symbol="🔶")
    IFT = ModelDetails(name="instruction-tuned", symbol="â­•")
    RL = ModelDetails(name="RL-tuned", symbol="🟦")
    Unknown = ModelDetails(name="", symbol="?")

    def to_str(self, separator=" "):
        return f"{self.value.symbol}{separator}{self.value.name}"

    @staticmethod
    def from_str(type_str):
        if "fine-tuned" in type_str or "🔶" in type_str:
            return ModelType.FT
        if "pretrained" in type_str or "🟢" in type_str:
            return ModelType.PT
        if "RL-tuned" in type_str or "🟦" in type_str:
            return ModelType.RL
        if "instruction-tuned" in type_str or "â­•" in type_str:
            return ModelType.IFT
        return ModelType.Unknown

class WeightType(Enum):
    Adapter = "Adapter"
    Original = "Original"
    Delta = "Delta"

class Precision(Enum):
    float16 = "float16"
    bfloat16 = "bfloat16"
    Unknown = "Unknown"

    @staticmethod
    def from_str(precision_str):
        if precision_str in ["torch.float16", "float16"]:
            return Precision.float16
        if precision_str in ["torch.bfloat16", "bfloat16"]:
            return Precision.bfloat16
        return Precision.Unknown