Spaces:
Build error
Build error
Add cache
Browse files
app.py
CHANGED
|
@@ -23,11 +23,15 @@ from paddleocr import PaddleOCR
|
|
| 23 |
import postprocess
|
| 24 |
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
detection_class_names = ['table', 'table rotated', 'no object']
|
| 33 |
structure_class_names = [
|
|
@@ -62,7 +66,7 @@ def cv_to_PIL(cv_img):
|
|
| 62 |
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
|
| 63 |
|
| 64 |
|
| 65 |
-
def table_detection(pil_img):
|
| 66 |
image = PIL_to_cv(pil_img)
|
| 67 |
pred = detection_model(image, size=imgsz)
|
| 68 |
pred = pred.xywhn[0]
|
|
@@ -70,7 +74,7 @@ def table_detection(pil_img):
|
|
| 70 |
return result
|
| 71 |
|
| 72 |
|
| 73 |
-
def table_structure(pil_img):
|
| 74 |
image = PIL_to_cv(pil_img)
|
| 75 |
pred = structure_model(image, size=imgsz)
|
| 76 |
pred = pred.xywhn[0]
|
|
|
|
| 23 |
import postprocess
|
| 24 |
|
| 25 |
|
| 26 |
+
@st.cache_resource(ttl=3600)
|
| 27 |
+
def load_models():
|
| 28 |
+
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
|
| 29 |
+
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
|
| 30 |
+
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
|
| 31 |
+
return ocr_instance, detection_model, structure_model
|
| 32 |
|
| 33 |
+
|
| 34 |
+
ocr_instance, detection_model, structure_model = load_models()
|
| 35 |
|
| 36 |
detection_class_names = ['table', 'table rotated', 'no object']
|
| 37 |
structure_class_names = [
|
|
|
|
| 66 |
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
|
| 67 |
|
| 68 |
|
| 69 |
+
def table_detection(pil_img, imgsz=640):
|
| 70 |
image = PIL_to_cv(pil_img)
|
| 71 |
pred = detection_model(image, size=imgsz)
|
| 72 |
pred = pred.xywhn[0]
|
|
|
|
| 74 |
return result
|
| 75 |
|
| 76 |
|
| 77 |
+
def table_structure(pil_img, imgsz=640):
|
| 78 |
image = PIL_to_cv(pil_img)
|
| 79 |
pred = structure_model(image, size=imgsz)
|
| 80 |
pred = pred.xywhn[0]
|