import streamlit as st
import PIL
import cv2
import numpy as np
import pandas as pd
import torch
import os
import io
# import sys
# import json
from collections import OrderedDict, defaultdict
import xml.etree.ElementTree as ET
from tempfile import TemporaryDirectory
import xlsxwriter
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch

from paddleocr import PaddleOCR
# import pytesseract
# from pytesseract import Output
from fitz import Rect

import postprocess


st.set_page_config(page_title='Table Extraction Demo', layout='wide')


@st.experimental_singleton(ttl=3600)
def load_ocr_instance():
    ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
    return ocr_instance


@st.experimental_singleton(ttl=3600)
def load_detection_model():
    detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
    return detection_model


@st.experimental_singleton(ttl=3600)
def load_structure_model():
    structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
    return structure_model


ocr_instance, detection_model, structure_model = load_ocr_instance(), load_detection_model(), load_structure_model()

detection_class_names = ['table', 'table rotated', 'no object']
structure_class_names = [
    'table', 'table column', 'table row', 'table column header',
    'table projected row header', 'table spanning cell', 'no object'
]

detection_class_map = {k: v for v, k in enumerate(detection_class_names)}
structure_class_map = {k: v for v, k in enumerate(structure_class_names)}

detection_class_thresholds = {
    'table': 0.5,
    'table rotated': 0.5,
    'no object': 10
}
structure_class_thresholds = {
    "table": 0.45,
    "table column": 0.6,
    "table row": 0.5,
    "table column header": 0.4,
    "table projected row header": 0.3,
    "table spanning cell": 0.5,
    "no object": 10
}


def PIL_to_cv(pil_img):
    return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)


def cv_to_PIL(cv_img):
    return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))


def table_detection(pil_img, imgsz=640):
    image = PIL_to_cv(pil_img)
    pred = detection_model(image, size=imgsz)
    pred = pred.xywhn[0]
    result = pred.detach().cpu().numpy()
    return result


def table_structure(pil_img, imgsz=640):
    image = PIL_to_cv(pil_img)
    pred = structure_model(image, size=imgsz)
    pred = pred.xywhn[0]
    result = pred.detach().cpu().numpy()
    return result


def crop_image(pil_img, detection_result):
    crop_images = []
    image = PIL_to_cv(pil_img)
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)
    for idx, result in enumerate(detection_result):
        class_id = int(result[5])
        score = float(result[4])
        min_x = result[0]
        min_y = result[1]
        w = result[2]
        h = result[3]

        if score < detection_class_thresholds[detection_class_names[class_id]]:
            continue

        x1 = int((min_x - w / 2) * width)
        y1 = int((min_y - h / 2) * height)
        x2 = int((min_x + w / 2) * width)
        y2 = int((min_y + h / 2) * height)
        # print(x1, y1, x2, y2)

        padding_x = max(int(0.02 * width), 30)
        padding_y = max(int(0.02 * height), 30)

        x1_pad = max(0, x1 - padding_x)
        y1_pad = max(0, y1 - padding_y)
        x2_pad = min(width, x2 + padding_x)
        y2_pad = min(height, y2 + padding_y)

        crop_image = image[y1_pad:y2_pad, x1_pad:x2_pad, :]
        crop_image = cv_to_PIL(crop_image)
        if detection_class_names[class_id] == 'table rotated':
            crop_image = crop_image.rotate(270, expand=True)

        crop_images.append(crop_image)

        cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255), thickness=2)

        label = f'{detection_class_names[class_id]} {score:.2f}'

        lw = max(round(sum(image.shape) / 2 * 0.003), 2)
        fontScale = lw / 3
        thickness = max(lw - 1, 1)
        w_label, h_label = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=thickness)[0]
        cv2.rectangle(image, (x1, y1), (x1 + w_label, y1 - h_label - 3), (0, 0, 255), -1, cv2.LINE_AA)
        cv2.putText(image, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, fontScale, (255, 255, 255), thickness=thickness, lineType=cv2.LINE_AA)

    return crop_images, cv_to_PIL(image)


