File size: 3,826 Bytes
2a54612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import ast
import onnx
import onnxruntime as ort
import cv2
from huggingface_hub import hf_hub_download
import numpy as np

# Download the model from the Hugging Face Hub
model = hf_hub_download(
    repo_id="wybxc/DocLayout-YOLO-DocStructBench-onnx",
    filename="doclayout_yolo_docstructbench_imgsz1024.onnx",
)
model = onnx.load(model)
metadata = {prop.key: prop.value for prop in model.metadata_props}

names = ast.literal_eval(metadata["names"])
stride = ast.literal_eval(metadata["stride"])

# Load the model with ONNX Runtime
session = ort.InferenceSession(model.SerializeToString())


def resize_and_pad_image(image, new_shape, stride=32):
    """

    Resize and pad the image to the specified size, ensuring dimensions are multiples of stride.



    Parameters:

    - image: Input image

    - new_shape: Target size (integer or (height, width) tuple)

    - stride: Padding alignment stride, default 32



    Returns:

    - Processed image

    """
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    h, w = image.shape[:2]
    new_h, new_w = new_shape

    # Calculate scaling ratio
    r = min(new_h / h, new_w / w)
    resized_h, resized_w = int(round(h * r)), int(round(w * r))

    # Resize image
    image = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR)

    # Calculate padding size and align to stride multiple
    pad_w = (new_w - resized_w) % stride
    pad_h = (new_h - resized_h) % stride
    top, bottom = pad_h // 2, pad_h - pad_h // 2
    left, right = pad_w // 2, pad_w - pad_w // 2

    # Add padding
    image = cv2.copyMakeBorder(
        image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
    )

    return image


class YoloResult:
    def __init__(self, boxes, names):
        self.boxes = [YoloBox(data=d) for d in boxes]
        self.names = names


class YoloBox:
    def __init__(self, data):
        self.xyxy = data[:4]
        self.conf = data[-2]
        self.cls = data[-1]


def inference(image):
    """

    Run inference on the input image.



    Parameters:

    - image: Input image, HWC format and RGB order



    Returns:

    - YoloResult object containing the predicted boxes and class names

    """

    # Preprocess image
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    pix = resize_and_pad_image(image, new_shape=int(image.shape[0] / stride) * stride)
    pix = np.transpose(pix, (2, 0, 1))  # CHW
    pix = np.expand_dims(pix, axis=0)  # BCHW
    pix = pix.astype(np.float32) / 255.0  # Normalize to [0, 1]

    # Run inference
    preds = session.run(None, {"images": pix})[0]

    # Postprocess predictions
    preds = preds[preds[..., 4] > 0.25]
    return YoloResult(boxes=preds, names=names)


if __name__ == "__main__":
    import sys
    import matplotlib.pyplot as plt

    image = sys.argv[1]
    image = cv2.imread(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    layout = inference(image)

    bitmap = np.ones(image.shape[:2], dtype=np.uint8)
    h, w = bitmap.shape
    vcls = ["abandon", "figure", "table", "isolate_formula", "formula_caption"]
    for i, d in enumerate(layout.boxes):
        x0, y0, x1, y1 = d.xyxy.squeeze()
        x0, y0, x1, y1 = (
            np.clip(int(x0 - 1), 0, w - 1),
            np.clip(int(h - y1 - 1), 0, h - 1),
            np.clip(int(x1 + 1), 0, w - 1),
            np.clip(int(h - y0 + 1), 0, h - 1),
        )
        if layout.names[int(d.cls)] in vcls:
            bitmap[y0:y1, x0:x1] = 0
        else:
            bitmap[y0:y1, x0:x1] = i + 2
    bitmap = bitmap[::-1, :]

    fig, ax = plt.subplots(1, 2, figsize=(10, 6))
    ax[0].imshow(image)
    ax[1].imshow(bitmap)
    plt.show()