import pandas as pd
from typing import List, Dict, Optional
from constants import (
    DATASET_ARXIV_SCAN_PAPERS,
    DATASET_CONFERENCE_PAPERS,
    DATASET_COMMUNITY_SCIENCE,
    NEURIPS_ICO,
    DATASET_PAPER_CENTRAL,
    COLM_ICO,
    DEFAULT_ICO,
    MICCAI24ICO,
    CORL_ICO,
)
import gradio as gr
from utils import load_and_process
import numpy as np
from datetime import datetime, timedelta
import re

class PaperCentral:
    """
    A class to manage and process paper data for display in a Gradio Dataframe component.
    """

    CONFERENCES_ICONS = {
        "NeurIPS2024 D&B": NEURIPS_ICO,
        "NeurIPS2024": NEURIPS_ICO,
        "EMNLP2024": 'https://aclanthology.org/aclicon.ico',
        "CoRL2024": CORL_ICO,
        "ACMMM2024": "https://2024.acmmm.org/favicon.ico",
        "MICCAI2024": MICCAI24ICO,
        "COLM2024": COLM_ICO,
        "COLING2024": 'https://aclanthology.org/aclicon.ico',
        "CVPR2024": "https://openaccess.thecvf.com/favicon.ico",
        "ACL2024": 'https://aclanthology.org/aclicon.ico',
        "ACL2023": 'https://aclanthology.org/aclicon.ico',
        "CVPR2023": "https://openaccess.thecvf.com/favicon.ico",
        "ECCV2024": "https://openaccess.thecvf.com/favicon.ico",
        "EMNLP2023": 'https://aclanthology.org/aclicon.ico',
        "NAACL2023": 'https://aclanthology.org/aclicon.ico',
        "NeurIPS2023": NEURIPS_ICO,
        "NeurIPS2023 D&B": NEURIPS_ICO,
    }

    CONFERENCES = list(CONFERENCES_ICONS.keys())

    # Class-level constants defining columns and their data types
    COLUMNS_START_PAPER_PAGE: List[str] = [
        'date',
        'arxiv_id',
        'paper_page',
        'title',
    ]

    COLUMNS_ORDER_PAPER_PAGE: List[str] = [
        'chat_with_paper',
        'date',
        'arxiv_id',
        'paper_page',
        'num_models',
        'num_datasets',
        'num_spaces',
        'upvotes',
        'num_comments',
        'github',
        'github_stars',
        'project_page',
        'conference_name',
        'id',
        'type',
        'proceedings',
        'title',
        'authors',
    ]

    DATATYPES: Dict[str, str] = {
        'date': 'str',
        'arxiv_id': 'markdown',
        'paper_page': 'markdown',
        'upvotes': 'number',
        'num_comments': 'number',
        'num_models': 'markdown',
        'num_datasets': 'markdown',
        'num_spaces': 'markdown',
        'github': 'markdown',
        'title': 'str',
        'proceedings': 'markdown',
        'conference_name': 'str',
        'id': 'str',
        'type': 'str',
        'authors': 'str',
        'github_stars': 'number',
        'project_page': 'markdown',
        'chat_with_paper': 'markdown',
    }

    # Mapping for renaming columns for display purposes
    COLUMN_RENAME_MAP: Dict[str, str] = {
        'num_models': 'models',
        'num_spaces': 'spaces',
        'num_datasets': 'datasets',
        'github': 'GitHub',
        'github_stars': 'GitHub⭐',
        'num_comments': '💬',
        'upvotes': '👍',
        'chat_with_paper': 'Chat',
    }

    def __init__(self):
        """
        Initialize the PaperCentral class by loading and processing the datasets.
        """
        self.df_raw: pd.DataFrame = self.get_df()
        self.df_prettified: pd.DataFrame = self.prettify(self.df_raw)

    @staticmethod
    def get_columns_order(columns: List[str]) -> List[str]:
        """
        Get columns ordered according to COLUMNS_ORDER_PAPER_PAGE.

        Args:
            columns (List[str]): List of column names to order.

        Returns:
            List[str]: Ordered list of column names.
        """
        return [c for c in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if c in columns]

    @staticmethod
    def get_columns_datatypes(columns: List[str]) -> List[str]:
        """
        Get data types for the specified columns.

        Args:
            columns (List[str]): List of column names.

        Returns:
            List[str]: List of data types corresponding to the columns.
        """
        return [PaperCentral.DATATYPES[c] for c in columns]

    @staticmethod
    def get_df() -> pd.DataFrame:
        """
        Load and merge datasets to create the raw DataFrame.

        Returns:
            pd.DataFrame: The merged and processed DataFrame.
        """
        # Load datasets
        paper_central_df: pd.DataFrame = load_and_process(DATASET_PAPER_CENTRAL)[
            ['arxiv_id', 'categories', 'primary_category', 'date', 'upvotes', 'num_comments', 'github', 'num_models',
             'num_datasets', 'num_spaces', 'id', 'proceedings', 'type',
             'conference_name', 'title', 'paper_page', 'authors', 'github_stars', 'project_page']
        ]

        # If arxiv published_date is weekend, switch to Monday
        def adjust_date(dt):
            if dt.weekday() == 5:  # Saturday
                return dt + pd.Timedelta(days=2)
            elif dt.weekday() == 6:  # Sunday
                return dt + pd.Timedelta(days=1)
            else:
                return dt

        # Convert 'date' column to datetime
        paper_central_df['date'] = pd.to_datetime(paper_central_df['date'], format='%Y-%m-%d')
        paper_central_df['date'] = paper_central_df['date'].apply(adjust_date)
        paper_central_df['date'] = paper_central_df['date'].dt.strftime('%Y-%m-%d')

        return paper_central_df

    @staticmethod
    def format_df_date(df: pd.DataFrame, date_column: str = "date") -> pd.DataFrame:
        """
        Format the date column in the DataFrame to 'YYYY-MM-DD'.

        Args:
            df (pd.DataFrame): The DataFrame to format.
            date_column (str): The name of the date column.

        Returns:
            pd.DataFrame: The DataFrame with the formatted date column.
        """
        df.loc[:, date_column] = pd.to_datetime(df[date_column]).dt.strftime('%Y-%m-%d')
        return df

    @staticmethod
    def prettify(df: pd.DataFrame) -> pd.DataFrame:
        """
        Prettify the DataFrame by adding markdown links and sorting.

        Args:
            df (pd.DataFrame): The DataFrame to prettify.

        Returns:
            pd.DataFrame: The prettified DataFrame.
        """

        def update_row(row: pd.Series) -> pd.Series:
            """
            Update a row by adding markdown links to 'paper_page' and 'arxiv_id' columns.

            Args:
                row (pd.Series): A row from the DataFrame.

            Returns:
                pd.Series: The updated row.
            """
            # Process 'num_models' column
            if (
                    'num_models' in row and pd.notna(row['num_models']) and row["arxiv_id"]
                    and float(row['num_models']) > 0
            ):
                num_models = int(float(row['num_models']))
                row['num_models'] = (
                    f"[{num_models}](https://huggingface.co/models?other=arxiv:{row['arxiv_id']})"
                )

            if (
                    'num_datasets' in row and pd.notna(row['num_datasets']) and row["arxiv_id"]
                    and float(row['num_datasets']) > 0
            ):
                num_datasets = int(float(row['num_datasets']))
                row['num_datasets'] = (
                    f"[{num_datasets}](https://huggingface.co/datasets?other=arxiv:{row['arxiv_id']})"
                )

            if (
                    'num_spaces' in row and pd.notna(row['num_spaces']) and row["arxiv_id"]
                    and float(row['num_spaces']) > 0
            ):
                num_spaces = int(float(row['num_spaces']))
                row['num_spaces'] = (
                    f"[{num_spaces}](https://huggingface.co/spaces?other=arxiv:{row['arxiv_id']})"
                )

            if 'proceedings' in row and pd.notna(row['proceedings']) and row['proceedings']:
                image_url = PaperCentral.CONFERENCES_ICONS.get(row["conference_name"], DEFAULT_ICO)

                style = "display:inline-block; vertical-align:middle; width: 16px; height:16px"
                row['proceedings'] = (
                    f"<img src='{image_url}' style='{style}'/>"
                    f"<a href='{row['proceedings']}'>proc_page</a>"
                )

            ####
            ### This should be processed last :)
            ####
            # Add markdown link to 'paper_page' if it exists
            if 'paper_page' in row and pd.notna(row['paper_page']) and row['paper_page']:
                row['paper_page'] = f"🤗[paper_page](https://huggingface.co/papers/{row['paper_page']})"

            # Add image and link to 'arxiv_id' if it exists
            if 'arxiv_id' in row and pd.notna(row['arxiv_id']) and row['arxiv_id']:
                image_url = "https://arxiv.org/static/browse/0.3.4/images/icons/favicon-16x16.png"
                style = "display:inline-block; vertical-align:middle;"
                row['arxiv_id'] = (
                    f"<img src='{image_url}' style='{style}'/>"
                    f"<a href='https://arxiv.org/abs/{row['arxiv_id']}'>arxiv_page</a>"
                )

            # Add image and link to 'arxiv_id' if it exists
            if 'github' in row and pd.notna(row['github']) and row["github"]:
                image_url = "https://github.githubassets.com/favicons/favicon.png"
                style = "display:inline-block; vertical-align:middle;width:16px;"
                row['github'] = (
                    f"<img src='{image_url}' style='{style}'/>"
                    f"<a href='{row['github']}'>github</a>"
                )

            if 'project_page' in row and pd.notna(row['project_page']) and row["project_page"]:
                row['project_page'] = (
                    f"<a href='{row['project_page']}'>{row['project_page']}</a>"
                )

            return row

        df = df.copy()

        # Apply the update_row function to each row
        prettified_df: pd.DataFrame = df.apply(update_row, axis=1)
        return prettified_df

    def rename_columns_for_display(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Rename columns in the DataFrame according to COLUMN_RENAME_MAP for display purposes.

        Args:
            df (pd.DataFrame): The DataFrame whose columns need to be renamed.

        Returns:
            pd.DataFrame: The DataFrame with renamed columns.
        """
        return df.rename(columns=self.COLUMN_RENAME_MAP)

    def filter(
            self,
            selected_date: Optional[str] = None,
            cat_options: Optional[List[str]] = None,
            hf_options: Optional[List[str]] = None,
            conference_options: Optional[List[str]] = None,
            author_search_input: Optional[str] = None,
            title_search_input: Optional[str] = None,
            date_range_option: Optional[str] = None,
    ) -> gr.update:
        """
        Filter the DataFrame based on selected date and options, and prepare it for display.
        """
        filtered_df: pd.DataFrame = self.df_raw.copy()

        # Start with the initial columns to display
        columns_to_show: List[str] = PaperCentral.COLUMNS_START_PAPER_PAGE.copy()

        # Handle title search
        if title_search_input:
            if 'title' not in columns_to_show:
                columns_to_show.append('authors')

            search_string = title_search_input.lower()

            def title_match(title):
                if isinstance(title, str):
                    return search_string in title.lower()
                else:
                    return False

            filtered_df = filtered_df[filtered_df['title'].apply(title_match)]

        # Handle author search
        if author_search_input:
            if 'authors' not in columns_to_show:
                columns_to_show.append('authors')

            search_string = author_search_input.lower()

            def author_matches(authors_list):
                if authors_list is None or len(authors_list) == 0:
                    return False
                if isinstance(authors_list, (list, tuple, pd.Series, np.ndarray)):
                    return any(
                        isinstance(author, str) and search_string in author.lower()
                        for author in authors_list
                    )
                elif isinstance(authors_list, str):
                    return search_string in authors_list.lower()
                else:
                    return False

            filtered_df = filtered_df[filtered_df['authors'].apply(author_matches)]

        # Handle category options
        if cat_options:
            if "(ALL)" in cat_options:
                # If "(ALL)" is selected, include all categories without filtering
                pass  # No action needed, include all categories
            else:
                # Proceed with filtering based on selected categories
                options = [o.replace(".*", "") for o in cat_options]
                conference_filter = pd.Series(False, index=filtered_df.index)
                for option in options:
                    conference_filter |= (
                            filtered_df['primary_category'].notna() &
                            filtered_df['primary_category'].str.contains(option, case=False)
                    )
                filtered_df = filtered_df[conference_filter]

        # Handle date filtering
        if not conference_options:
            if date_range_option:
                today = datetime.now()
                if date_range_option == "This week":
                    start_date = (today - timedelta(days=7)).strftime('%Y-%m-%d')
                    end_date = today.strftime('%Y-%m-%d')
                elif date_range_option == "This month":
                    start_date = (today - timedelta(days=30)).strftime('%Y-%m-%d')
                    end_date = today.strftime('%Y-%m-%d')
                elif date_range_option == "This year":
                    start_date = (today - timedelta(days=365)).strftime('%Y-%m-%d')
                    end_date = today.strftime('%Y-%m-%d')
                elif date_range_option == "All time":
                    start_date = None
                    end_date = None
                else:
                    start_date = None
                    end_date = None

                if start_date and end_date:
                    filtered_df = filtered_df[
                        (filtered_df['date'] >= start_date) & (filtered_df['date'] <= end_date)
                        ]
                else:
                    pass  # No date filtering for "All time"
            elif selected_date:
                selected_date = pd.to_datetime(selected_date).strftime('%Y-%m-%d')
                filtered_df = filtered_df[filtered_df['date'] == selected_date]

        # Handle Hugging Face options
        if hf_options:
            # Convert columns to numeric, handling non-numeric values
            filtered_df['num_datasets'] = pd.to_numeric(filtered_df['num_datasets'], errors='coerce').fillna(0).astype(
                int)
            filtered_df['num_models'] = pd.to_numeric(filtered_df['num_models'], errors='coerce').fillna(0).astype(int)
            filtered_df['num_spaces'] = pd.to_numeric(filtered_df['num_spaces'], errors='coerce').fillna(0).astype(int)

            if "🤗 artifacts" in hf_options:
                filtered_df = filtered_df[
                    (filtered_df['paper_page'] != "") & (filtered_df['paper_page'].notna())
                    ]
                if 'upvotes' not in columns_to_show:
                    columns_to_show.append('upvotes')
                if 'num_comments' not in columns_to_show:
                    columns_to_show.append('num_comments')
                if 'num_models' not in columns_to_show:
                    columns_to_show.append('num_models')
                if 'num_datasets' not in columns_to_show:
                    columns_to_show.append('num_datasets')
                if 'num_spaces' not in columns_to_show:
                    columns_to_show.append('num_spaces')

                filtered_df = filtered_df[
                    (filtered_df['num_datasets'] > 0) |
                    (filtered_df['num_models'] > 0) |
                    (filtered_df['num_spaces'] > 0)
                    ]

            if "datasets" in hf_options:
                if 'num_datasets' not in columns_to_show:
                    columns_to_show.append('num_datasets')
                filtered_df = filtered_df[filtered_df['num_datasets'] != 0]

            if "models" in hf_options:
                if 'num_models' not in columns_to_show:
                    columns_to_show.append('num_models')
                filtered_df = filtered_df[filtered_df['num_models'] != 0]

            if "spaces" in hf_options:
                if 'num_spaces' not in columns_to_show:
                    columns_to_show.append('num_spaces')
                filtered_df = filtered_df[filtered_df['num_spaces'] != 0]

            if "github" in hf_options:
                if 'github' not in columns_to_show:
                    columns_to_show.append('github')
                    columns_to_show.append('github_stars')
                filtered_df = filtered_df[(filtered_df['github'] != "") & (filtered_df['github'].notnull())]

            if "project page" in hf_options:
                if 'project_page' not in columns_to_show:
                    columns_to_show.append('project_page')
                filtered_df = filtered_df[(filtered_df['project_page'] != "") & (filtered_df['project_page'].notnull())]

        # Apply conference filtering
        if conference_options:
            columns_to_show = [col for col in columns_to_show if col not in ["date", "arxiv_id"]]

            if 'conference_name' not in columns_to_show:
                columns_to_show.append('conference_name')
            if 'proceedings' not in columns_to_show:
                columns_to_show.append('proceedings')
            if 'type' not in columns_to_show:
                columns_to_show.append('type')
            if 'id' not in columns_to_show:
                columns_to_show.append('id')

            if "ALL" in conference_options:
                filtered_df = filtered_df[
                    filtered_df['conference_name'].notna() & (filtered_df['conference_name'] != "")
                    ]

            other_conferences = [conf for conf in conference_options if conf != "ALL"]
            if other_conferences:
                conference_filter = pd.Series(False, index=filtered_df.index)
                for conference in other_conferences:
                    conference_filter |= (
                            filtered_df['conference_name'].notna() &
                            (filtered_df['conference_name'].str.lower() == conference.lower())
                    )
                filtered_df = filtered_df[conference_filter]

            if any(conf in ["NeurIPS2024 D&B", "NeurIPS2024"] for conf in conference_options):
                def create_chat_link(row):
                    neurips_id = re.search(r'id=([^&]+)', row["proceedings"])
                    if neurips_id:
                        neurips_id = neurips_id.group(1)
                        return f'<a href="/?tab=tab-chat-with-paper&paper_id={neurips_id}" id="custom_button" target="_blank" rel="noopener noreferrer" aria-disabled="false">✨ Chat with paper</a>'
                    else:
                        return ""

                # Add the "chat_with_paper" column
                filtered_df['chat_with_paper'] = filtered_df.apply(create_chat_link, axis=1)
                if 'chat_with_paper' not in columns_to_show:
                    columns_to_show.append('chat_with_paper')

        # Prettify the DataFrame
        filtered_df = self.prettify(filtered_df)

        # Ensure columns are ordered according to COLUMNS_ORDER_PAPER_PAGE
        columns_in_order: List[str] = [col for col in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if col in columns_to_show]
        filtered_df = filtered_df[columns_in_order]

        # Rename columns for display
        filtered_df = self.rename_columns_for_display(filtered_df)

        # Get the corresponding data types for the columns
        new_datatypes: List[str] = [
            PaperCentral.DATATYPES.get(self._get_original_column_name(col), 'str') for col in filtered_df.columns
        ]

        # Sort rows to display entries with 'paper_page' first
        if 'paper_page' in filtered_df.columns:
            filtered_df['has_paper_page'] = filtered_df['paper_page'].notna() & (filtered_df['paper_page'] != "")
            filtered_df.sort_values(by='has_paper_page', ascending=False, inplace=True)
            filtered_df.drop(columns='has_paper_page', inplace=True)

        # Return an update object to modify the Dataframe component
        return gr.update(value=filtered_df, datatype=new_datatypes)

    def _get_original_column_name(self, display_column_name: str) -> str:
        """
        Retrieve the original column name given a display column name.

        Args:
            display_column_name (str): The display name of the column.

        Returns:
            str: The original name of the column.
        """
        inverse_map = {v: k for k, v in self.COLUMN_RENAME_MAP.items()}
        return inverse_map.get(display_column_name, display_column_name)