import streamlit as st import PIL import cv2 import numpy as np import pandas as pd import torch import os import io # import sys # import json from collections import OrderedDict, defaultdict import xml.etree.ElementTree as ET from tempfile import TemporaryDirectory import xlsxwriter import matplotlib.pyplot as plt import matplotlib.patches as patches from paddleocr import PaddleOCR # import pytesseract # from pytesseract import Output import postprocess ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True) detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True, skip_validation=True, trust_repo=True) structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True, skip_validation=True, trust_repo=True) imgsz = 640 detection_class_names = ['table', 'table rotated'] structure_class_names = [ 'table', 'table column', 'table row', 'table column header', 'table projected row header', 'table spanning cell', 'no object' ] structure_class_map = {k: v for v, k in enumerate(structure_class_names)} structure_class_thresholds = { "table": 0.42, "table column": 0.56, "table row": 0.5, "table column header": 0.38, "table projected row header": 0.27, "table spanning cell": 0.4, "no object": 10 } def PIL_to_cv(pil_img): return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) def cv_to_PIL(cv_img): return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) def table_detection(pil_img): image = PIL_to_cv(pil_img) pred = detection_model(image, size=imgsz) pred = pred.xywhn[0] result = pred.cpu().numpy() return result def table_structure(pil_img): image = PIL_to_cv(pil_img) pred = structure_model(image, size=imgsz) pred = pred.xywhn[0] result = pred.cpu().numpy() return result def crop_image(pil_img, detection_result, padding=30): crop_images = [] image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) for i, result in enumerate(detection_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = int((min_x - w / 2) * width) y1 = int((min_y - h / 2) * height) x2 = int((min_x + w / 2) * width) y2 = int((min_y + h / 2) * height) # print(x1, y1, x2, y2) x1_pad = max(0, x1 - padding) y1_pad = max(0, y1 - padding) x2_pad = min(width, x2 + padding) y2_pad = min(height, y2 + padding) crop_image = image[y1_pad:y2_pad, x1_pad:x2_pad, :] crop_image = cv_to_PIL(crop_image) if class_id == 1: # table rotated crop_image = crop_image.rotate(270, expand=True) crop_images.append(crop_image) cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255)) cv2.putText(image, f'{score:.2f}', (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, color=(255, 0, 0)) return crop_images, cv_to_PIL(image) def ocr(pil_img): image = PIL_to_cv(pil_img) result = ocr_instance.ocr(image) ocr_res = [] for ps, (text, score) in result[0]: x1 = min(p[0] for p in ps) y1 = min(p[1] for p in ps) x2 = max(p[0] for p in ps) y2 = max(p[1] for p in ps) word_info = { 'bbox': [x1, y1, x2, y2], 'text': text } ocr_res.append(word_info) return ocr_res def convert_stucture(page_tokens, pil_img, structure_result): image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) bboxes = [] scores = [] labels = [] for i, result in enumerate(structure_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = int((min_x - w / 2) * width) y1 = int((min_y - h / 2) * height) x2 = int((min_x + w / 2) * width) y2 = int((min_y + h / 2) * height) # print(x1, y1, x2, y2) bboxes.append([x1, y1, x2, y2]) scores.append(score) labels.append(class_id) table_objects = [] for bbox, score, label in zip(bboxes, scores, labels): table_objects.append({'bbox': bbox, 'score': score, 'label': label}) # print('table_objects:', table_objects) table = {'objects': table_objects, 'page_num': 0} table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] if len(table_class_objects) > 1: table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) try: table_bbox = list(table_class_objects[0]['bbox']) except: table_bbox = (0, 0, 1000, 1000) # print('table_class_objects:', table_class_objects) # print('table_bbox:', table_bbox) tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] # print('tokens_in_table:', tokens_in_table) table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) return table_structures, cells, confidence_score def visualize_ocr(pil_img, ocr_result): image = PIL_to_cv(pil_img) for i, res in enumerate(ocr_result): bbox = res['bbox'] x1 = int(bbox[0]) y1 = int(bbox[1]) x2 = int(bbox[2]) y2 = int(bbox[3]) cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) cv2.putText(image, res['text'], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255, 0, 0)) return cv_to_PIL(image) def get_bbox_decorations(data_type, label): if label == 0: if data_type == 'detection': return 'brown', 0.05, 3, '//' else: return 'brown', 0, 3, None elif label == 1: return 'red', 0.15, 2, None elif label == 2: return 'blue', 0.15, 2, None elif label == 3: return 'magenta', 0.2, 3, '//' elif label == 4: return 'cyan', 0.2, 4, '//' elif label == 5: return 'green', 0.2, 4, '\\\\' return 'gray', 0, 0, None def visualize_structure(pil_img, structure_result): image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) fig, ax = plt.subplots(1) ax.imshow(pil_img, interpolation='lanczos') for i, result in enumerate(structure_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = int((min_x - w / 2) * width) y1 = int((min_y - h / 2) * height) x2 = int((min_x + w / 2) * width) y2 = int((min_y + h / 2) * height) # print(x1, y1, x2, y2) bbox = [x1, y1, x2, y2] if score >= structure_class_thresholds[structure_class_names[class_id]]: #cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id) # Fill rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, alpha=alpha, edgecolor='none',facecolor=color, linestyle=None) ax.add_patch(rect) # Hatch rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, alpha=0.4, edgecolor=color,facecolor='none', linestyle='--',hatch=hatch) ax.add_patch(rect) # Edge rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, edgecolor=color,facecolor='none', linestyle="--") ax.add_patch(rect) plt.axis('off') img_buf = io.BytesIO() plt.savefig(img_buf, bbox_inches='tight', dpi=1000) return PIL.Image.open(img_buf) def visualize_cells(pil_img, cells): fig, ax = plt.subplots(1) ax.imshow(pil_img, interpolation='lanczos') for i, cell in enumerate(cells): bbox = cell['bbox'] if cell['header']: alpha = 0.3 else: alpha = 0.125 rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor='none',facecolor="magenta", alpha=alpha) ax.add_patch(rect) rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor="magenta",facecolor='none',linestyle="--", alpha=0.08, hatch='///') ax.add_patch(rect) rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor="magenta",facecolor='none',linestyle="--") ax.add_patch(rect) plt.axis('off') img_buf = io.BytesIO() plt.savefig(img_buf, bbox_inches='tight', dpi=1000) return PIL.Image.open(img_buf) # def pytess(cell_pil_img): # return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip() # def resize(pil_img, size=1800): # length_x, width_y = pil_img.size # factor = max(1, size / length_x) # size = int(factor * length_x), int(factor * width_y) # pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS) # return pil_img, factor # def image_smoothening(img): # ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY) # ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # blur = cv2.GaussianBlur(th2, (1, 1), 0) # ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # return th3 # def remove_noise_and_smooth(pil_img): # img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) # filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3) # kernel = np.ones((1, 1), np.uint8) # opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel) # closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel) # img = image_smoothening(img) # or_image = cv2.bitwise_or(img, closing) # pil_img = PIL.Image.fromarray(or_image) # return pil_img # def extract_text_from_cells(pil_img, cells): # pil_img, factor = resize(pil_img) # #pil_img = remove_noise_and_smooth(pil_img) # #display(pil_img) # for cell in cells: # bbox = [x * factor for x in cell['bbox']] # cell_pil_img = pil_img.crop(bbox) # #cell_pil_img = remove_noise_and_smooth(cell_pil_img) # #cell_pil_img = tess_prep(cell_pil_img) # cell['cell text'] = pytess(cell_pil_img) # return cells def extract_text_from_cells(cells, sep=' '): for cell in cells: spans = cell['spans'] text = '' for span in spans: if 'text' in span: text += span['text'] + sep cell['cell_text'] = text return cells def cells_to_csv(cells): if len(cells) > 0: num_columns = max([max(cell['column_nums']) for cell in cells]) + 1 num_rows = max([max(cell['row_nums']) for cell in cells]) + 1 else: return header_cells = [cell for cell in cells if cell['header']] if len(header_cells) > 0: max_header_row = max([max(cell['row_nums']) for cell in header_cells]) else: max_header_row = -1 table_array = np.empty([num_rows, num_columns], dtype='object') if len(cells) > 0: for cell in cells: for row_num in cell['row_nums']: for column_num in cell['column_nums']: table_array[row_num, column_num] = cell['cell_text'] header = table_array[:max_header_row+1,:] flattened_header = [] for col in header.transpose(): flattened_header.append(' | '.join(OrderedDict.fromkeys(col))) df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header) return df, df.to_csv(index=None) def cells_to_html(cells): cells = sorted(cells, key=lambda k: min(k['column_nums'])) cells = sorted(cells, key=lambda k: min(k['row_nums'])) table = ET.Element('table') current_row = -1 for cell in cells: this_row = min(cell['row_nums']) attrib = {} colspan = len(cell['column_nums']) if colspan > 1: attrib['colspan'] = str(colspan) rowspan = len(cell['row_nums']) if rowspan > 1: attrib['rowspan'] = str(rowspan) if this_row > current_row: current_row = this_row if cell['header']: cell_tag = 'th' row = ET.SubElement(table, 'tr') else: cell_tag = 'td' row = ET.SubElement(table, 'tr') tcell = ET.SubElement(row, cell_tag, attrib=attrib) tcell.text = cell['cell_text'] return str(ET.tostring(table, encoding='unicode', short_empty_elements=False)) # def cells_to_html(cells): # for cell in cells: # cell['column_nums'].sort() # cell['row_nums'].sort() # n_cols = max(cell['column_nums'][-1] for cell in cells) + 1 # n_rows = max(cell['row_nums'][-1] for cell in cells) + 1 # html_code = '' # for r in range(n_rows): # r_cells = [cell for cell in cells if cell['row_nums'][0] == r] # r_cells.sort(key=lambda x: x['column_nums'][0]) # r_html = '' # for cell in r_cells: # rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1 # colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1 # r_html += f'