omarelsayeed's picture
Update app.py
86dc437 verified
raw
history blame
11.7 kB
from ultralytics import RTDETR
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification
# Load the LayoutLMv3 model
layout_model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
MAX_LEN = 100
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
"""
Parse logits to determine the reading order.
"""
logits = logits[1: length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# Filter indices with length > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# Resolve conflicts
for order, idxes in order_to_idxes.items():
idxes_to_logit = {idx: logits[idx, order] for idx in idxes}
idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True)
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()
return ret
def get_orders(_,bounding_boxes):
"""
Detects reading order for Arabic text layout, given bounding boxes in xyxy format.
Args:
- bounding_boxes: List of tuples (x1, y1, x2, y2), where
(x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner of the bounding box.
Returns:
- A list of indices representing the reading order.
"""
# Convert to numpy array for easier processing
bounding_boxes = [tuple(b) for b in bounding_boxes]
boxes = np.array(bounding_boxes)
# Extract positions: (x1, y1) as the top-left, (x2, y2) as the bottom-right
# Sort by vertical position first (y1), and then horizontal position (x1), with right-to-left sorting
sorted_indices = np.lexsort((boxes[:, 0], boxes[:, 1])) # Sort by y1, then by x1 (right-to-left)
# Sort within rows by checking overlap tolerance for y coordinates
rows = []
tolerance = 10 # Tolerance for grouping elements into rows
for idx in sorted_indices:
placed = False
for row in rows:
# Check if the box belongs to an existing row (y1 overlap within tolerance)
if abs(row[-1][1] - boxes[idx][1]) < tolerance:
row.append(boxes[idx])
placed = True
break
if not placed:
rows.append([boxes[idx]])
# Within each row, sort by x1 (right-to-left)
reading_order = []
for row in rows:
row.sort(key=lambda b: -b[0]) # Sort by x1 descending (right-to-left)
reading_order.extend(row)
# Return the indices of the bounding boxes in the correct reading order
return [bounding_boxes.index(tuple(box)) for box in reading_order]
# def get_orders(image_path, boxes):
# b = scale_and_normalize_boxes(boxes)
# inputs = boxes2inputs(b)
# inputs = {k: v.to(layout_model.device) for k, v in inputs.items()} # Move inputs to model device
# logits = layout_model(**inputs).logits.cpu().squeeze(0) # Perform inference and get logits
# orders = parse_logits(logits, len(b))
# return orders
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
model = RTDETR(model_dir)
def detect_layout(img, conf_threshold, iou_threshold):
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
results = model.predict(
source=img,
conf=conf_threshold,
iou=iou_threshold,
show_labels=True,
show_conf=True,
imgsz=1024,
agnostic_nms= True,
max_det=34,
nms=True
)[0]
bboxes = results.boxes.xyxy.cpu().tolist()
classes = results.boxes.cls.cpu().tolist()
mapping = {0: 'CheckBox',
1: 'List',
2: 'P',
3: 'abandon',
4: 'figure',
5: 'gridless_table',
6: 'handwritten_signature',
7: 'qr_code',
8: 'table',
9: 'title'}
classes = [mapping[i] for i in classes]
return bboxes , classes
from PIL import Image, ImageDraw, ImageFont
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
# Define a color map for each class name
class_colors = {
'CheckBox': 'orange',
'List': 'blue',
'P': 'green',
'abandon': 'purple',
'figure': 'cyan',
'gridless_table': 'yellow',
'handwritten_signature': 'magenta',
'qr_code': 'red',
'table': 'brown',
'title': 'pink'
}
# Open the image using PIL
image = image_path
# Prepare to draw on the image
draw = ImageDraw.Draw(image)
# Try loading a default font, if it fails, use a basic font
try:
font = ImageFont.truetype("arial.ttf", 20)
title_font = ImageFont.truetype("arial.ttf", 30) # Larger font for titles
except IOError:
font = ImageFont.load_default(size = 30)
title_font = font # Use the same font for title if custom font fails
# Loop through the bounding boxes and corresponding labels
for i in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[i]
class_name = classes[i]
order = reading_order[i]
# Get the color for the class
color = class_colors[class_name]
# If it's a title, make the bounding box thicker and text larger
if class_name == 'title':
box_thickness = 4 # Thicker box for title
label_font = title_font # Larger font for title
else:
box_thickness = 2 # Default box thickness
label_font = font # Default font for other classes
# Draw the rectangle with the class color and box thickness
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness)
# Label the box with the class and order
label = f"{class_name}-{order}"
# Calculate text size using textbbox() to get the bounding box of the text
bbox = draw.textbbox((x1, y1 - 20), label, font=label_font)
label_width = bbox[2] - bbox[0]
label_height = bbox[3] - bbox[1]
# Draw the text above the box
draw.text((x1, y1 - label_height), label, fill="black", font=label_font)
# Return the modified image as a PIL image object
return image
def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000):
"""
Scales and normalizes bounding boxes from original dimensions to new dimensions.
Args:
bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format.
old_width (int or float): Width of the original image.
old_height (int or float): Height of the original image.
new_width (int or float): Width of the scaled image.
new_height (int or float): Height of the scaled image.
normalize_width (int or float): Width of the normalization range (e.g., target resolution width).
normalize_height (int or float): Height of the normalization range (e.g., target resolution height).
Returns:
list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format.
"""
scale_x = new_width / old_width
scale_y = new_height / old_height
def scale_and_normalize_single(bbox):
# Extract coordinates
x_min, y_min, x_max, y_max = bbox
# Scale to new dimensions
x_min *= scale_x
y_min *= scale_y
x_max *= scale_x
y_max *= scale_y
# Normalize to the target range
x_min = int(normalize_width * (x_min / new_width))
y_min = int(normalize_height * (y_min / new_height))
x_max = int(normalize_width * (x_max / new_width))
y_max = int(normalize_height * (y_max / new_height))
return [x_min, y_min, x_max, y_max]
# Process all bounding boxes
return [scale_and_normalize_single(bbox) for bbox in bboxes]
from PIL import Image, ImageDraw
def is_inside(box1, box2):
# Check if box1 is inside box2
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
def is_overlap(box1, box2):
# Check if box1 overlaps with box2
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
# No overlap if one box is to the left, right, above, or below the other box
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)
def remove_overlapping_and_inside_boxes(boxes, classes):
to_remove = []
for i, box1 in enumerate(boxes):
for j, box2 in enumerate(boxes):
if i != j:
if is_inside(box1, box2):
# Mark the smaller (inside) box for removal
to_remove.append(i)
elif is_inside(box2, box1):
# Mark the smaller (inside) box for removal
to_remove.append(j)
elif is_overlap(box1, box2):
# If the boxes overlap, mark the smaller one for removal
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
to_remove.append(j)
else:
to_remove.append(i)
# Remove duplicates and sort by the index to keep original boxes
to_remove = sorted(set(to_remove), reverse=True)
# Remove the boxes and their corresponding classes from the list
for idx in to_remove:
del boxes[idx]
del classes[idx]
return boxes, classes
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
IMAGE_PATH = IMAGE_PATH.resize((1024,1024))
bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold)
bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes)
orders = get_orders(IMAGE_PATH, bboxes)
final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders)
return final_image
iface = gr.Interface(
fn=full_predictions,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"),
],
outputs=gr.Image(type="pil", label="Result"),
title="Ultralytics Gradio",
description="Upload images for inference. The Ultralytics YOLO11n model is used by default.",
examples=[
["kashida.png", 0.2, 0.45],
["image.jpg", 0.2, 0.45],
["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45]
],
theme=gr.themes.Default()
)
if __name__ == "__main__":
iface.launch()