def ocr(pil_img):
    image = PIL_to_cv(pil_img)
    result = ocr_instance.ocr(image)
    ocr_res = []

    for ps, (text, score) in result[0]:
        x1 = min(p[0] for p in ps)
        y1 = min(p[1] for p in ps)
        x2 = max(p[0] for p in ps)
        y2 = max(p[1] for p in ps)
        word_info = {
            'bbox': [x1, y1, x2, y2],
            'text': text
        }
        ocr_res.append(word_info)

    return ocr_res


def convert_stucture(page_tokens, pil_img, structure_result):
    image = PIL_to_cv(pil_img)

    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)

    bboxes = []
    scores = []
    labels = []
    for idx, result in enumerate(structure_result):
        class_id = int(result[5])
        score = float(result[4])
        min_x = result[0]
        min_y = result[1]
        w = result[2]
        h = result[3]

        x1 = int((min_x - w / 2) * width)
        y1 = int((min_y - h / 2) * height)
        x2 = int((min_x + w / 2) * width)
        y2 = int((min_y + h / 2) * height)
        # print(x1, y1, x2, y2)

        bboxes.append([x1, y1, x2, y2])
        scores.append(score)
        labels.append(class_id)

    table_objects = []
    for bbox, score, label in zip(bboxes, scores, labels):
        table_objects.append({'bbox': bbox, 'score': score, 'label': label})
    # print('table_objects:', table_objects)

    table = {'objects': table_objects, 'page_num': 0}

    table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
    if len(table_class_objects) > 1:
        table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
    try:
        table_bbox = list(table_class_objects[0]['bbox'])
    except:
        table_bbox = (0, 0, 1000, 1000)
    # print('table_class_objects:', table_class_objects)
    # print('table_bbox:', table_bbox)

    tmp = Rect(table_bbox)
    for obj in table_objects:
        if structure_class_names[obj['label']] in ('table column', 'table row'):
            if postprocess.iob(obj['bbox'], table_bbox) >= 0.001:
                tmp.include_rect(obj['bbox'])
    table_bbox = (tmp[0], tmp[1], tmp[2], tmp[3])

    tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.001]
    # print('tokens_in_table:', tokens_in_table)

    table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)

    return table_structures, cells, confidence_score


def visualize_image(pil_img):
    plt.imshow(pil_img, interpolation='lanczos')
    plt.gcf().set_size_inches(10, 10)
    plt.axis('off')
    img_buf = io.BytesIO()
    plt.savefig(img_buf, bbox_inches='tight', dpi=150)
    plt.close()
    return PIL.Image.open(img_buf)


def visualize_ocr(pil_img, ocr_result):
    plt.imshow(pil_img, interpolation='lanczos')
    plt.gcf().set_size_inches(20, 20)
    ax = plt.gca()

    for idx, result in enumerate(ocr_result):
        bbox = result['bbox']
        text = result['text']
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-')
        ax.add_patch(rect)
        ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7)

    plt.xticks([], [])
    plt.yticks([], [])

    plt.gcf().set_size_inches(10, 10)
    plt.axis('off')
    img_buf = io.BytesIO()
    plt.savefig(img_buf, bbox_inches='tight', dpi=150)
    plt.close()

    return PIL.Image.open(img_buf)


def get_bbox_decorations(data_type, label):
    if label == 0:
        if data_type == 'detection':
            return 'brown', 0.05, 3, '//'
        else:
            return 'brown', 0, 3, None
    elif label == 1:
        return 'red', 0.15, 2, None
    elif label == 2:
        return 'blue', 0.15, 2, None
    elif label == 3:
        return 'magenta', 0.2, 3, '//'
    elif label == 4:
        return 'cyan', 0.2, 4, '//'
    elif label == 5:
        return 'green', 0.2, 4, '\\\\'

    return 'gray', 0, 0, None


