"""This section describes unitxt operators for tabular data.

These operators are specialized in handling tabular data.
Input table format is assumed as:
{
  "header": ["col1", "col2"],
  "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
}

------------------------
"""
import random
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import (
    Any,
    Dict,
    List,
    Optional,
)

from .dict_utils import dict_get
from .operators import FieldOperator, StreamInstanceOperator


class SerializeTable(ABC, FieldOperator):
    """TableSerializer converts a given table into a flat sequence with special symbols.

    Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow.
    """

    # main method to serialize a table
    @abstractmethod
    def serialize_table(self, table_content: Dict) -> str:
        pass

    # method to process table header
    @abstractmethod
    def process_header(self, header: List):
        pass

    # method to process a table row
    @abstractmethod
    def process_row(self, row: List, row_index: int):
        pass


# Concrete classes implementing table serializers
class SerializeTableAsIndexedRowMajor(SerializeTable):
    """Indexed Row Major Table Serializer.

    Commonly used row major serialization format.
    Format:  col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
    """

    def process_value(self, table: Any) -> Any:
        table_input = deepcopy(table)
        return self.serialize_table(table_content=table_input)

    # main method that processes a table
    # table_content must be in the presribed input format
    def serialize_table(self, table_content: Dict) -> str:
        # Extract headers and rows from the dictionary
        header = table_content.get("header", [])
        rows = table_content.get("rows", [])

        assert header and rows, "Incorrect input table format"

        # Process table header first
        serialized_tbl_str = self.process_header(header) + " "

        # Process rows sequentially starting from row 1
        for i, row in enumerate(rows, start=1):
            serialized_tbl_str += self.process_row(row, row_index=i) + " "

        # return serialized table as a string
        return serialized_tbl_str.strip()

    # serialize header into a string containing the list of column names separated by '|' symbol
    def process_header(self, header: List):
        return "col : " + " | ".join(header)

    # serialize a table row into a string containing the list of cell values separated by '|'
    def process_row(self, row: List, row_index: int):
        serialized_row_str = ""
        row_cell_values = [
            str(value) if isinstance(value, (int, float)) else value for value in row
        ]

        serialized_row_str += " | ".join(row_cell_values)

        return f"row {row_index} : {serialized_row_str}"


class SerializeTableAsMarkdown(SerializeTable):
    """Markdown Table Serializer.

    Markdown table format is used in GitHub code primarily.
    Format:
    |col1|col2|col3|
    |---|---|---|
    |A|4|1|
    |I|2|1|
    ...
    """

    def process_value(self, table: Any) -> Any:
        table_input = deepcopy(table)
        return self.serialize_table(table_content=table_input)

    # main method that serializes a table.
    # table_content must be in the presribed input format.
    def serialize_table(self, table_content: Dict) -> str:
        # Extract headers and rows from the dictionary
        header = table_content.get("header", [])
        rows = table_content.get("rows", [])

        assert header and rows, "Incorrect input table format"

        # Process table header first
        serialized_tbl_str = self.process_header(header)

        # Process rows sequentially starting from row 1
        for i, row in enumerate(rows, start=1):
            serialized_tbl_str += self.process_row(row, row_index=i)

        # return serialized table as a string
        return serialized_tbl_str.strip()

    # serialize header into a string containing the list of column names
    def process_header(self, header: List):
        header_str = "|{}|\n".format("|".join(header))
        header_str += "|{}|\n".format("|".join(["---"] * len(header)))
        return header_str

    # serialize a table row into a string containing the list of cell values
    def process_row(self, row: List, row_index: int):
        row_str = ""
        row_str += "|{}|\n".format("|".join(str(cell) for cell in row))
        return row_str


# truncate cell value to maximum allowed length
def truncate_cell(cell_value, max_len):
    if cell_value is None:
        return None

    if isinstance(cell_value, int) or isinstance(cell_value, float):
        return None

    if cell_value.strip() == "":
        return None

    if len(cell_value) > max_len:
        return cell_value[:max_len]

    return None


class TruncateTableCells(StreamInstanceOperator):
    """Limit the maximum length of cell values in a table to reduce the overall length.

    Args:
        max_length (int) - maximum allowed length of cell values
        For tasks that produce a cell value as answer, truncating a cell value should be replicated
        with truncating the corresponding answer as well. This has been addressed in the implementation.

    """

    max_length: int = 15
    table: str = None
    text_output: Optional[str] = None
    use_query: bool = False

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        table = dict_get(instance, self.table, use_dpath=self.use_query)

        answers = []
        if self.text_output is not None:
            answers = dict_get(instance, self.text_output, use_dpath=self.use_query)

        self.truncate_table(table_content=table, answers=answers)

        return instance

    # truncate table cells
    def truncate_table(self, table_content: Dict, answers: Optional[List]):
        cell_mapping = {}

        # One row at a time
        for row in table_content.get("rows", []):
            for i, cell in enumerate(row):
                truncated_cell = truncate_cell(cell, self.max_length)
                if truncated_cell is not None:
                    cell_mapping[cell] = truncated_cell
                    row[i] = truncated_cell

        # Update values in answer list to truncated values
        if answers is not None:
            for i, case in enumerate(answers):
                answers[i] = cell_mapping.get(case, case)


