Spaces:
Running
Running
File size: 7,887 Bytes
e1b00d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import gradio as gr
import os
import cv2
from rapid_table_det.inference import TableDetector
from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
example_images = [
"images/doc1.png",
"images/doc2.jpg",
"images/doc3.jpg",
"images/doc4.jpg",
"images/doc5.jpg",
"images/real1.jpg",
"images/real2.jpeg",
"images/real3.jpg",
"images/real4.jpg",
"images/real5.jpg"
]
# 定义模型类型选项
model_type_options = {
"YOLO 目标检测": ["yolo_obj_det"],
"Paddle 目标检测": ["paddle_obj_det"],
"Paddle 目标检测 (量化)": ["paddle_obj_det_s"],
"YOLO 语义分割": ["yolo_edge_det"],
"YOLO 语义分割 (小型)": ["yolo_edge_det_s"],
"Paddle 语义分割": ["paddle_edge_det"],
"Paddle 语义分割 (量化)": ["paddle_edge_det_s"],
"Paddle 方向分类": ["paddle_cls_det"]
}
# 预生成所有可能的 TableDetector 实例
preinitialized_detectors = {}
for obj_model_type in model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] + model_type_options[
"Paddle 目标检测 (量化)"]:
for edge_model_type in model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] + model_type_options[
"Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"]:
for cls_model_type in model_type_options["Paddle 方向分类"]:
detector_key = (obj_model_type, edge_model_type, cls_model_type)
preinitialized_detectors[detector_key] = TableDetector(
obj_model_type=obj_model_type,
edge_model_type=edge_model_type,
cls_model_type=cls_model_type,
obj_model_path=os.path.join("models", f"{obj_model_type}.onnx"),
edge_model_path=os.path.join("models", f"{edge_model_type}.onnx"),
cls_model_path=os.path.join("models", f"{cls_model_type}.onnx")
)
# 定义图片缩放函数
def resize_image(image, max_size=640):
height, width = image.shape[:2]
if max(height, width) > max_size:
scale = max_size / max(height, width)
new_height = int(height * scale)
new_width = int(width * scale)
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return image
# 定义推理函数
def run_inference(img_path, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det,
use_cls_det):
detector_key = (obj_model_type, edge_model_type, cls_model_type)
table_det = preinitialized_detectors[detector_key]
result, elapse = table_det(
img_path,
det_accuracy=det_accuracy,
use_obj_det=use_obj_det,
use_edge_det=use_edge_det,
use_cls_det=use_cls_det
)
# 加载图片
img = img_loader(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
extract_img = img.copy()
visual_img = img.copy()
extract_imgs = []
for i, res in enumerate(result):
box = res["box"]
lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
# 带识别框和左上角方向位置
visual_img = visuallize(visual_img, box, lt, rt, rb, lb)
# 透视变换提取表格图片
wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
extract_imgs.append(wrapped_img)
# 缩放图片
visual_img = resize_image(visual_img)
extract_imgs = [resize_image(img) for img in extract_imgs]
obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
return visual_img, extract_imgs, f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
def update_extract_outputs(visual_img, extract_imgs, time_info):
if len(extract_imgs) == 1:
return visual_img, extract_imgs[0], time_info
else:
return visual_img, extract_imgs, time_info
# 创建Gradio界面
with gr.Blocks(
css="""
.scrollable-container {
overflow-x: auto;
white-space: nowrap;
}
.header-links {
text-align: center;
}
.header-links a {
display: inline-block;
text-align: center;
margin-right: 10px; /* 调整间距 */
}
"""
) as demo:
gr.HTML(
"<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidTableDetection'>RapidTableDetection</a></h1>"
)
gr.HTML('''
<div class="header-links">
<a href=""><img src="https://img.shields.io/badge/Python->=3.8,<3.12-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Mac%2C%20Win-pink.svg"></a>
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
<a href="https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE"><img alt="GitHub" src="https://img.shields.io/badge/license-Apache 2.0-blue"></a>
</div>
''')
with gr.Row():
with gr.Column(variant="panel", scale=1):
img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/real1.jpg")
# 示例图片选择器
examples = gr.Examples(
examples=example_images,
examples_per_page=len(example_images),
inputs=img_input,
fn=lambda x: x, # 简单返回图片路径
outputs=img_input,
cache_examples=False
)
obj_model_type = gr.Dropdown(
choices=model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] +
model_type_options["Paddle 目标检测 (量化)"],
value="yolo_obj_det",
label="obj det model")
edge_model_type = gr.Dropdown(
choices=model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] +
model_type_options["Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"],
value="yolo_edge_det",
label="edge seg model")
cls_model_type = gr.Dropdown(choices=model_type_options["Paddle 方向分类"],
value="paddle_cls_det",
label="direction cls model")
det_accuracy = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="目标检测置信度阈值")
use_obj_det = gr.Checkbox(value=True, label="use obj det")
use_edge_det = gr.Checkbox(value=True, label="use edge seg")
use_cls_det = gr.Checkbox(value=True, label="use direction cls")
run_button = gr.Button("run")
with gr.Column(scale=2):
visual_output = gr.Image(label="output visualize")
extract_outputs = gr.Gallery(label="extracted images", object_fit="contain", columns=1, preview=True)
time_output = gr.Textbox(label="elapsed")
run_button.click(
fn=run_inference,
inputs=[img_input, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det,
use_cls_det],
outputs=[visual_output, extract_outputs, time_output]
)
# 启动Gradio应用
demo.launch()
|