def visualize_structure(pil_img, structure_result):
    image = PIL_to_cv(pil_img)
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)

    plt.imshow(pil_img, interpolation='lanczos')
    plt.gcf().set_size_inches(20, 20)
    ax = plt.gca()

    for idx, result in enumerate(structure_result):
        class_id = int(result[5])
        score = float(result[4])
        min_x = result[0]
        min_y = result[1]
        w = result[2]
        h = result[3]

        if score < structure_class_thresholds[structure_class_names[class_id]]:
            continue

        x1 = int((min_x - w / 2) * width)
        y1 = int((min_y - h / 2) * height)
        x2 = int((min_x + w / 2) * width)
        y2 = int((min_y + h / 2) * height)
        # print(x1, y1, x2, y2)
        bbox = [x1, y1, x2, y2]

        color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id)
        # Fill
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
                                    linewidth=linewidth, alpha=alpha,
                                    edgecolor='none',facecolor=color,
                                    linestyle=None)
        ax.add_patch(rect)
        # Hatch
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
                                    linewidth=1, alpha=0.4,
                                    edgecolor=color, facecolor='none',
                                    linestyle='--',hatch=hatch)
        ax.add_patch(rect)
        # Edge
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
                                    linewidth=linewidth,
                                    edgecolor=color, facecolor='none',
                                    linestyle='--')
        ax.add_patch(rect)

    plt.xticks([], [])
    plt.yticks([], [])

    legend_elements = []
    for class_name in structure_class_names[:-1]:
        color, alpha, linewidth, hatch = get_bbox_decorations('recognition', structure_class_map[class_name])
        legend_elements.append(
            Patch(facecolor='none', edgecolor=color, linestyle='--', label=class_name, hatch=hatch)
        )

    plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
                    fontsize=10, ncol=3)
    plt.gcf().set_size_inches(10, 10)
    plt.axis('off')
    img_buf = io.BytesIO()
    plt.savefig(img_buf, bbox_inches='tight', dpi=150)
    plt.close()

    return PIL.Image.open(img_buf)


def visualize_cells(pil_img, cells):
    plt.imshow(pil_img, interpolation='lanczos')
    plt.gcf().set_size_inches(20, 20)
    ax = plt.gca()

    for cell in cells:
        bbox = cell['bbox']

        if cell['header']:
            facecolor = (1, 0, 0.45)
            edgecolor = (1, 0, 0.45)
            alpha = 0.3
            linewidth = 2
            hatch='//////'
        elif cell['subheader']:
            facecolor = (0.95, 0.6, 0.1)
            edgecolor = (0.95, 0.6, 0.1)
            alpha = 0.3
            linewidth = 2
            hatch='//////'
        else:
            facecolor = (0.3, 0.74, 0.8)
            edgecolor = (0.3, 0.7, 0.6)
            alpha = 0.3
            linewidth = 2
            hatch='\\\\\\\\\\\\'

        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
                                    edgecolor='none',facecolor=facecolor, alpha=0.1)
        ax.add_patch(rect)
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
                                    edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
        ax.add_patch(rect)
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
                                    edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
        ax.add_patch(rect)

    plt.xticks([], [])
    plt.yticks([], [])

    legend_elements = [Patch(facecolor=(0.3, 0.74, 0.8), edgecolor=(0.3, 0.7, 0.6),
                                label='Data cell', hatch='\\\\\\\\\\\\', alpha=0.3),
                        Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
                                label='Column header cell', hatch='//////', alpha=0.3),
                        Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
                                label='Projected row header cell', hatch='//////', alpha=0.3)]
    plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
                    fontsize=10, ncol=3)
    plt.gcf().set_size_inches(10, 10)
    plt.axis('off')
    img_buf = io.BytesIO()
    plt.savefig(img_buf, bbox_inches='tight', dpi=150)
    plt.close()

    return PIL.Image.open(img_buf)


# def pytess(cell_pil_img):
#     return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()


# def resize(pil_img, size=1800):
#     length_x, width_y = pil_img.size
#     factor = max(1, size / length_x)
#     size = int(factor * length_x), int(factor * width_y)
#     pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
#     return pil_img, factor


# def image_smoothening(img):
#     ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
#     ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
#     blur = cv2.GaussianBlur(th2, (1, 1), 0)
#     ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
#     return th3


