import os
from typing import Any, List, Dict

from collections import defaultdict

import pandas as pd

dirname = os.path.dirname(__file__)
css_filename = os.path.join(dirname, "tapas-styles.css")
with open(css_filename) as f:
    css = f.read()


def HTMLBody(table_html: str, css_styles=css) -> str:
    """
    Generates the full html with css from a list of html spans

    Args:
        children (:obj:`List[str]`):
            A list of strings, assumed to be html elements

        css_styles (:obj:`str`, `optional`):
            Optional alternative implementation of the css

    Returns:
        :obj:`str`: An HTML string with style markup
    """
    return f"""
    <html>
        <head>
            <style>
                {css_styles}
            </style>
        </head>
        <body>
            <div class="tokenized-text" dir=auto>
            {table_html}
            </div>
        </body>
    </html>
    """


class TapasVisualizer:
    def __init__(self, tokenizer) -> None:
        self.tokenizer = tokenizer

    def normalize_token_str(self, token_str: str) -> str:
        return token_str.replace("##", "")

    def style_span(self, span_text: str, css_classes: List[str]) -> str:
        css = f'''class="{' '.join(css_classes)}"'''
        return f"<span {css} >{span_text}</span>"

    def text_to_html(self, org_text: str, tokens: List[str]) -> str:
        """Create html based on the original text and its tokens.

        Note: The tokens need to be in same order as in the original text

        Args:
            org_text (str): Original string before tokenization
            tokens (List[str]): The tokens of org_text

        Returns:
            str: html with styling for the tokens
        """
        if len(tokens) == 0:
            print(f'Empty tokens for: {org_text}')
            return ''

        cur_token_id = 0
        cur_token = self.normalize_token_str(tokens[cur_token_id])

        # Loop through each character
        next_start = 0
        last_end = 0
        spans = []

        while next_start < len(org_text):
            candidate = org_text[next_start: next_start + len(cur_token)]

            # The tokenizer performs lowercasing; so check against lowercase
            if candidate.lower() == cur_token:
                if last_end != next_start:
                    # There was token-less text (probably whitespace)
                    # in the middle
                    spans.append(self.style_span(org_text[last_end: next_start], ['non-token']))

                odd_or_even = 'even-token' if cur_token_id % 2 == 0 else 'odd-token'
                spans.append(self.style_span(candidate, ['token', odd_or_even]))
                next_start += len(cur_token)
                last_end = next_start
                cur_token_id += 1
                if cur_token_id >= len(tokens):
                    break
                cur_token = self.normalize_token_str(tokens[cur_token_id])
            else:
                next_start += 1
        
        if last_end != len(org_text):
            spans.append(self.style_span(org_text[last_end: next_start], ['non-token']))

        return spans

    def cells_to_html(self,
                      cell_vals: List[List[str]],
                      cell_tokens: Dict,
                      row_id_start: int=0,
                      cell_element: str="td",
                      cumulative_cnt: int=0,
                      table_html: str="") -> str:

        for row_id, row in enumerate(cell_vals, start=row_id_start):
            row_html = ""
            row_token_cnt = 0
            for col_id, cell in enumerate(row, start=1):
                cur_cell_tokens = cell_tokens[(row_id, col_id)]
                span_htmls = self.text_to_html(cell, cur_cell_tokens)
                cell_html = "".join(span_htmls)
                row_html += f"<{cell_element}>{cell_html}</{cell_element}>"
                row_token_cnt += len(cur_cell_tokens)
            cumulative_cnt += row_token_cnt
            # cnt_str = f'{row_token_cnt} | {cumulative_cnt}'
            row_html += f'<td style="border: none;" align="right">{self.style_span(str(row_token_cnt), ["non-token", "count"])}</td>'
            row_html += f'<td style="border: none;" align="right">{self.style_span(str(cumulative_cnt), ["non-token", "count"])}</td>'
            table_html += f'<tr>{row_html}</tr>'

        return table_html, cumulative_cnt


    def __call__(self, table: pd.DataFrame) -> Any:
        tokenized = self.tokenizer(table)

        cell_tokens = defaultdict(list)

        for id_ind, input_id in enumerate(tokenized['input_ids']):
            input_id = int(input_id)
            # 'prev_label', 'column_rank', 'inv_column_rank', 'numeric_relation' not required
            segment_id, col_id, row_id, *_ = tokenized['token_type_ids'][id_ind]
            token_text = self.tokenizer._convert_id_to_token(input_id)
            if int(segment_id) == 1:
                cell_tokens[(row_id, col_id)].append(token_text)

        table_html, cumulative_cnt = self.cells_to_html(cell_vals=[table.columns],
                                                        cell_tokens=cell_tokens,
                                                        row_id_start=0,
                                                        cell_element="th",
                                                        cumulative_cnt=0,
                                                        table_html="")

        table_html, cumulative_cnt = self.cells_to_html(cell_vals=table.values,
                                                        cell_tokens=cell_tokens,
                                                        row_id_start=1,
                                                        cell_element="td",
                                                        cumulative_cnt=cumulative_cnt,
                                                        table_html=table_html)

        table_html = f'<table>{table_html}</table>'
        return HTMLBody(table_html)