Spaces:
Build error
Build error
| import streamlit as st | |
| import PIL | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| # import sys | |
| # import json | |
| from collections import OrderedDict, defaultdict | |
| import xml.etree.ElementTree as ET | |
| from paddleocr import PaddleOCR | |
| import pytesseract | |
| from pytesseract import Output | |
| import postprocess | |
| ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True) | |
| detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True) | |
| structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True) | |
| imgsz = 640 | |
| detection_class_names = ['table', 'table rotated'] | |
| structure_class_names = [ | |
| 'table', 'table column', 'table row', 'table column header', | |
| 'table projected row header', 'table spanning cell', 'no object' | |
| ] | |
| structure_class_map = {k: v for v, k in enumerate(structure_class_names)} | |
| structure_class_thresholds = { | |
| 'table': 0.5, | |
| 'table column': 0.5, | |
| 'table row': 0.5, | |
| 'table column header': 0.25, | |
| 'table projected row header': 0.25, | |
| 'table spanning cell': 0.25, | |
| 'no object': 10 | |
| } | |
| def PIL_to_cv(pil_img): | |
| return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
| def cv_to_PIL(cv_img): | |
| return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) | |
| def table_detection(pil_img): | |
| image = PIL_to_cv(pil_img) | |
| pred = detection_model(image, size=imgsz) | |
| pred = pred.xywhn[0] | |
| result = pred.cpu().numpy() | |
| return result | |
| def table_structure(pil_img): | |
| image = PIL_to_cv(pil_img) | |
| pred = structure_model(image, size=imgsz) | |
| pred = pred.xywhn[0] | |
| result = pred.cpu().numpy() | |
| return result | |
| def crop_image(pil_img, detection_result): | |
| crop_images = [] | |
| image = PIL_to_cv(pil_img) | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| for i, result in enumerate(detection_result): | |
| class_id = int(result[5]) | |
| score = float(result[4]) | |
| min_x = result[0] | |
| min_y = result[1] | |
| w = result[2] | |
| h = result[3] | |
| x1 = max(0, int((min_x - w / 2 - 0.02) * width)) | |
| y1 = max(0, int((min_y - h / 2 - 0.02) * height)) | |
| x2 = min(width, int((min_x + w / 2 + 0.02) * width)) | |
| y2 = min(height, int((min_y + h / 2 + 0.02) * height)) | |
| # print(x1, y1, x2, y2) | |
| crop_image = image[y1:y2, x1:x2, :] | |
| crop_images.append(cv_to_PIL(crop_image)) | |
| cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) | |
| return crop_images, cv_to_PIL(image) | |
| def ocr(pil_img): | |
| image = PIL_to_cv(pil_img) | |
| result = ocr_instance.ocr(image) | |
| ocr_res = [] | |
| for ps, (text, score) in result[0]: | |
| x1 = min(p[0] for p in ps) | |
| y1 = min(p[1] for p in ps) | |
| x2 = max(p[0] for p in ps) | |
| y2 = max(p[1] for p in ps) | |
| word_info = { | |
| 'bbox': [x1, y1, x2, y2], | |
| 'text': text | |
| } | |
| ocr_res.append(word_info) | |
| return ocr_res | |
| def convert_stucture(page_tokens, pil_img, structure_result): | |
| image = PIL_to_cv(pil_img) | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| bboxes = [] | |
| scores = [] | |
| labels = [] | |
| for i, result in enumerate(structure_result): | |
| class_id = int(result[5]) | |
| score = float(result[4]) | |
| min_x = result[0] | |
| min_y = result[1] | |
| w = result[2] | |
| h = result[3] | |
| x1 = int((min_x - w / 2) * width) | |
| y1 = int((min_y - h / 2) * height) | |
| x2 = int((min_x + w / 2) * width) | |
| y2 = int((min_y + h / 2) * height) | |
| # print(x1, y1, x2, y2) | |
| bboxes.append([x1, y1, x2, y2]) | |
| scores.append(score) | |
| labels.append(class_id) | |
| table_objects = [] | |
| for bbox, score, label in zip(bboxes, scores, labels): | |
| table_objects.append({'bbox': bbox, 'score': score, 'label': label}) | |
| # print('table_objects:', table_objects) | |
| table = {'objects': table_objects, 'page_num': 0} | |
| table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] | |
| if len(table_class_objects) > 1: | |
| table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) | |
| try: | |
| table_bbox = list(table_class_objects[0]['bbox']) | |
| except: | |
| table_bbox = (0, 0, 1000, 1000) | |
| # print('table_class_objects:', table_class_objects) | |
| # print('table_bbox:', table_bbox) | |
| tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] | |
| # print('tokens_in_table:', tokens_in_table) | |
| table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) | |
| return table_structures, cells, confidence_score | |
| def visualize_ocr(pil_img, ocr_result): | |
| image = PIL_to_cv(pil_img) | |
| for i, res in enumerate(ocr_result): | |
| bbox = res['bbox'] | |
| x1 = int(bbox[0]) | |
| y1 = int(bbox[1]) | |
| x2 = int(bbox[2]) | |
| y2 = int(bbox[3]) | |
| cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) | |
| return cv_to_PIL(image) | |
| def visualize_structure(pil_img, structure_result): | |
| image = PIL_to_cv(pil_img) | |
| width = image.shape[1] | |
| height = image.shape[0] | |
| # print(width, height) | |
| for i, result in enumerate(structure_result): | |
| class_id = int(result[5]) | |
| score = float(result[4]) | |
| min_x = result[0] | |
| min_y = result[1] | |
| w = result[2] | |
| h = result[3] | |
| x1 = int((min_x - w / 2) * width) | |
| y1 = int((min_y - h / 2) * height) | |
| x2 = int((min_x + w / 2) * width) | |
| y2 = int((min_y + h / 2) * height) | |
| # print(x1, y1, x2, y2) | |
| if score >= structure_class_thresholds[structure_class_names[class_id]]: | |
| cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255)) | |
| #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) | |
| return cv_to_PIL(image) | |
| def visualize_cells(pil_img, cells): | |
| image = PIL_to_cv(pil_img) | |
| for i, cell in enumerate(cells): | |
| bbox = cell['bbox'] | |
| x1 = int(bbox[0]) | |
| y1 = int(bbox[1]) | |
| x2 = int(bbox[2]) | |
| y2 = int(bbox[3]) | |
| cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) | |
| return cv_to_PIL(image) | |
| def pytess(cell_pil_img): | |
| return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip() | |
| def resize(pil_img, size=1800): | |
| length_x, width_y = pil_img.size | |
| factor = max(1, size / length_x) | |
| size = int(factor * length_x), int(factor * width_y) | |
| pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS) | |
| return pil_img, factor | |
| def image_smoothening(img): | |
| ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY) | |
| ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| blur = cv2.GaussianBlur(th2, (1, 1), 0) | |
| ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return th3 | |
| def remove_noise_and_smooth(pil_img): | |
| img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) | |
| filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3) | |
| kernel = np.ones((1, 1), np.uint8) | |
| opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel) | |
| closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel) | |
| img = image_smoothening(img) | |
| or_image = cv2.bitwise_or(img, closing) | |
| pil_img = PIL.Image.fromarray(or_image) | |
| return pil_img | |
| # def extract_text_from_cells(pil_img, cells): | |
| # pil_img, factor = resize(pil_img) | |
| # #pil_img = remove_noise_and_smooth(pil_img) | |
| # #display(pil_img) | |
| # for cell in cells: | |
| # bbox = [x * factor for x in cell['bbox']] | |
| # cell_pil_img = pil_img.crop(bbox) | |
| # #cell_pil_img = remove_noise_and_smooth(cell_pil_img) | |
| # #cell_pil_img = tess_prep(cell_pil_img) | |
| # cell['cell text'] = pytess(cell_pil_img) | |
| # return cells | |
| def extract_text_from_cells(cells, sep=' '): | |
| for cell in cells: | |
| spans = cell['spans'] | |
| text = '' | |
| for span in spans: | |
| if 'text' in span: | |
| text += span['text'] + sep | |
| cell['cell_text'] = text | |
| return cells | |
| def cells_to_csv(cells): | |
| if len(cells) > 0: | |
| num_columns = max([max(cell['column_nums']) for cell in cells]) + 1 | |
| num_rows = max([max(cell['row_nums']) for cell in cells]) + 1 | |
| else: | |
| return | |
| header_cells = [cell for cell in cells if cell['header']] | |
| if len(header_cells) > 0: | |
| max_header_row = max([max(cell['row_nums']) for cell in header_cells]) | |
| else: | |
| max_header_row = -1 | |
| table_array = np.empty([num_rows, num_columns], dtype='object') | |
| if len(cells) > 0: | |
| for cell in cells: | |
| for row_num in cell['row_nums']: | |
| for column_num in cell['column_nums']: | |
| table_array[row_num, column_num] = cell['cell_text'] | |
| header = table_array[:max_header_row+1,:] | |
| flattened_header = [] | |
| for col in header.transpose(): | |
| flattened_header.append(' | '.join(OrderedDict.fromkeys(col))) | |
| df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header) | |
| return df, df.to_csv(index=None) | |
| def cells_to_html(cells): | |
| cells = sorted(cells, key=lambda k: min(k['column_nums'])) | |
| cells = sorted(cells, key=lambda k: min(k['row_nums'])) | |
| table = ET.Element('table') | |
| current_row = -1 | |
| for cell in cells: | |
| this_row = min(cell['row_nums']) | |
| attrib = {} | |
| colspan = len(cell['column_nums']) | |
| if colspan > 1: | |
| attrib['colspan'] = str(colspan) | |
| rowspan = len(cell['row_nums']) | |
| if rowspan > 1: | |
| attrib['rowspan'] = str(rowspan) | |
| if this_row > current_row: | |
| current_row = this_row | |
| if cell['header']: | |
| cell_tag = 'th' | |
| row = ET.SubElement(table, 'thead') | |
| else: | |
| cell_tag = 'td' | |
| row = ET.SubElement(table, 'tr') | |
| tcell = ET.SubElement(row, cell_tag, attrib=attrib) | |
| tcell.text = cell['cell_text'] | |
| return str(ET.tostring(table, encoding='unicode', short_empty_elements=False)) | |
| # def cells_to_html(cells): | |
| # for cell in cells: | |
| # cell['column_nums'].sort() | |
| # cell['row_nums'].sort() | |
| # n_cols = max(cell['column_nums'][-1] for cell in cells) + 1 | |
| # n_rows = max(cell['row_nums'][-1] for cell in cells) + 1 | |
| # html_code = '' | |
| # for r in range(n_rows): | |
| # r_cells = [cell for cell in cells if cell['row_nums'][0] == r] | |
| # r_cells.sort(key=lambda x: x['column_nums'][0]) | |
| # r_html = '' | |
| # for cell in r_cells: | |
| # rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1 | |
| # colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1 | |
| # r_html += f'<td rowspan='{rowspan}' colspan='{colspan}'>{escape(cell['text'])}</td>' | |
| # html_code += f'<tr>{r_html}</tr>' | |
| # html_code = '''<html> | |
| # <head> | |
| # <meta charset='UTF-8'> | |
| # <style> | |
| # table, th, td { | |
| # border: 1px solid black; | |
| # font-size: 10px; | |
| # } | |
| # </style> | |
| # </head> | |
| # <body> | |
| # <table frame='hsides' rules='groups' width='100%%'> | |
| # %s | |
| # </table> | |
| # </body> | |
| # </html>''' % html_code | |
| # soup = bs(html_code) | |
| # html_code = soup.prettify() | |
| # return html_code | |
| def main(): | |
| st.set_page_config(layout='wide') | |
| st.title('Table Extraction Demo') | |
| st.write('\n') | |
| tabs = st.tabs( | |
| ['Table Detection', 'Table Structure Recognition'] | |
| ) | |
| filename = st.file_uploader('Upload image', type=['png', 'jpeg', 'jpg']) | |
| if st.button('Analyze image'): | |
| if filename is None: | |
| st.sidebar.write('Please upload an image') | |
| else: | |
| print(filename) | |
| pil_img = PIL.Image.open(filename) | |
| detection_result = table_detection(pil_img) | |
| crop_images, vis_det_img = crop_image(pil_img, detection_result) | |
| with tabs[0]: | |
| st.image(vis_det_img) | |
| with tabs[1]: | |
| str_cols = st.columns((len(crop_images), ) * 5) | |
| str_cols[0].subheader('Table image') | |
| str_cols[1].subheader('OCR result') | |
| str_cols[2].subheader('Structure result') | |
| str_cols[3].subheader('Cells result') | |
| str_cols[4].subheader('CSV result') | |
| for i, img in enumerate(crop_images): | |
| ocr_result = ocr(img) | |
| structure_result = table_structure(img) | |
| table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result) | |
| cells = extract_text_from_cells(cells) | |
| html_result = cells_to_html(cells) | |
| df, csv_result = cells_to_csv(cells) | |
| #print(df) | |
| vis_ocr_img = visualize_ocr(img, ocr_result) | |
| vis_str_img = visualize_structure(img, structure_result) | |
| vis_cells_img = visualize_cells(img, cells) | |
| str_cols[0].image(img) | |
| str_cols[1].image(vis_ocr_img) | |
| str_cols[2].image(vis_str_img) | |
| str_cols[3].image(vis_cells_img) | |
| try: | |
| str_cols[4].dataframe(df) | |
| except: | |
| pass | |
| str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}') | |
| st.write('\n') | |
| st.markdown(html_result, unsafe_allow_html=True) | |
| if __name__ == '__main__': | |
| main() | |