File size: 8,081 Bytes
fa50974
679d3c5
 
190bae6
fa50974
6e73f0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190bae6
74c842e
679d3c5
6e73f0b
679d3c5
fa50974
679d3c5
 
 
 
 
 
 
a231a61
fa50974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e059e1e
fa50974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679d3c5
fa50974
 
679d3c5
fa50974
 
 
 
679d3c5
fa50974
faf74d9
6e73f0b
 
 
 
 
 
679d3c5
 
fa50974
679d3c5
 
 
 
 
 
 
 
 
45440e8
 
63b70c6
679d3c5
2132edd
679d3c5
 
 
3ec7d0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from ultralytics import RTDETR
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image
from PIL import Image, ImageDraw, ImageFont


from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification

# Load the LayoutLMv3 model
layout_model = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/LayoutReader80Small")

MAX_LEN = 100
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2


def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
    bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
    input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
    attention_mask = [1] + [1] * len(boxes) + [1]
    return {
        "bbox": torch.tensor([bbox]),
        "attention_mask": torch.tensor([attention_mask]),
        "input_ids": torch.tensor([input_ids]),
    }

def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
    """
    Parse logits to determine the reading order.
    """
    logits = logits[1: length + 1, :length]
    orders = logits.argsort(descending=False).tolist()
    ret = [o.pop() for o in orders]
    while True:
        order_to_idxes = defaultdict(list)
        for idx, order in enumerate(ret):
            order_to_idxes[order].append(idx)
        # Filter indices with length > 1
        order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
        if not order_to_idxes:
            break
        # Resolve conflicts
        for order, idxes in order_to_idxes.items():
            idxes_to_logit = {idx: logits[idx, order] for idx in idxes}
            idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True)
            for idx, _ in idxes_to_logit[1:]:
                ret[idx] = orders[idx].pop()

    return ret

def get_orders(image_path, boxes):
    inputs = boxes2inputs(boxes)
    inputs = {k: v.to(layout_model.device) for k, v in inputs.items()}  # Move inputs to model device
    logits = layout_model(**inputs).logits.cpu().squeeze(0)  # Perform inference and get logits
    orders = parse_logits(logits, len(boxes))
    return orders


model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
model = RTDETR(model_dir)


def detect_layout(img, conf_threshold, iou_threshold):
    """Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
    results = model.predict(
        source=img,
        conf=conf_threshold,
        iou=iou_threshold,
        show_labels=True,
        show_conf=True,
        imgsz=1024,
        agnostic_nms= True,
        max_det=34,
        nms=True
    )[0]
    bboxes = results.boxes.xyxy.cpu().tolist()
    classes = results.boxes.cls.cpu().tolist()
    mapping = {0: 'CheckBox',
              1: 'List',
              2: 'P',
              3: 'abandon',
              4: 'figure',
              5: 'gridless_table',
              6: 'handwritten_signature',
              7: 'qr_code',
              8: 'table',
              9: 'title'}
    classes = [mapping[i] for i in classes]
    return bboxes , classes 
    

def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
    # Define a color map for each class name
    class_colors = {
        'CheckBox': 'orange',
        'List': 'blue',
        'P': 'green',
        'abandon': 'purple',
        'figure': 'cyan',
        'gridless_table': 'yellow',
        'handwritten_signature': 'magenta',
        'qr_code': 'red',
        'table': 'brown',
        'title': 'pink'
    }

    # Open the image using PIL
    image = image_path
    
    # Prepare to draw on the image
    draw = ImageDraw.Draw(image)
    
    # Try loading a default font, if it fails, use a basic font
    try:
        font = ImageFont.truetype("arial.ttf", 20)
        title_font = ImageFont.truetype("arial.ttf", 30)  # Larger font for titles
    except IOError:
        font = ImageFont.load_default(size = 30)
        title_font = font  # Use the same font for title if custom font fails

    # Loop through the bounding boxes and corresponding labels
    for i in range(len(bboxes)):
        x1, y1, x2, y2 = bboxes[i]
        class_name = classes[i]
        order = reading_order[i]
        
        # Get the color for the class
        color = class_colors[class_name]
        
        # If it's a title, make the bounding box thicker and text larger
        if class_name == 'title':
            box_thickness = 4  # Thicker box for title
            label_font = title_font  # Larger font for title
        else:
            box_thickness = 2  # Default box thickness
            label_font = font  # Default font for other classes
        
        # Draw the rectangle with the class color and box thickness
        draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness)
        
        # Label the box with the class and order
        label = f"{class_name}-{order}"
        
        # Calculate text size using textbbox() to get the bounding box of the text
        bbox = draw.textbbox((x1, y1 - 20), label, font=label_font)
        label_width = bbox[2] - bbox[0]
        label_height = bbox[3] - bbox[1]
        
        # Draw the text above the box
        draw.text((x1, y1 - label_height), label, fill="black", font=label_font)
    
    # Return the modified image as a PIL image object
    return image
from PIL import Image, ImageDraw

def is_inside(box1, box2):
    # Check if box1 is inside box2
    return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]

def is_overlap(box1, box2):
    # Check if box1 overlaps with box2
    x1, y1, x2, y2 = box1
    x3, y3, x4, y4 = box2

    # No overlap if one box is to the left, right, above, or below the other box
    return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)

def remove_overlapping_and_inside_boxes(boxes, classes):
    to_remove = []
    
    for i, box1 in enumerate(boxes):
        for j, box2 in enumerate(boxes):
            if i != j:
                if is_inside(box1, box2):
                    # Mark the smaller (inside) box for removal
                    to_remove.append(i)
                elif is_inside(box2, box1):
                    # Mark the smaller (inside) box for removal
                    to_remove.append(j)
                elif is_overlap(box1, box2):
                    # If the boxes overlap, mark the smaller one for removal
                    if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
                        to_remove.append(j)
                    else:
                        to_remove.append(i)

    # Remove duplicates and sort by the index to keep original boxes
    to_remove = sorted(set(to_remove), reverse=True)

    # Remove the boxes and their corresponding classes from the list
    for idx in to_remove:
        del boxes[idx]
        del classes[idx]

    return boxes, classes
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
    bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold)
    bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes)
    orders = get_orders(IMAGE_PATH, bboxes)
    final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders)
    return final_image


iface = gr.Interface(
    fn=full_predictions,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
        gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"),
    ],
    outputs=gr.Image(type="pil", label="Result"),
    title="Ultralytics Gradio",
    description="Upload images for inference. The Ultralytics YOLO11n model is used by default.",
    examples=[
        ["kashida.png", 0.2, 0.45],
        ["image.jpg", 0.2, 0.45],
        ["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45]
    ],
    theme=gr.themes.Default()
)

if __name__ == "__main__":
    iface.launch()