rishabhv471's picture
test
0599d82 verified
raw
history blame
2.95 kB
import asyncio
import string
from collections import Counter
from itertools import count, tee
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
import torch
from PIL import Image
from transformers import (DetrImageProcessor, TableTransformerForObjectDetection)
from vietocr.tool.config import Cfg
from vietocr.tool.predictor import Predictor
st.set_option('deprecation.showPyplotGlobalUse', False)
st.set_page_config(layout='wide')
st.title("Table Detection and Table Structure Recognition")
st.write(
"Implemented by MSFT team: https://github.com/microsoft/table-transformer")
# Config (optional, comment out if not using)
# config = Cfg.load_config_from_name('vgg_transformer')
# config = Cfg.load_config_from_name('vgg_seq2seq')
# config['cnn']['pretrained'] = False
# config['device'] = 'cpu'
# config['predictor']['beamsearch'] = False
# detector = Predictor(config)
table_detection_model = TableTransformerForObjectDetection.from_pretrained(
"microsoft/table-transformer-detection")
table_recognition_model = TableTransformerForObjectDetection.from_pretrained(
"microsoft/table-transformer-structure-recognition")
def PIL_to_cv(pil_img):
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
def cv_to_PIL(cv_img):
return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
async def pytess(cell_pil_img, threshold: float = 0.5):
text, prob = detector.predict(cell_pil_img, return_prob=True) # Assuming detector is defined
if prob < threshold:
return ""
return text.strip()
def sharpen_image(pil_img):
img = PIL_to_cv(pil_img)
sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
sharpen = cv2.filter2D(img, -1, sharpen_kernel)
pil_img = cv_to_PIL(sharpen)
return pil_img
def uniquify(seq, suffs=count(1)):
"""Make all the items unique by adding a suffix (1, 2, etc).
Credit: https://stackoverflow.com/questions/30650474/python-rename-duplicates-in-list-with-progressive-numbers-without-sorting-list
"""
not_unique = [k for k, v in Counter(seq).items() if v > 1]
suff_gens = dict(zip(not_unique, tee(suffs, len(not_unique))))
for idx, s in enumerate(seq):
try:
suffix = str(next(suff_gens[s]))
except KeyError:
continue
else:
seq[idx] += suffix
return seq
def binarizeBlur_image(pil_img):
image = PIL_to_cv(pil_img)
thresh = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY_INV)[1]
result = cv2.GaussianBlur(thresh, (5, 5), 0)
result = 255 - result
return cv_to_PIL(result)
def td_postprocess(pil_img):
'''
Removes gray background from tables
'''
img = PIL_to_cv(pil_img)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
mask = cv2.inRange(hsv, (0, 0, 100), (255, 5, 255)) # (0, 0, 100), (255, 5, 255)
nzmask = cv2.inRange(hsv, (0, 0, 5), (255, 255, 255)) # (0, 0,