Spaces:
Build error
Build error
| import streamlit as st | |
| import PIL | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import os | |
| import io | |
| # import sys | |
| # import json | |
| from collections import OrderedDict, defaultdict | |
| import xml.etree.ElementTree as ET | |
| from tempfile import TemporaryDirectory | |
| import xlsxwriter | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from matplotlib.patches import Patch | |
| from paddleocr import PaddleOCR | |
| # import pytesseract | |
| # from pytesseract import Output | |
| import postprocess | |
| st.set_page_config(page_title='Table Extraction Demo', layout='wide') | |
| def load_ocr_instance(): | |
| ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True) | |
| return ocr_instance | |
| def load_detection_model(): | |
| detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True, skip_validation=True, trust_repo=True) | |
| return detection_model | |
| def load_structure_model(): | |
| structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True, skip_validation=True, trust_repo=True) | |
| return structure_model | |
| ocr_instance, detection_model, structure_model = load_ocr_instance(), load_detection_model(), load_structure_model() | |
| detection_class_names = ['table', 'table rotated', 'no object'] | |
| structure_class_names = [ | |
| 'table', 'table column', 'table row', 'table column header', | |
| 'table projected row header', 'table spanning cell', 'no object' | |
| ] | |
| detection_class_map = {k: v for v, k in enumerate(detection_class_names)} | |
| structure_class_map = {k: v for v, k in enumerate(structure_class_names)} | |
| detection_class_thresholds = { | |
| 'table': 0.5, | |
| 'table rotated': 0.5, | |
| 'no object': 10 | |
| } | |
| structure_class_thresholds = { | |
| "table": 0.45, | |
| "table column": 0.6, | |
| "table row": 0.5, | |
| "table column header": 0.4, | |
| "table projected row header": 0.3, | |
| "table spanning cell": 0.5, | |
| "no object": 10 | |
| } | |
| def PIL_to_cv(pil_img): | |
| return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
| def cv_to_PIL(cv_img): | |
| return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) | |
| def table_detection(pil_img, imgsz=640): | |
| image = PIL_to_cv(pil_img) | |
| pred = detection_model(image, size=imgsz) | |
| pred = pred.xywhn[0] | |
| result = pred.cpu().numpy() | |
| return result | |
| def table_structure(pil_img, imgsz=640): | |
| image = PIL_to_cv(pil_img) | |
| pred = structure_model(image, size=imgsz) | |
| pred = pred.xywhn[0] | |
| result = pred.cpu().numpy() | |
| return result | |
| def crop_image(pil_img, detection_result): | |
| crop_images = [] | |
| image = PIL_to_cv(pil_img) | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| for idx, result in enumerate(detection_result): | |
| class_id = int(result[5]) | |
| score = float(result[4]) | |
| min_x = result[0] | |
| min_y = result[1] | |
| w = result[2] | |
| h = result[3] | |
| if score < detection_class_thresholds[detection_class_names[class_id]]: | |
| continue | |
| x1 = int((min_x - w / 2) * width) | |
| y1 = int((min_y - h / 2) * height) | |
| x2 = int((min_x + w / 2) * width) | |
| y2 = int((min_y + h / 2) * height) | |
| # print(x1, y1, x2, y2) | |
| padding_x = max(int(0.02 * width), 30) | |
| padding_y = max(int(0.02 * height), 30) | |
| x1_pad = max(0, x1 - padding_x) | |
| y1_pad = max(0, y1 - padding_y) | |
| x2_pad = min(width, x2 + padding_x) | |
| y2_pad = min(height, y2 + padding_y) | |
| crop_image = image[y1_pad:y2_pad, x1_pad:x2_pad, :] | |
| crop_image = cv_to_PIL(crop_image) | |
| if detection_class_names[class_id] == 'table rotated': | |
| crop_image = crop_image.rotate(270, expand=True) | |
| crop_images.append(crop_image) | |
| cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255), thickness=2) | |
| label = f'{detection_class_names[class_id]} {score:.2f}' | |
| lw = max(round(sum(image.shape) / 2 * 0.003), 2) | |
| fontScale = lw / 3 | |
| thickness = max(lw - 1, 1) | |
| w_label, h_label = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=thickness)[0] | |
| cv2.rectangle(image, (x1, y1), (x1 + w_label, y1 - h_label - 3), (255, 0, 0), -1, cv2.LINE_AA) | |
| cv2.putText(image, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, fontScale, (255, 255, 255), thickness=thickness, lineType=cv2.LINE_AA) | |
| return crop_images, cv_to_PIL(image) | |
| def ocr(pil_img): | |
| image = PIL_to_cv(pil_img) | |
| result = ocr_instance.ocr(image) | |
| ocr_res = [] | |
| for ps, (text, score) in result[0]: | |
| x1 = min(p[0] for p in ps) | |
| y1 = min(p[1] for p in ps) | |
| x2 = max(p[0] for p in ps) | |
| y2 = max(p[1] for p in ps) | |
| word_info = { | |
| 'bbox': [x1, y1, x2, y2], | |
| 'text': text | |
| } | |
| ocr_res.append(word_info) | |
| return ocr_res | |
| def convert_stucture(page_tokens, pil_img, structure_result): | |
| image = PIL_to_cv(pil_img) | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| bboxes = [] | |
| scores = [] | |
| labels = [] | |
| for idx, result in enumerate(structure_result): | |
| class_id = int(result[5]) | |
| score = float(result[4]) | |
| min_x = result[0] | |
| min_y = result[1] | |
| w = result[2] | |
| h = result[3] | |
| x1 = int((min_x - w / 2) * width) | |
| y1 = int((min_y - h / 2) * height) | |
| x2 = int((min_x + w / 2) * width) | |
| y2 = int((min_y + h / 2) * height) | |
| # print(x1, y1, x2, y2) | |
| bboxes.append([x1, y1, x2, y2]) | |
| scores.append(score) | |
| labels.append(class_id) | |
| table_objects = [] | |
| for bbox, score, label in zip(bboxes, scores, labels): | |
| table_objects.append({'bbox': bbox, 'score': score, 'label': label}) | |
| # print('table_objects:', table_objects) | |
| table = {'objects': table_objects, 'page_num': 0} | |
| table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] | |
| if len(table_class_objects) > 1: | |
| table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) | |
| try: | |
| table_bbox = list(table_class_objects[0]['bbox']) | |
| except: | |
| table_bbox = (0, 0, 1000, 1000) | |
| # print('table_class_objects:', table_class_objects) | |
| # print('table_bbox:', table_bbox) | |
| tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] | |
| # print('tokens_in_table:', tokens_in_table) | |
| table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) | |
| return table_structures, cells, confidence_score | |
| def visualize_image(pil_img): | |
| plt.imshow(pil_img, interpolation='lanczos') | |
| plt.gcf().set_size_inches(10, 10) | |
| plt.axis('off') | |
| img_buf = io.BytesIO() | |
| plt.savefig(img_buf, bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return PIL.Image.open(img_buf) | |
| def visualize_ocr(pil_img, ocr_result): | |
| plt.imshow(pil_img, interpolation='lanczos') | |
| plt.gcf().set_size_inches(20, 20) | |
| ax = plt.gca() | |
| for idx, result in enumerate(ocr_result): | |
| bbox = result['bbox'] | |
| text = result['text'] | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-') | |
| ax.add_patch(rect) | |
| ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7) | |
| plt.xticks([], []) | |
| plt.yticks([], []) | |
| plt.gcf().set_size_inches(10, 10) | |
| plt.axis('off') | |
| img_buf = io.BytesIO() | |
| plt.savefig(img_buf, bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return PIL.Image.open(img_buf) | |
| def get_bbox_decorations(data_type, label): | |
| if label == 0: | |
| if data_type == 'detection': | |
| return 'brown', 0.05, 3, '//' | |
| else: | |
| return 'brown', 0, 3, None | |
| elif label == 1: | |
| return 'red', 0.15, 2, None | |
| elif label == 2: | |
| return 'blue', 0.15, 2, None | |
| elif label == 3: | |
| return 'magenta', 0.2, 3, '//' | |
| elif label == 4: | |
| return 'cyan', 0.2, 4, '//' | |
| elif label == 5: | |
| return 'green', 0.2, 4, '\\\\' | |
| return 'gray', 0, 0, None | |
| def visualize_structure(pil_img, structure_result): | |
| image = PIL_to_cv(pil_img) | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| plt.imshow(pil_img, interpolation='lanczos') | |
| plt.gcf().set_size_inches(20, 20) | |
| ax = plt.gca() | |
| for idx, result in enumerate(structure_result): | |
| class_id = int(result[5]) | |
| score = float(result[4]) | |
| min_x = result[0] | |
| min_y = result[1] | |
| w = result[2] | |
| h = result[3] | |
| if score < structure_class_thresholds[structure_class_names[class_id]]: | |
| continue | |
| x1 = int((min_x - w / 2) * width) | |
| y1 = int((min_y - h / 2) * height) | |
| x2 = int((min_x + w / 2) * width) | |
| y2 = int((min_y + h / 2) * height) | |
| # print(x1, y1, x2, y2) | |
| bbox = [x1, y1, x2, y2] | |
| color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id) | |
| # Fill | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], | |
| linewidth=linewidth, alpha=alpha, | |
| edgecolor='none',facecolor=color, | |
| linestyle=None) | |
| ax.add_patch(rect) | |
| # Hatch | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], | |
| linewidth=1, alpha=0.4, | |
| edgecolor=color, facecolor='none', | |
| linestyle='--',hatch=hatch) | |
| ax.add_patch(rect) | |
| # Edge | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], | |
| linewidth=linewidth, | |
| edgecolor=color, facecolor='none', | |
| linestyle='--') | |
| ax.add_patch(rect) | |
| plt.xticks([], []) | |
| plt.yticks([], []) | |
| legend_elements = [] | |
| for class_name in structure_class_names[:-1]: | |
| color, alpha, linewidth, hatch = get_bbox_decorations('recognition', structure_class_map[class_name]) | |
| legend_elements.append( | |
| Patch(facecolor='none', edgecolor=color, linestyle='--', label=class_name, hatch=hatch) | |
| ) | |
| plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0, | |
| fontsize=10, ncol=3) | |
| plt.gcf().set_size_inches(10, 10) | |
| plt.axis('off') | |
| img_buf = io.BytesIO() | |
| plt.savefig(img_buf, bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return PIL.Image.open(img_buf) | |
| def visualize_cells(pil_img, cells): | |
| plt.imshow(pil_img, interpolation='lanczos') | |
| plt.gcf().set_size_inches(20, 20) | |
| ax = plt.gca() | |
| for cell in cells: | |
| bbox = cell['bbox'] | |
| if cell['header']: | |
| facecolor = (1, 0, 0.45) | |
| edgecolor = (1, 0, 0.45) | |
| alpha = 0.3 | |
| linewidth = 2 | |
| hatch='//////' | |
| elif cell['subheader']: | |
| facecolor = (0.95, 0.6, 0.1) | |
| edgecolor = (0.95, 0.6, 0.1) | |
| alpha = 0.3 | |
| linewidth = 2 | |
| hatch='//////' | |
| else: | |
| facecolor = (0.3, 0.74, 0.8) | |
| edgecolor = (0.3, 0.7, 0.6) | |
| alpha = 0.3 | |
| linewidth = 2 | |
| hatch='\\\\\\\\\\\\' | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, | |
| edgecolor='none',facecolor=facecolor, alpha=0.1) | |
| ax.add_patch(rect) | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, | |
| edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha) | |
| ax.add_patch(rect) | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, | |
| edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2) | |
| ax.add_patch(rect) | |
| plt.xticks([], []) | |
| plt.yticks([], []) | |
| legend_elements = [Patch(facecolor=(0.3, 0.74, 0.8), edgecolor=(0.3, 0.7, 0.6), | |
| label='Data cell', hatch='\\\\\\\\\\\\', alpha=0.3), | |
| Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45), | |
| label='Column header cell', hatch='//////', alpha=0.3), | |
| Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1), | |
| label='Projected row header cell', hatch='//////', alpha=0.3)] | |
| plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0, | |
| fontsize=10, ncol=3) | |
| plt.gcf().set_size_inches(10, 10) | |
| plt.axis('off') | |
| img_buf = io.BytesIO() | |
| plt.savefig(img_buf, bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return PIL.Image.open(img_buf) | |
| # def pytess(cell_pil_img): | |
| # return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip() | |
| # def resize(pil_img, size=1800): | |
| # length_x, width_y = pil_img.size | |
| # factor = max(1, size / length_x) | |
| # size = int(factor * length_x), int(factor * width_y) | |
| # pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS) | |
| # return pil_img, factor | |
| # def image_smoothening(img): | |
| # ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY) | |
| # ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| # blur = cv2.GaussianBlur(th2, (1, 1), 0) | |
| # ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| # return th3 | |
| # def remove_noise_and_smooth(pil_img): | |
| # img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) | |
| # filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3) | |
| # kernel = np.ones((1, 1), np.uint8) | |
| # opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel) | |
| # closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel) | |
| # img = image_smoothening(img) | |
| # or_image = cv2.bitwise_or(img, closing) | |
| # pil_img = PIL.Image.fromarray(or_image) | |
| # return pil_img | |
| # def extract_text_from_cells(pil_img, cells): | |
| # pil_img, factor = resize(pil_img) | |
| # #pil_img = remove_noise_and_smooth(pil_img) | |
| # #display(pil_img) | |
| # for cell in cells: | |
| # bbox = [x * factor for x in cell['bbox']] | |
| # cell_pil_img = pil_img.crop(bbox) | |
| # #cell_pil_img = remove_noise_and_smooth(cell_pil_img) | |
| # #cell_pil_img = tess_prep(cell_pil_img) | |
| # cell['cell text'] = pytess(cell_pil_img) | |
| # return cells | |
| def extract_text_from_cells(cells, sep=' '): | |
| for cell in cells: | |
| spans = cell['spans'] | |
| text = '' | |
| for span in spans: | |
| if 'text' in span: | |
| text += span['text'] + sep | |
| cell['cell_text'] = text | |
| return cells | |
| def cells_to_csv(cells): | |
| if len(cells) > 0: | |
| num_columns = max([max(cell['column_nums']) for cell in cells]) + 1 | |
| num_rows = max([max(cell['row_nums']) for cell in cells]) + 1 | |
| else: | |
| return | |
| header_cells = [cell for cell in cells if cell['header']] | |
| if len(header_cells) > 0: | |
| max_header_row = max([max(cell['row_nums']) for cell in header_cells]) | |
| else: | |
| max_header_row = -1 | |
| table_array = np.empty([num_rows, num_columns], dtype='object') | |
| if len(cells) > 0: | |
| for cell in cells: | |
| for row_num in cell['row_nums']: | |
| for column_num in cell['column_nums']: | |
| table_array[row_num, column_num] = cell['cell_text'] | |
| header = table_array[:max_header_row+1,:] | |
| flattened_header = [] | |
| for col in header.transpose(): | |
| flattened_header.append(' | '.join(OrderedDict.fromkeys(col))) | |
| df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header) | |
| return df, df.to_csv(index=None) | |
| def cells_to_html(cells): | |
| cells = sorted(cells, key=lambda k: min(k['column_nums'])) | |
| cells = sorted(cells, key=lambda k: min(k['row_nums'])) | |
| table = ET.Element('table') | |
| current_row = -1 | |
| for cell in cells: | |
| this_row = min(cell['row_nums']) | |
| attrib = {} | |
| colspan = len(cell['column_nums']) | |
| if colspan > 1: | |
| attrib['colspan'] = str(colspan) | |
| rowspan = len(cell['row_nums']) | |
| if rowspan > 1: | |
| attrib['rowspan'] = str(rowspan) | |
| if this_row > current_row: | |
| current_row = this_row | |
| if cell['header']: | |
| cell_tag = 'th' | |
| row = ET.SubElement(table, 'tr') | |
| else: | |
| cell_tag = 'td' | |
| row = ET.SubElement(table, 'tr') | |
| tcell = ET.SubElement(row, cell_tag, attrib=attrib) | |
| tcell.text = cell['cell_text'] | |
| return str(ET.tostring(table, encoding='unicode', short_empty_elements=False)) | |
| # def cells_to_html(cells): | |
| # for cell in cells: | |
| # cell['column_nums'].sort() | |
| # cell['row_nums'].sort() | |
| # n_cols = max(cell['column_nums'][-1] for cell in cells) + 1 | |
| # n_rows = max(cell['row_nums'][-1] for cell in cells) + 1 | |
| # html_code = '' | |
| # for r in range(n_rows): | |
| # r_cells = [cell for cell in cells if cell['row_nums'][0] == r] | |
| # r_cells.sort(key=lambda x: x['column_nums'][0]) | |
| # r_html = '' | |
| # for cell in r_cells: | |
| # rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1 | |
| # colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1 | |
| # r_html += f'<td rowspan='{rowspan}' colspan='{colspan}'>{escape(cell['text'])}</td>' | |
| # html_code += f'<tr>{r_html}</tr>' | |
| # html_code = '''<html> | |
| # <head> | |
| # <meta charset='UTF-8'> | |
| # <style> | |
| # table, th, td { | |
| # border: 1px solid black; | |
| # font-size: 10px; | |
| # } | |
| # </style> | |
| # </head> | |
| # <body> | |
| # <table frame='hsides' rules='groups' width='100%%'> | |
| # %s | |
| # </table> | |
| # </body> | |
| # </html>''' % html_code | |
| # soup = bs(html_code) | |
| # html_code = soup.prettify() | |
| # return html_code | |
| def cells_to_excel(cells, file_path): | |
| def int2xlsx(i): | |
| if i < 26: | |
| return chr(i + 65) | |
| return f'{chr(i // 26 + 64)}{chr(i % 26 + 65)}' | |
| cells = sorted(cells, key=lambda k: min(k['column_nums'])) | |
| cells = sorted(cells, key=lambda k: min(k['row_nums'])) | |
| workbook = xlsxwriter.Workbook(file_path) | |
| cell_format = workbook.add_format( | |
| {'align': 'center', 'valign': 'vcenter'} | |
| ) | |
| worksheet = workbook.add_worksheet(name='Table') | |
| table_start_index = 0 | |
| for cell in cells: | |
| start_row = min(cell['row_nums']) | |
| end_row = max(cell['row_nums']) | |
| start_col = min(cell['column_nums']) | |
| end_col = max(cell['column_nums']) | |
| if start_row == end_row and start_col == end_col: | |
| worksheet.write( | |
| table_start_index + start_row, | |
| start_col, | |
| cell['cell_text'], | |
| cell_format, | |
| ) | |
| else: | |
| if start_col == end_col and start_row == end_row: | |
| excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}' | |
| else: | |
| excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}:{int2xlsx(table_start_index + end_col)}{table_start_index + end_row + 1}' | |
| worksheet.merge_range( | |
| excel_index, cell['cell_text'], cell_format | |
| ) | |
| workbook.close() | |
| def main(): | |
| st.title('Table Extraction Demo') | |
| filename = st.file_uploader('Upload image', type=['png', 'jpeg', 'jpg']) | |
| if st.button('Analyze image'): | |
| if filename is None: | |
| st.write('Please upload an image') | |
| else: | |
| tabs = st.tabs( | |
| ['Table Detection', 'Table Structure Recognition', 'Extracted Table(s)'] | |
| ) | |
| print(filename) | |
| pil_img = PIL.Image.open(filename) | |
| detection_result = table_detection(pil_img) | |
| crop_images, vis_det_img = crop_image(pil_img, detection_result) | |
| all_cells = [] | |
| with tabs[0]: | |
| st.header('Table Detection') | |
| st.image(vis_det_img) | |
| with tabs[1]: | |
| st.header('Table Structure Recognition') | |
| str_cols = st.columns(4) | |
| str_cols[0].subheader('Table image') | |
| str_cols[1].subheader('OCR result') | |
| str_cols[2].subheader('Structure result') | |
| str_cols[3].subheader('Cells result') | |
| for idx, img in enumerate(crop_images): | |
| str_cols = st.columns(4) | |
| vis_img = visualize_image(img) | |
| str_cols[0].image(vis_img) | |
| ocr_result = ocr(img) | |
| vis_ocr_img = visualize_ocr(img, ocr_result) | |
| str_cols[1].image(vis_ocr_img) | |
| structure_result = table_structure(img) | |
| vis_str_img = visualize_structure(img, structure_result) | |
| str_cols[2].image(vis_str_img) | |
| table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result) | |
| cells = extract_text_from_cells(cells) | |
| vis_cells_img = visualize_cells(img, cells) | |
| str_cols[3].image(vis_cells_img) | |
| all_cells.append(cells) | |
| #df, csv_result = cells_to_csv(cells) | |
| #print(df) | |
| with tabs[2]: | |
| st.header('Extracted Table(s)') | |
| for idx, col in enumerate(st.columns(len(all_cells))): | |
| with col: | |
| if len(all_cells) > 1: | |
| st.header(f'Table {idx + 1}') | |
| with TemporaryDirectory() as temp_dir_path: | |
| df = None | |
| xlsx_path = os.path.join(temp_dir_path, f'debug_{idx}.xlsx') | |
| cells_to_excel(all_cells[idx], xlsx_path) | |
| with open(xlsx_path, 'rb') as ref: | |
| df = pd.read_excel(ref) | |
| st.dataframe(df) | |
| st.download_button( | |
| 'Download Excel File', | |
| ref, | |
| file_name=f'output_{idx}.xlsx', | |
| ) | |
| for idx, cells in enumerate(all_cells): | |
| html_result = cells_to_html(cells) | |
| st.subheader(f'HTML Table {idx + 1}') | |
| st.markdown(html_result, unsafe_allow_html=True) | |
| if __name__ == '__main__': | |
| main() | |