class TruncateTableRows(FieldOperator):
    """Limits table rows to specified limit by removing excess rows via random selection.

    Args:
        rows_to_keep (int) - number of rows to keep.
    """

    rows_to_keep: int = 10

    def process_value(self, table: Any) -> Any:
        return self.truncate_table_rows(table_content=table)

    def truncate_table_rows(self, table_content: Dict):
        # Get rows from table
        rows = table_content.get("rows", [])

        num_rows = len(rows)

        # if number of rows are anyway lesser, return.
        if num_rows <= self.rows_to_keep:
            return table_content

        # calculate number of rows to delete, delete them
        rows_to_delete = num_rows - self.rows_to_keep

        # Randomly select rows to be deleted
        deleted_rows_indices = random.sample(range(len(rows)), rows_to_delete)

        remaining_rows = [
            row for i, row in enumerate(rows) if i not in deleted_rows_indices
        ]
        table_content["rows"] = remaining_rows

        return table_content


class SerializeTableRowAsText(StreamInstanceOperator):
    """Serializes a table row as text.

    Args:
        fields (str) - list of fields to be included in serialization.
        to_field (str) - serialized text field name.
        max_cell_length (int) - limits cell length to be considered, optional.
    """

    fields: str
    to_field: str
    max_cell_length: Optional[int] = None

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        linearized_str = ""
        for field in self.fields:
            value = dict_get(instance, field, use_dpath=False)
            if self.max_cell_length is not None:
                truncated_value = truncate_cell(value, self.max_cell_length)
                if truncated_value is not None:
                    value = truncated_value

            linearized_str = linearized_str + field + " is " + str(value) + ", "

        instance[self.to_field] = linearized_str
        return instance


class SerializeTableRowAsList(StreamInstanceOperator):
    """Serializes a table row as list.

    Args:
        fields (str) - list of fields to be included in serialization.
        to_field (str) - serialized text field name.
        max_cell_length (int) - limits cell length to be considered, optional.
    """

    fields: str
    to_field: str
    max_cell_length: Optional[int] = None

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        linearized_str = ""
        for field in self.fields:
            value = dict_get(instance, field, use_dpath=False)
            if self.max_cell_length is not None:
                truncated_value = truncate_cell(value, self.max_cell_length)
                if truncated_value is not None:
                    value = truncated_value

            linearized_str = linearized_str + field + ": " + str(value) + ", "

        instance[self.to_field] = linearized_str
        return instance


class SerializeTriples(FieldOperator):
    """Serializes triples into a flat sequence.

    Sample input in expected format:
    [[ "First Clearing", "LOCATION", "On NYS 52 1 Mi. Youngsville" ], [ "On NYS 52 1 Mi. Youngsville", "CITY_OR_TOWN", "Callicoon, New York"]]

    Sample output:
    First Clearing : LOCATION : On NYS 52 1 Mi. Youngsville | On NYS 52 1 Mi. Youngsville : CITY_OR_TOWN : Callicoon, New York

    """

    def process_value(self, tripleset: Any) -> Any:
        return self.serialize_triples(tripleset)

    def serialize_triples(self, tripleset) -> str:
        return " | ".join(
            f"{subj} : {rel.lower()} : {obj}" for subj, rel, obj in tripleset
        )


class SerializeKeyValPairs(FieldOperator):
    """Serializes key, value pairs into a flat sequence.

    Sample input in expected format: {"name": "Alex", "age": 31, "sex": "M"}
    Sample output: name is Alex, age is 31, sex is M
    """

    def process_value(self, kvpairs: Any) -> Any:
        return self.serialize_kvpairs(kvpairs)

    def serialize_kvpairs(self, kvpairs) -> str:
        serialized_str = ""
        for key, value in kvpairs.items():
            serialized_str += f"{key} is {value}, "

        # Remove the trailing comma and space then return
        return serialized_str[:-2]


class ListToKeyValPairs(StreamInstanceOperator):
    """Maps list of keys and values into key:value pairs.

    Sample input in expected format: {"keys": ["name", "age", "sex"], "values": ["Alex", 31, "M"]}
    Sample output: {"name": "Alex", "age": 31, "sex": "M"}
    """

    fields: List[str]
    to_field: str
    use_query: bool = False

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        keylist = dict_get(instance, self.fields[0], use_dpath=self.use_query)
        valuelist = dict_get(instance, self.fields[1], use_dpath=self.use_query)

        output_dict = {}
        for key, value in zip(keylist, valuelist):
            output_dict[key] = value

        instance[self.to_field] = output_dict

        return instance