File size: 5,695 Bytes
2a54612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
defb54c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a54612
 
 
defb54c
2a54612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
defb54c
2a54612
 
 
 
 
defb54c
2a54612
 
 
 
 
 
defb54c
2a54612
 
 
 
 
defb54c
2a54612
defb54c
2a54612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
defb54c
 
 
 
 
 
 
 
 
 
 
2a54612
defb54c
 
2a54612
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import ast
import onnx
import onnxruntime as ort
import cv2
from huggingface_hub import hf_hub_download
import numpy as np

# Download the model from the Hugging Face Hub
model = hf_hub_download(
    repo_id="wybxc/DocLayout-YOLO-DocStructBench-onnx",
    filename="doclayout_yolo_docstructbench_imgsz1024.onnx",
)
model = onnx.load(model)
metadata = {prop.key: prop.value for prop in model.metadata_props}

names = ast.literal_eval(metadata["names"])
stride = ast.literal_eval(metadata["stride"])

# Load the model with ONNX Runtime
session = ort.InferenceSession(model.SerializeToString())


def resize_and_pad_image(image, new_shape, stride=32):
    """

    Resize and pad the image to the specified size, ensuring dimensions are multiples of stride.



    Parameters:

    - image: Input image

    - new_shape: Target size (integer or (height, width) tuple)

    - stride: Padding alignment stride, default 32



    Returns:

    - Processed image

    """
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    h, w = image.shape[:2]
    new_h, new_w = new_shape

    # Calculate scaling ratio
    r = min(new_h / h, new_w / w)
    resized_h, resized_w = int(round(h * r)), int(round(w * r))

    # Resize image
    image = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR)

    # Calculate padding size and align to stride multiple
    pad_w = (new_w - resized_w) % stride
    pad_h = (new_h - resized_h) % stride
    top, bottom = pad_h // 2, pad_h - pad_h // 2
    left, right = pad_w // 2, pad_w - pad_w // 2

    # Add padding
    image = cv2.copyMakeBorder(
        image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
    )

    return image


def scale_boxes(img1_shape, boxes, img0_shape):
    """

    Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally

    specified in (img1_shape) to the shape of a different image (img0_shape).



    Args:

        img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).

        boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)

        img0_shape (tuple): the shape of the target image, in the format of (height, width).



    Returns:

        boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)

    """

    # Calculate scaling ratio
    gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])

    # Calculate padding size
    pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
    pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)

    # Remove padding and scale boxes
    boxes[..., :4] = (boxes[..., :4] - [pad_x, pad_y, pad_x, pad_y]) / gain
    return boxes


class YoloResult:
    def __init__(self, boxes, names):
        self.boxes = [YoloBox(data=d) for d in boxes]
        self.boxes = sorted(self.boxes, key=lambda x: x.conf, reverse=True)
        self.names = names


class YoloBox:
    def __init__(self, data):
        self.xyxy = data[:4]
        self.conf = data[-2]
        self.cls = data[-1]


def inference(image):
    """

    Run inference on the input image.



    Parameters:

    - image: Input image, HWC format and RGB order



    Returns:

    - YoloResult object containing the predicted boxes and class names

    """

    # Preprocess image
    orig_h, orig_w = image.shape[:2]
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    pix = resize_and_pad_image(image, new_shape=int(image.shape[0] / stride) * stride)
    pix = np.transpose(pix, (2, 0, 1))  # CHW
    pix = np.expand_dims(pix, axis=0)  # BCHW
    pix = pix.astype(np.float32) / 255.0  # Normalize to [0, 1]
    new_h, new_w = pix.shape[2:]

    # Run inference
    preds = session.run(None, {"images": pix})[0]

    # Postprocess predictions
    preds = preds[preds[..., 4] > 0.25]
    preds[..., :4] = scale_boxes((new_h, new_w), preds[..., :4], (orig_h, orig_w))
    return YoloResult(boxes=preds, names=names)


if __name__ == "__main__":
    import sys
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.colors as colors

    image = sys.argv[1]
    image = cv2.imread(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    layout = inference(image)

    bitmap = np.ones(image.shape[:2], dtype=np.uint8)
    h, w = bitmap.shape
    vcls = ["abandon", "figure", "table", "isolate_formula", "formula_caption"]
    for i, d in enumerate(layout.boxes):
        x0, y0, x1, y1 = d.xyxy.squeeze()
        x0, y0, x1, y1 = (
            np.clip(int(x0 - 1), 0, w - 1),
            np.clip(int(h - y1 - 1), 0, h - 1),
            np.clip(int(x1 + 1), 0, w - 1),
            np.clip(int(h - y0 + 1), 0, h - 1),
        )
        if layout.names[int(d.cls)] in vcls:
            bitmap[y0:y1, x0:x1] = 0
        else:
            bitmap[y0:y1, x0:x1] = i + 2
    bitmap = bitmap[::-1, :]

    # map bitmap to color
    colormap = matplotlib.colormaps["Pastel1"]
    norm = colors.Normalize(vmin=bitmap.min(), vmax=bitmap.max())
    colored_bitmap = colormap(norm(bitmap))
    colored_bitmap = (colored_bitmap[:, :, :3] * 255).astype(np.uint8)

    # overlay bitmap on image
    image_with_bitmap = cv2.multiply(image, colored_bitmap, scale=1 / 255)

    # show the results
    fig, ax = plt.subplots(1, 3, figsize=(15, 6))
    ax[0].imshow(image)
    ax[1].imshow(bitmap, cmap="Pastel1")
    ax[2].imshow(image_with_bitmap)
    plt.show()