# def remove_noise_and_smooth(pil_img):
#     img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
#     filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
#     kernel = np.ones((1, 1), np.uint8)
#     opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
#     closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
#     img = image_smoothening(img)
#     or_image = cv2.bitwise_or(img, closing)
#     pil_img = PIL.Image.fromarray(or_image)
#     return pil_img


# def extract_text_from_cells(pil_img, cells):
#     pil_img, factor = resize(pil_img)
#     #pil_img = remove_noise_and_smooth(pil_img)
#     #display(pil_img)
#     for cell in cells:
#         bbox = [x * factor for x in cell['bbox']]
#         cell_pil_img = pil_img.crop(bbox)
#         #cell_pil_img = remove_noise_and_smooth(cell_pil_img)
#         #cell_pil_img = tess_prep(cell_pil_img)
#         cell['cell text'] = pytess(cell_pil_img)
#     return cells


def extract_text_from_cells(cells, sep=' '):
    for cell in cells:
        spans = cell['spans']
        text = ''
        for span in spans:
            if 'text' in span:
                text += span['text'] + sep
        cell['cell_text'] = text
    return cells


def cells_to_csv(cells):
    if len(cells) > 0:
        num_columns = max([max(cell['column_nums']) for cell in cells]) + 1
        num_rows = max([max(cell['row_nums']) for cell in cells]) + 1
    else:
        return

    header_cells = [cell for cell in cells if cell['header']]
    if len(header_cells) > 0:
        max_header_row = max([max(cell['row_nums']) for cell in header_cells])
    else:
        max_header_row = -1

    table_array = np.empty([num_rows, num_columns], dtype='object')
    if len(cells) > 0:
        for cell in cells:
            for row_num in cell['row_nums']:
                for column_num in cell['column_nums']:
                    table_array[row_num, column_num] = cell['cell_text']

    header = table_array[:max_header_row+1,:]
    flattened_header = []
    for col in header.transpose():
        flattened_header.append(' | '.join(OrderedDict.fromkeys(col)))
    df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header)

    return df, df.to_csv(index=None)


def cells_to_html(cells):
    cells = sorted(cells, key=lambda k: min(k['column_nums']))
    cells = sorted(cells, key=lambda k: min(k['row_nums']))

    table = ET.Element('table')
    current_row = -1

    for cell in cells:
        this_row = min(cell['row_nums'])

        attrib = {}
        colspan = len(cell['column_nums'])
        if colspan > 1:
            attrib['colspan'] = str(colspan)
        rowspan = len(cell['row_nums'])
        if rowspan > 1:
            attrib['rowspan'] = str(rowspan)
        if this_row > current_row:
            current_row = this_row
            if cell['header']:
                cell_tag = 'th'
                row = ET.SubElement(table, 'tr')
            else:
                cell_tag = 'td'
                row = ET.SubElement(table, 'tr')
        tcell = ET.SubElement(row, cell_tag, attrib=attrib)
        tcell.text = cell['cell_text']

    return str(ET.tostring(table, encoding='unicode', short_empty_elements=False))


# def cells_to_html(cells):
#     for cell in cells:
#         cell['column_nums'].sort()
#         cell['row_nums'].sort()
#     n_cols = max(cell['column_nums'][-1] for cell in cells) + 1
#     n_rows = max(cell['row_nums'][-1] for cell in cells) + 1
#     html_code = ''
#     for r in range(n_rows):
#         r_cells = [cell for cell in cells if cell['row_nums'][0] == r]
#         r_cells.sort(key=lambda x: x['column_nums'][0])
#         r_html = ''
#         for cell in r_cells:
#             rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
#             colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
#             r_html += f'<td rowspan='{rowspan}' colspan='{colspan}'>{escape(cell['text'])}</td>'
#         html_code += f'<tr>{r_html}</tr>'
#     html_code = '''<html>
#                    <head>
#                    <meta charset='UTF-8'>
#                    <style>
#                    table, th, td {
#                      border: 1px solid black;
#                      font-size: 10px;
#                    }
#                    </style>
#                    </head>
#                    <body>
#                    <table frame='hsides' rules='groups' width='100%%'>
#                      %s
#                    </table>
#                    </body>
#                    </html>''' % html_code
#     soup = bs(html_code)
#     html_code = soup.prettify()
#     return html_code


