File size: 7,887 Bytes
e1b00d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import gradio as gr
import os
import cv2
from rapid_table_det.inference import TableDetector
from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img

example_images = [
    "images/doc1.png",
    "images/doc2.jpg",
    "images/doc3.jpg",
    "images/doc4.jpg",
    "images/doc5.jpg",
    "images/real1.jpg",
    "images/real2.jpeg",
    "images/real3.jpg",
    "images/real4.jpg",
    "images/real5.jpg"
]

# 定义模型类型选项
model_type_options = {
    "YOLO 目标检测": ["yolo_obj_det"],
    "Paddle 目标检测": ["paddle_obj_det"],
    "Paddle 目标检测 (量化)": ["paddle_obj_det_s"],
    "YOLO 语义分割": ["yolo_edge_det"],
    "YOLO 语义分割 (小型)": ["yolo_edge_det_s"],
    "Paddle 语义分割": ["paddle_edge_det"],
    "Paddle 语义分割 (量化)": ["paddle_edge_det_s"],
    "Paddle 方向分类": ["paddle_cls_det"]
}

# 预生成所有可能的 TableDetector 实例
preinitialized_detectors = {}

for obj_model_type in model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] + model_type_options[
    "Paddle 目标检测 (量化)"]:
    for edge_model_type in model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] + model_type_options[
        "Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"]:
        for cls_model_type in model_type_options["Paddle 方向分类"]:
            detector_key = (obj_model_type, edge_model_type, cls_model_type)
            preinitialized_detectors[detector_key] = TableDetector(
                obj_model_type=obj_model_type,
                edge_model_type=edge_model_type,
                cls_model_type=cls_model_type,
                obj_model_path=os.path.join("models", f"{obj_model_type}.onnx"),
                edge_model_path=os.path.join("models", f"{edge_model_type}.onnx"),
                cls_model_path=os.path.join("models", f"{cls_model_type}.onnx")
            )


# 定义图片缩放函数
def resize_image(image, max_size=640):
    height, width = image.shape[:2]
    if max(height, width) > max_size:
        scale = max_size / max(height, width)
        new_height = int(height * scale)
        new_width = int(width * scale)
        image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
    return image


# 定义推理函数
def run_inference(img_path, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det,

                  use_cls_det):
    detector_key = (obj_model_type, edge_model_type, cls_model_type)
    table_det = preinitialized_detectors[detector_key]
    result, elapse = table_det(
        img_path,
        det_accuracy=det_accuracy,
        use_obj_det=use_obj_det,
        use_edge_det=use_edge_det,
        use_cls_det=use_cls_det
    )

    # 加载图片
    img = img_loader(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    extract_img = img.copy()
    visual_img = img.copy()
    extract_imgs = []

    for i, res in enumerate(result):
        box = res["box"]
        lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
        # 带识别框和左上角方向位置
        visual_img = visuallize(visual_img, box, lt, rt, rb, lb)
        # 透视变换提取表格图片
        wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
        extract_imgs.append(wrapped_img)
    # 缩放图片
    visual_img = resize_image(visual_img)
    extract_imgs = [resize_image(img) for img in extract_imgs]

    obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
    return visual_img, extract_imgs, f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"


def update_extract_outputs(visual_img, extract_imgs, time_info):
    if len(extract_imgs) == 1:
        return visual_img, extract_imgs[0], time_info
    else:
        return visual_img, extract_imgs, time_info


# 创建Gradio界面
with gr.Blocks(
        css="""

        .scrollable-container {

            overflow-x: auto;

            white-space: nowrap;

        }

        .header-links {

            text-align: center;

        }

        .header-links a {

            display: inline-block;

            text-align: center;

            margin-right: 10px;  /* 调整间距 */

        }

    """
) as demo:
    gr.HTML(
        "<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidTableDetection'>RapidTableDetection</a></h1>"
    )
    gr.HTML('''

                                            <div class="header-links">

                                              <a href=""><img src="https://img.shields.io/badge/Python->=3.8,<3.12-aff.svg"></a>

  <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Mac%2C%20Win-pink.svg"></a>

<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>

  <a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>

  <a href="https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE"><img alt="GitHub" src="https://img.shields.io/badge/license-Apache 2.0-blue"></a>

                                            </div>

                                            ''')
    with gr.Row():
        with gr.Column(variant="panel", scale=1):
            img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/real1.jpg")

            # 示例图片选择器
            examples = gr.Examples(
                examples=example_images,
                examples_per_page=len(example_images),
                inputs=img_input,
                fn=lambda x: x,  # 简单返回图片路径
                outputs=img_input,
                cache_examples=False
            )

            obj_model_type = gr.Dropdown(
                choices=model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] +
                        model_type_options["Paddle 目标检测 (量化)"],
                value="yolo_obj_det",
                label="obj det model")
            edge_model_type = gr.Dropdown(
                choices=model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] +
                        model_type_options["Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"],
                value="yolo_edge_det",
                label="edge seg model")
            cls_model_type = gr.Dropdown(choices=model_type_options["Paddle 方向分类"],
                                         value="paddle_cls_det",
                                         label="direction cls model")

            det_accuracy = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="目标检测置信度阈值")
            use_obj_det = gr.Checkbox(value=True, label="use obj det")
            use_edge_det = gr.Checkbox(value=True, label="use edge seg")
            use_cls_det = gr.Checkbox(value=True, label="use direction cls")

            run_button = gr.Button("run")

        with gr.Column(scale=2):
            visual_output = gr.Image(label="output visualize")
            extract_outputs = gr.Gallery(label="extracted images", object_fit="contain", columns=1, preview=True)
            time_output = gr.Textbox(label="elapsed")

    run_button.click(
        fn=run_inference,
        inputs=[img_input, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det,
                use_cls_det],
        outputs=[visual_output, extract_outputs, time_output]
    )

# 启动Gradio应用
demo.launch()