def cells_to_excel(cells, file_path):

    def int2xlsx(i):
        if i < 26:
            return chr(i + 65)
        return f'{chr(i // 26 + 64)}{chr(i % 26 + 65)}'

    cells = sorted(cells, key=lambda k: min(k['column_nums']))
    cells = sorted(cells, key=lambda k: min(k['row_nums']))

    workbook = xlsxwriter.Workbook(file_path)

    cell_format = workbook.add_format(
        {'align': 'center', 'valign': 'vcenter'}
    )

    worksheet = workbook.add_worksheet(name='Table')

    table_start_index = 0

    for cell in cells:
        start_row = min(cell['row_nums'])
        end_row = max(cell['row_nums'])
        start_col = min(cell['column_nums'])
        end_col = max(cell['column_nums'])
        if start_row == end_row and start_col == end_col:
            worksheet.write(
                table_start_index + start_row,
                start_col,
                cell['cell_text'],
                cell_format,
            )
        else:
            if start_col == end_col and start_row == end_row:
                excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}'
            else:
                excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}:{int2xlsx(table_start_index + end_col)}{table_start_index + end_row + 1}'
            worksheet.merge_range(
                excel_index, cell['cell_text'], cell_format
            )

    workbook.close()


def main():

    st.title('Table Extraction Demo')

    filename = st.file_uploader('Upload image', type=['png', 'jpeg', 'jpg'])

    if st.button('Analyze image'):

        if filename is None:
            st.write('Please upload an image')

        else:
            tabs = st.tabs(
                ['Table Detection', 'Table Structure Recognition', 'Extracted Table(s)']
            )

            print(filename)
            pil_img = PIL.Image.open(filename)

            detection_result = table_detection(pil_img)
            crop_images, vis_det_img = crop_image(pil_img, detection_result)

            all_cells = []

            with tabs[0]:
                st.header('Table Detection')
                st.image(vis_det_img)

            with tabs[1]:
                st.header('Table Structure Recognition')

                str_cols = st.columns(4)
                str_cols[0].subheader('Table image')
                str_cols[1].subheader('OCR result')
                str_cols[2].subheader('Structure result')
                str_cols[3].subheader('Cells result')

                for idx, img in enumerate(crop_images):
                    str_cols = st.columns(4)

                    vis_img = visualize_image(img)
                    str_cols[0].image(vis_img)

                    ocr_result = ocr(img)
                    vis_ocr_img = visualize_ocr(img, ocr_result)
                    str_cols[1].image(vis_ocr_img)

                    structure_result = table_structure(img)
                    vis_str_img = visualize_structure(img, structure_result)
                    str_cols[2].image(vis_str_img)

                    table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
                    cells = extract_text_from_cells(cells)
                    vis_cells_img = visualize_cells(img, cells)
                    str_cols[3].image(vis_cells_img)

                    all_cells.append(cells)

                    #df, csv_result = cells_to_csv(cells)
                    #print(df)

            with tabs[2]:
                st.header('Extracted Table(s)')
                for idx, col in enumerate(st.columns(len(all_cells))):
                    with col:
                        if len(all_cells) > 1:
                            st.header(f'Table {idx + 1}')

                        with TemporaryDirectory() as temp_dir_path:
                            df = None
                            xlsx_path = os.path.join(temp_dir_path, f'debug_{idx}.xlsx')
                            cells_to_excel(all_cells[idx], xlsx_path)
                            with open(xlsx_path, 'rb') as ref:
                                df = pd.read_excel(ref)
                                st.dataframe(df)
                                st.download_button(
                                    'Download Excel File',
                                    ref,
                                    file_name=f'output_{idx}.xlsx',
                                )

                for idx, cells in enumerate(all_cells):
                    html_result = cells_to_html(cells)
                    st.subheader(f'HTML Table {idx + 1}')
                    st.markdown(html_result, unsafe_allow_html=True)


if __name__ == '__main__':
    main()