Spaces:
Sleeping
Sleeping
put online demo
Browse files- .gitignore +13 -0
- OCR.py +415 -0
- demo_streamlit.py +339 -0
- display.py +181 -0
- eval.py +649 -0
- flask.py +6 -0
- htlm_webpage.py +141 -0
- packages.txt +1 -0
- requirements.txt +10 -0
- toXML.py +351 -0
- train.py +394 -0
- utils.py +936 -0
.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
__pycache__/
|
| 3 |
+
|
| 4 |
+
temp/
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
VISION_KEY.json
|
| 8 |
+
|
| 9 |
+
*.pth
|
| 10 |
+
|
| 11 |
+
.streamlit/secrets.toml
|
| 12 |
+
|
| 13 |
+
backup/
|
OCR.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
from azure.ai.vision.imageanalysis import ImageAnalysisClient
|
| 4 |
+
from azure.ai.vision.imageanalysis.models import VisualFeatures
|
| 5 |
+
from azure.core.credentials import AzureKeyCredential
|
| 6 |
+
import time
|
| 7 |
+
import numpy as np
|
| 8 |
+
import networkx as nx
|
| 9 |
+
from eval import iou
|
| 10 |
+
from utils import class_dict, proportion_inside
|
| 11 |
+
import json
|
| 12 |
+
from utils import rescale_boxes as rescale
|
| 13 |
+
import streamlit as st
|
| 14 |
+
|
| 15 |
+
VISION_KEY = st.secrets["VISION_KEY"]
|
| 16 |
+
VISION_ENDPOINT = st.secrets["VISION_ENDPOINT"]
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
#If local execution
|
| 20 |
+
with open("VISION_KEY.json", "r") as json_file:
|
| 21 |
+
json_data = json.load(json_file)
|
| 22 |
+
|
| 23 |
+
# Step 2: Parse the JSON data (this is done by json.load automatically)
|
| 24 |
+
VISION_KEY = json_data["VISION_KEY"]
|
| 25 |
+
VISION_ENDPOINT = json_data["VISION_ENDPOINT"]
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def sample_ocr_image_file(image_data):
|
| 30 |
+
# Set the values of your computer vision endpoint and computer vision key
|
| 31 |
+
# as environment variables:
|
| 32 |
+
try:
|
| 33 |
+
endpoint = VISION_ENDPOINT
|
| 34 |
+
key = VISION_KEY
|
| 35 |
+
except KeyError:
|
| 36 |
+
print("Missing environment variable 'VISION_ENDPOINT' or 'VISION_KEY'")
|
| 37 |
+
print("Set them before running this sample.")
|
| 38 |
+
exit()
|
| 39 |
+
|
| 40 |
+
# Create an Image Analysis client
|
| 41 |
+
client = ImageAnalysisClient(
|
| 42 |
+
endpoint=endpoint,
|
| 43 |
+
credential=AzureKeyCredential(key)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Extract text (OCR) from an image stream. This will be a synchronously (blocking) call.
|
| 47 |
+
result = client.analyze(
|
| 48 |
+
image_data=image_data,
|
| 49 |
+
visual_features=[VisualFeatures.READ]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
return result
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def text_prediction(image):
|
| 56 |
+
#transform the image into a byte array
|
| 57 |
+
image.save('temp.jpg')
|
| 58 |
+
with open('temp.jpg', 'rb') as f:
|
| 59 |
+
image_data = f.read()
|
| 60 |
+
ocr_result = sample_ocr_image_file(image_data)
|
| 61 |
+
#delete the temporary image
|
| 62 |
+
os.remove('temp.jpg')
|
| 63 |
+
return ocr_result
|
| 64 |
+
|
| 65 |
+
def filter_text(ocr_result, threshold=0.5):
|
| 66 |
+
words_to_cancel = {"+",".",",","#","@","!","?","(",")","[","]","{","}","<",">","/","\\","|","-","_","=","&","^","%","$","£","€","¥","¢","¤","§","©","®","™","°","±","×","÷","¶","∆","∏","∑","∞","√","∫","≈","≠","≤","≥","≡","∼"}
|
| 67 |
+
# Add every other one-letter word to the list of words to cancel, except 'I' and 'a'
|
| 68 |
+
for letter in "bcdefghjklmnopqrstuvwxyz1234567890": # All lowercase letters except 'a'
|
| 69 |
+
words_to_cancel.add(letter)
|
| 70 |
+
words_to_cancel.add("i")
|
| 71 |
+
words_to_cancel.add(letter.upper()) # Add the uppercase version as well
|
| 72 |
+
characters_to_cancel = {"+", "<", ">"} # Characters to cancel
|
| 73 |
+
|
| 74 |
+
list_of_lines = []
|
| 75 |
+
|
| 76 |
+
for block in ocr_result['readResult']['blocks']:
|
| 77 |
+
for line in block['lines']:
|
| 78 |
+
line_text = []
|
| 79 |
+
x_min, y_min = float('inf'), float('inf')
|
| 80 |
+
x_max, y_max = float('-inf'), float('-inf')
|
| 81 |
+
for word in line['words']:
|
| 82 |
+
if word['text'] in words_to_cancel or any(disallowed_char in word['text'] for disallowed_char in characters_to_cancel):
|
| 83 |
+
continue
|
| 84 |
+
if word['confidence'] > threshold:
|
| 85 |
+
if word['text']:
|
| 86 |
+
line_text.append(word['text'])
|
| 87 |
+
x = [point['x'] for point in word['boundingPolygon']]
|
| 88 |
+
y = [point['y'] for point in word['boundingPolygon']]
|
| 89 |
+
x_min = min(x_min, min(x))
|
| 90 |
+
y_min = min(y_min, min(y))
|
| 91 |
+
x_max = max(x_max, max(x))
|
| 92 |
+
y_max = max(y_max, max(y))
|
| 93 |
+
if line_text: # If there are valid words in the line
|
| 94 |
+
list_of_lines.append({
|
| 95 |
+
'text': ' '.join(line_text),
|
| 96 |
+
'boundingBox': [x_min,y_min,x_max,y_max]
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
list_text = []
|
| 100 |
+
list_bbox = []
|
| 101 |
+
for i in range(len(list_of_lines)):
|
| 102 |
+
list_text.append(list_of_lines[i]['text'])
|
| 103 |
+
for i in range(len(list_of_lines)):
|
| 104 |
+
list_bbox.append(list_of_lines[i]['boundingBox'])
|
| 105 |
+
|
| 106 |
+
list_of_lines = [list_bbox, list_text]
|
| 107 |
+
|
| 108 |
+
return list_of_lines
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_box_points(box):
|
| 114 |
+
"""Returns all critical points of a box: corners and midpoints of edges."""
|
| 115 |
+
xmin, ymin, xmax, ymax = box
|
| 116 |
+
return np.array([
|
| 117 |
+
[xmin, ymin], # Bottom-left corner
|
| 118 |
+
[xmax, ymin], # Bottom-right corner
|
| 119 |
+
[xmin, ymax], # Top-left corner
|
| 120 |
+
[xmax, ymax], # Top-right corner
|
| 121 |
+
[(xmin + xmax) / 2, ymin], # Midpoint of bottom edge
|
| 122 |
+
[(xmin + xmax) / 2, ymax], # Midpoint of top edge
|
| 123 |
+
[xmin, (ymin + ymax) / 2], # Midpoint of left edge
|
| 124 |
+
[xmax, (ymin + ymax) / 2] # Midpoint of right edge
|
| 125 |
+
])
|
| 126 |
+
|
| 127 |
+
def min_distance_between_boxes(box1, box2):
|
| 128 |
+
"""Computes the minimum distance between two boxes considering all critical points."""
|
| 129 |
+
points1 = get_box_points(box1)
|
| 130 |
+
points2 = get_box_points(box2)
|
| 131 |
+
|
| 132 |
+
min_dist = float('inf')
|
| 133 |
+
for point1 in points1:
|
| 134 |
+
for point2 in points2:
|
| 135 |
+
dist = np.linalg.norm(point1 - point2)
|
| 136 |
+
if dist < min_dist:
|
| 137 |
+
min_dist = dist
|
| 138 |
+
return min_dist
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def is_inside(box1, box2):
|
| 142 |
+
"""Check if the center of box1 is inside box2."""
|
| 143 |
+
x_center = (box1[0] + box1[2]) / 2
|
| 144 |
+
y_center = (box1[1] + box1[3]) / 2
|
| 145 |
+
return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3]
|
| 146 |
+
|
| 147 |
+
def are_close(box1, box2, threshold=50):
|
| 148 |
+
"""Determines if boxes are close based on their corners and center points."""
|
| 149 |
+
corners1 = np.array([
|
| 150 |
+
[box1[0], box1[1]], [box1[0], box1[3]], [box1[2], box1[1]], [box1[2], box1[3]],
|
| 151 |
+
[(box1[0]+box1[2])/2, box1[1]], [(box1[0]+box1[2])/2, box1[3]],
|
| 152 |
+
[box1[0], (box1[1]+box1[3])/2], [box1[2], (box1[1]+box1[3])/2]
|
| 153 |
+
])
|
| 154 |
+
corners2 = np.array([
|
| 155 |
+
[box2[0], box2[1]], [box2[0], box2[3]], [box2[2], box2[1]], [box2[2], box2[3]],
|
| 156 |
+
[(box2[0]+box2[2])/2, box2[1]], [(box2[0]+box2[2])/2, box2[3]],
|
| 157 |
+
[box2[0], (box2[1]+box2[3])/2], [box2[2], (box2[1]+box2[3])/2]
|
| 158 |
+
])
|
| 159 |
+
for c1 in corners1:
|
| 160 |
+
for c2 in corners2:
|
| 161 |
+
if np.linalg.norm(c1 - c2) < threshold:
|
| 162 |
+
return True
|
| 163 |
+
return False
|
| 164 |
+
|
| 165 |
+
def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
|
| 166 |
+
"""Find the closest box to the given text box within a specified threshold."""
|
| 167 |
+
min_distance = float('inf')
|
| 168 |
+
closest_index = None
|
| 169 |
+
|
| 170 |
+
#check if the text is inside a sequenceFlow
|
| 171 |
+
for j in range(len(all_boxes)):
|
| 172 |
+
if proportion_inside(text_box, all_boxes[j])>iou_threshold and labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
| 173 |
+
return j
|
| 174 |
+
|
| 175 |
+
for i, box in enumerate(all_boxes):
|
| 176 |
+
# Compute the center of both boxes
|
| 177 |
+
center_text = np.array([(text_box[0] + text_box[2]) / 2, (text_box[1] + text_box[3]) / 2])
|
| 178 |
+
center_box = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2])
|
| 179 |
+
|
| 180 |
+
# Calculate Euclidean distance between centers
|
| 181 |
+
distance = np.linalg.norm(center_text - center_box)
|
| 182 |
+
|
| 183 |
+
# Update closest box if this box is nearer
|
| 184 |
+
if distance < min_distance:
|
| 185 |
+
min_distance = distance
|
| 186 |
+
closest_index = i
|
| 187 |
+
|
| 188 |
+
# Check if the closest box found is within the acceptable threshold
|
| 189 |
+
if min_distance < threshold:
|
| 190 |
+
return closest_index
|
| 191 |
+
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def is_vertical(box):
|
| 196 |
+
"""Determine if the text in the bounding box is vertically aligned."""
|
| 197 |
+
width = box[2] - box[0]
|
| 198 |
+
height = box[3] - box[1]
|
| 199 |
+
return (height > 2*width)
|
| 200 |
+
|
| 201 |
+
def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
|
| 202 |
+
"""Maps text boxes to task boxes and groups texts within each task based on proximity."""
|
| 203 |
+
G = nx.Graph()
|
| 204 |
+
|
| 205 |
+
# Map each text box to the nearest task box
|
| 206 |
+
task_to_texts = {i: [] for i in range(len(task_boxes))}
|
| 207 |
+
information_texts = [] # texts not inside any task box
|
| 208 |
+
text_to_task_mapped = [False] * len(text_boxes)
|
| 209 |
+
|
| 210 |
+
for idx, text_box in enumerate(text_boxes):
|
| 211 |
+
mapped = False
|
| 212 |
+
for jdx, task_box in enumerate(task_boxes):
|
| 213 |
+
if proportion_inside(text_box, task_box)>iou_threshold:
|
| 214 |
+
task_to_texts[jdx].append(idx)
|
| 215 |
+
text_to_task_mapped[idx] = True
|
| 216 |
+
mapped = True
|
| 217 |
+
break
|
| 218 |
+
if not mapped:
|
| 219 |
+
information_texts.append(idx)
|
| 220 |
+
|
| 221 |
+
all_grouped_texts = []
|
| 222 |
+
sentence_boxes = [] # Store the bounding box for each sentence
|
| 223 |
+
|
| 224 |
+
# Process texts for each task
|
| 225 |
+
for task_texts in task_to_texts.values():
|
| 226 |
+
G.clear()
|
| 227 |
+
for i in task_texts:
|
| 228 |
+
G.add_node(i)
|
| 229 |
+
for j in task_texts:
|
| 230 |
+
if i != j and are_close(text_boxes[i], text_boxes[j]) and not is_vertical(text_boxes[i]) and not is_vertical(text_boxes[j]):
|
| 231 |
+
G.add_edge(i, j)
|
| 232 |
+
|
| 233 |
+
groups = list(nx.connected_components(G))
|
| 234 |
+
for group in groups:
|
| 235 |
+
group = list(group)
|
| 236 |
+
lines = {}
|
| 237 |
+
for idx in group:
|
| 238 |
+
y_center = (text_boxes[idx][1] + text_boxes[idx][3]) / 2
|
| 239 |
+
found_line = False
|
| 240 |
+
for line in lines:
|
| 241 |
+
if abs(y_center - line) < (text_boxes[idx][3] - text_boxes[idx][1]) / 2:
|
| 242 |
+
lines[line].append(idx)
|
| 243 |
+
found_line = True
|
| 244 |
+
break
|
| 245 |
+
if not found_line:
|
| 246 |
+
lines[y_center] = [idx]
|
| 247 |
+
|
| 248 |
+
sorted_lines = sorted(lines.keys())
|
| 249 |
+
grouped_texts = []
|
| 250 |
+
min_x = min_y = float('inf')
|
| 251 |
+
max_x = max_y = -float('inf')
|
| 252 |
+
|
| 253 |
+
for line in sorted_lines:
|
| 254 |
+
sorted_indices = sorted(lines[line], key=lambda idx: text_boxes[idx][0])
|
| 255 |
+
line_text = ' '.join(texts[idx] for idx in sorted_indices)
|
| 256 |
+
grouped_texts.append(line_text)
|
| 257 |
+
|
| 258 |
+
for idx in sorted_indices:
|
| 259 |
+
box = text_boxes[idx]
|
| 260 |
+
min_x = min(min_x-5, box[0]-5)
|
| 261 |
+
min_y = min(min_y-5, box[1]-5)
|
| 262 |
+
max_x = max(max_x+5, box[2]+5)
|
| 263 |
+
max_y = max(max_y+5, box[3]+5)
|
| 264 |
+
|
| 265 |
+
all_grouped_texts.append(' '.join(grouped_texts))
|
| 266 |
+
sentence_boxes.append([min_x, min_y, max_x, max_y])
|
| 267 |
+
|
| 268 |
+
# Group information texts
|
| 269 |
+
G.clear()
|
| 270 |
+
info_sentence_boxes = []
|
| 271 |
+
|
| 272 |
+
for i in information_texts:
|
| 273 |
+
G.add_node(i)
|
| 274 |
+
for j in information_texts:
|
| 275 |
+
if i != j and are_close(text_boxes[i], text_boxes[j], percentage_thresh * min_dist) and not is_vertical(text_boxes[i]) and not is_vertical(text_boxes[j]):
|
| 276 |
+
G.add_edge(i, j)
|
| 277 |
+
|
| 278 |
+
info_groups = list(nx.connected_components(G))
|
| 279 |
+
information_grouped_texts = []
|
| 280 |
+
for group in info_groups:
|
| 281 |
+
group = list(group)
|
| 282 |
+
lines = {}
|
| 283 |
+
for idx in group:
|
| 284 |
+
y_center = (text_boxes[idx][1] + text_boxes[idx][3]) / 2
|
| 285 |
+
found_line = False
|
| 286 |
+
for line in lines:
|
| 287 |
+
if abs(y_center - line) < (text_boxes[idx][3] - text_boxes[idx][1]) / 2:
|
| 288 |
+
lines[line].append(idx)
|
| 289 |
+
found_line = True
|
| 290 |
+
break
|
| 291 |
+
if not found_line:
|
| 292 |
+
lines[y_center] = [idx]
|
| 293 |
+
|
| 294 |
+
sorted_lines = sorted(lines.keys())
|
| 295 |
+
grouped_texts = []
|
| 296 |
+
min_x = min_y = float('inf')
|
| 297 |
+
max_x = max_y = -float('inf')
|
| 298 |
+
|
| 299 |
+
for line in sorted_lines:
|
| 300 |
+
sorted_indices = sorted(lines[line], key=lambda idx: text_boxes[idx][0])
|
| 301 |
+
line_text = ' '.join(texts[idx] for idx in sorted_indices)
|
| 302 |
+
grouped_texts.append(line_text)
|
| 303 |
+
|
| 304 |
+
for idx in sorted_indices:
|
| 305 |
+
box = text_boxes[idx]
|
| 306 |
+
min_x = min(min_x, box[0])
|
| 307 |
+
min_y = min(min_y, box[1])
|
| 308 |
+
max_x = max(max_x, box[2])
|
| 309 |
+
max_y = max(max_y, box[3])
|
| 310 |
+
|
| 311 |
+
information_grouped_texts.append(' '.join(grouped_texts))
|
| 312 |
+
info_sentence_boxes.append([min_x, min_y, max_x, max_y])
|
| 313 |
+
|
| 314 |
+
return all_grouped_texts, sentence_boxes, information_grouped_texts, info_sentence_boxes
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
|
| 318 |
+
|
| 319 |
+
########### REFAIRE CETTE FONCTION ###########
|
| 320 |
+
#refaire la fonction pour qu'elle prenne en premier les elements qui sont dans les task et ensuite prendre un seuil de distance pour les autres elements
|
| 321 |
+
#ou sinon faire la distance entre les elements et non pas seulement les tasks
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# Example usage
|
| 325 |
+
boxes = rescale(scale, full_pred['boxes'])
|
| 326 |
+
|
| 327 |
+
min_dist = 200
|
| 328 |
+
labels = full_pred['labels']
|
| 329 |
+
avoid = [list(class_dict.values()).index('pool'), list(class_dict.values()).index('lane'), list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('dataAssociation')]
|
| 330 |
+
for i in range(len(boxes)):
|
| 331 |
+
box1 = boxes[i]
|
| 332 |
+
if labels[i] in avoid:
|
| 333 |
+
continue
|
| 334 |
+
for j in range(i + 1, len(boxes)):
|
| 335 |
+
box2 = boxes[j]
|
| 336 |
+
if labels[j] in avoid:
|
| 337 |
+
continue
|
| 338 |
+
dist = min_distance_between_boxes(box1, box2)
|
| 339 |
+
min_dist = min(min_dist, dist)
|
| 340 |
+
|
| 341 |
+
#print("Minimum distance between boxes:", min_dist)
|
| 342 |
+
|
| 343 |
+
text_pred[0] = rescale(scale, text_pred[0])
|
| 344 |
+
task_boxes = [box for i, box in enumerate(boxes) if full_pred['labels'][i] == list(class_dict.values()).index('task')]
|
| 345 |
+
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_pred[0], text_pred[1], min_dist=min_dist)
|
| 346 |
+
BPMN_id = set(full_pred['BPMN_id']) # This ensures uniqueness of task names
|
| 347 |
+
text_mapping = {id: '' for id in BPMN_id}
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if print_sentences:
|
| 351 |
+
for sentence, box in zip(grouped_sentences, sentence_bounding_boxes):
|
| 352 |
+
print("Task-related Text:", sentence)
|
| 353 |
+
print("Bounding Box:", box)
|
| 354 |
+
print("Information Texts:", info_texts)
|
| 355 |
+
print("Information Bounding Boxes:", info_boxes)
|
| 356 |
+
|
| 357 |
+
# Map the grouped sentences to the corresponding task
|
| 358 |
+
for i in range(len(sentence_bounding_boxes)):
|
| 359 |
+
for j in range(len(boxes)):
|
| 360 |
+
if proportion_inside(sentence_bounding_boxes[i], boxes[j])>iou_threshold and full_pred['labels'][j] == list(class_dict.values()).index('task'):
|
| 361 |
+
text_mapping[full_pred['BPMN_id'][j]]=grouped_sentences[i]
|
| 362 |
+
|
| 363 |
+
# Map the grouped sentences to the corresponding pool
|
| 364 |
+
for i in range(len(info_boxes)):
|
| 365 |
+
if is_vertical(info_boxes[i]):
|
| 366 |
+
for j in range(len(boxes)):
|
| 367 |
+
if proportion_inside(info_boxes[i], boxes[j])>0 and full_pred['labels'][j] == list(class_dict.values()).index('pool'):
|
| 368 |
+
print("Text:", info_texts[i], "associate with ", full_pred['BPMN_id'][j])
|
| 369 |
+
bpmn_id = full_pred['BPMN_id'][j]
|
| 370 |
+
# Append new text or create new entry if not existing
|
| 371 |
+
if bpmn_id in text_mapping:
|
| 372 |
+
text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
|
| 373 |
+
else:
|
| 374 |
+
text_mapping[bpmn_id] = info_texts[i]
|
| 375 |
+
info_texts[i] = '' # Clear the text to avoid re-use
|
| 376 |
+
|
| 377 |
+
# Map the grouped sentences to the corresponding object
|
| 378 |
+
for i in range(len(info_boxes)):
|
| 379 |
+
if is_vertical(info_boxes[i]):
|
| 380 |
+
continue # Skip if the text is vertical
|
| 381 |
+
for j in range(len(boxes)):
|
| 382 |
+
if info_texts[i] == '':
|
| 383 |
+
continue # Skip if there's no text
|
| 384 |
+
if (proportion_inside(info_boxes[i], boxes[j])>0 or are_close(info_boxes[i], boxes[j], threshold=percentage_thresh*min_dist)) and (full_pred['labels'][j] == list(class_dict.values()).index('event')
|
| 385 |
+
or full_pred['labels'][j] == list(class_dict.values()).index('messageEvent')
|
| 386 |
+
or full_pred['labels'][j] == list(class_dict.values()).index('timerEvent')
|
| 387 |
+
or full_pred['labels'][j] == list(class_dict.values()).index('dataObject')) :
|
| 388 |
+
bpmn_id = full_pred['BPMN_id'][j]
|
| 389 |
+
# Append new text or create new entry if not existing
|
| 390 |
+
if bpmn_id in text_mapping:
|
| 391 |
+
text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
|
| 392 |
+
else:
|
| 393 |
+
text_mapping[bpmn_id] = info_texts[i]
|
| 394 |
+
info_texts[i] = '' # Clear the text to avoid re-use
|
| 395 |
+
|
| 396 |
+
# Map the grouped sentences to the corresponding flow
|
| 397 |
+
for i in range(len(info_boxes)):
|
| 398 |
+
if info_texts[i] == '' or is_vertical(info_boxes[i]):
|
| 399 |
+
continue # Skip if there's no text
|
| 400 |
+
# Find the closest box within the defined threshold
|
| 401 |
+
closest_index = find_closest_box(info_boxes[i], boxes, full_pred['labels'], threshold=4*min_dist)
|
| 402 |
+
if closest_index is not None and (full_pred['labels'][closest_index] == list(class_dict.values()).index('sequenceFlow') or full_pred['labels'][closest_index] == list(class_dict.values()).index('messageFlow')):
|
| 403 |
+
bpmn_id = full_pred['BPMN_id'][closest_index]
|
| 404 |
+
# Append new text or create new entry if not existing
|
| 405 |
+
if bpmn_id in text_mapping:
|
| 406 |
+
text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
|
| 407 |
+
else:
|
| 408 |
+
text_mapping[bpmn_id] = info_texts[i]
|
| 409 |
+
info_texts[i] = '' # Clear the text to avoid re-use
|
| 410 |
+
|
| 411 |
+
if print_sentences:
|
| 412 |
+
print("Text Mapping:", text_mapping)
|
| 413 |
+
print("Information Texts left:", info_texts)
|
| 414 |
+
|
| 415 |
+
return text_mapping
|
demo_streamlit.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import streamlit.components.v1 as components
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision.transforms import functional as F
|
| 6 |
+
from PIL import Image, ImageEnhance
|
| 7 |
+
from htlm_webpage import display_bpmn_xml
|
| 8 |
+
import gc
|
| 9 |
+
import psutil
|
| 10 |
+
|
| 11 |
+
from OCR import text_prediction, filter_text, mapping_text, rescale
|
| 12 |
+
from train import prepare_model
|
| 13 |
+
from utils import draw_annotations, create_loader, class_dict, arrow_dict, object_dict
|
| 14 |
+
from toXML import calculate_pool_bounds, add_diagram_elements
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from toXML import create_bpmn_object, create_flow_element
|
| 17 |
+
import xml.etree.ElementTree as ET
|
| 18 |
+
import numpy as np
|
| 19 |
+
from display import draw_stream
|
| 20 |
+
from eval import full_prediction
|
| 21 |
+
from streamlit_image_comparison import image_comparison
|
| 22 |
+
from xml.dom import minidom
|
| 23 |
+
from streamlit_cropper import st_cropper
|
| 24 |
+
from streamlit_drawable_canvas import st_canvas
|
| 25 |
+
from utils import find_closest_object
|
| 26 |
+
from train import get_faster_rcnn_model, get_arrow_model
|
| 27 |
+
import gdown
|
| 28 |
+
|
| 29 |
+
def get_memory_usage():
|
| 30 |
+
process = psutil.Process()
|
| 31 |
+
mem_info = process.memory_info()
|
| 32 |
+
return mem_info.rss / (1024 ** 2) # Return memory usage in MB
|
| 33 |
+
|
| 34 |
+
def clear_memory():
|
| 35 |
+
st.session_state.clear()
|
| 36 |
+
gc.collect()
|
| 37 |
+
|
| 38 |
+
# Function to read XML content from a file
|
| 39 |
+
def read_xml_file(filepath):
|
| 40 |
+
""" Read XML content from a file """
|
| 41 |
+
with open(filepath, 'r', encoding='utf-8') as file:
|
| 42 |
+
return file.read()
|
| 43 |
+
|
| 44 |
+
# Function to modify bounding box positions based on the given sizes
|
| 45 |
+
def modif_box_pos(pred, size):
|
| 46 |
+
for i, (x1, y1, x2, y2) in enumerate(pred['boxes']):
|
| 47 |
+
center = [(x1 + x2) / 2, (y1 + y2) / 2]
|
| 48 |
+
label = class_dict[pred['labels'][i]]
|
| 49 |
+
if label in size:
|
| 50 |
+
pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
|
| 51 |
+
return pred
|
| 52 |
+
|
| 53 |
+
# Function to create a BPMN XML file from prediction results
|
| 54 |
+
def create_XML(full_pred, text_mapping, scale):
|
| 55 |
+
namespaces = {
|
| 56 |
+
'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
|
| 57 |
+
'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
|
| 58 |
+
'di': 'http://www.omg.org/spec/DD/20100524/DI',
|
| 59 |
+
'dc': 'http://www.omg.org/spec/DD/20100524/DC',
|
| 60 |
+
'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
size_elements = {
|
| 64 |
+
'start': (54, 54),
|
| 65 |
+
'task': (150, 120),
|
| 66 |
+
'message': (54, 54),
|
| 67 |
+
'messageEvent': (54, 54),
|
| 68 |
+
'end': (54, 54),
|
| 69 |
+
'exclusiveGateway': (75, 75),
|
| 70 |
+
'event': (54, 54),
|
| 71 |
+
'parallelGateway': (75, 75),
|
| 72 |
+
'sequenceFlow': (225, 15),
|
| 73 |
+
'pool': (375, 150),
|
| 74 |
+
'lane': (300, 150),
|
| 75 |
+
'dataObject': (60, 90),
|
| 76 |
+
'dataStore': (90, 90),
|
| 77 |
+
'subProcess': (180, 135),
|
| 78 |
+
'eventBasedGateway': (75, 75),
|
| 79 |
+
'timerEvent': (60, 60),
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
definitions = ET.Element('bpmn:definitions', {
|
| 84 |
+
'xmlns:xsi': namespaces['xsi'],
|
| 85 |
+
'xmlns:bpmn': namespaces['bpmn'],
|
| 86 |
+
'xmlns:bpmndi': namespaces['bpmndi'],
|
| 87 |
+
'xmlns:di': namespaces['di'],
|
| 88 |
+
'xmlns:dc': namespaces['dc'],
|
| 89 |
+
'targetNamespace': "http://example.bpmn.com",
|
| 90 |
+
'id': "simpleExample"
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
# Create BPMN collaboration element
|
| 94 |
+
collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
|
| 95 |
+
|
| 96 |
+
# Create BPMN process elements
|
| 97 |
+
process = []
|
| 98 |
+
for idx in range(len(full_pred['pool_dict'].items())):
|
| 99 |
+
process_id = f'process_{idx+1}'
|
| 100 |
+
process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]))
|
| 101 |
+
|
| 102 |
+
bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
|
| 103 |
+
bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
|
| 104 |
+
|
| 105 |
+
full_pred['boxes'] = rescale(scale, full_pred['boxes'])
|
| 106 |
+
|
| 107 |
+
# Add diagram elements for each pool
|
| 108 |
+
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
| 109 |
+
pool_id = f'participant_{idx+1}'
|
| 110 |
+
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
|
| 111 |
+
|
| 112 |
+
# Calculate the bounding box for the pool
|
| 113 |
+
if len(keep_elements) == 0:
|
| 114 |
+
min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index]
|
| 115 |
+
pool_width = max_x - min_x
|
| 116 |
+
pool_height = max_y - min_y
|
| 117 |
+
else:
|
| 118 |
+
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements)
|
| 119 |
+
pool_width = max_x - min_x + 100 # Adding padding
|
| 120 |
+
pool_height = max_y - min_y + 100 # Adding padding
|
| 121 |
+
|
| 122 |
+
add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
|
| 123 |
+
|
| 124 |
+
# Create BPMN elements for each pool
|
| 125 |
+
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
| 126 |
+
create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
|
| 127 |
+
|
| 128 |
+
# Create message flow elements
|
| 129 |
+
message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow']
|
| 130 |
+
for idx in message_flows:
|
| 131 |
+
create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True)
|
| 132 |
+
|
| 133 |
+
# Create sequence flow elements
|
| 134 |
+
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
| 135 |
+
for i in keep_elements:
|
| 136 |
+
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
|
| 137 |
+
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
|
| 138 |
+
|
| 139 |
+
# Generate pretty XML string
|
| 140 |
+
tree = ET.ElementTree(definitions)
|
| 141 |
+
rough_string = ET.tostring(definitions, 'utf-8')
|
| 142 |
+
reparsed = minidom.parseString(rough_string)
|
| 143 |
+
pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
|
| 144 |
+
|
| 145 |
+
full_pred['boxes'] = rescale(1/scale, full_pred['boxes'])
|
| 146 |
+
|
| 147 |
+
return pretty_xml_as_string
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Function to load the models only once and use session state to keep track of it
|
| 151 |
+
def load_models():
|
| 152 |
+
with st.spinner('Loading model...'):
|
| 153 |
+
model_object = get_faster_rcnn_model(len(object_dict))
|
| 154 |
+
model_arrow = get_arrow_model(len(arrow_dict),2)
|
| 155 |
+
|
| 156 |
+
url_arrow = 'https://drive.google.com/uc?id=1xwfvo7BgDWz-1jAiJC1DCF0Wp8YlFNWt'
|
| 157 |
+
url_object = 'https://drive.google.com/uc?id=1GiM8xOXG6M6r8J9HTOeMJz9NKu7iumZi'
|
| 158 |
+
|
| 159 |
+
# Define paths to save models
|
| 160 |
+
output_arrow = 'model_arrow.pth'
|
| 161 |
+
output_object = 'model_object.pth'
|
| 162 |
+
|
| 163 |
+
# Download models using gdown
|
| 164 |
+
if not Path(output_arrow).exists():
|
| 165 |
+
# Download models using gdown
|
| 166 |
+
gdown.download(url_arrow, output_arrow, quiet=False)
|
| 167 |
+
else:
|
| 168 |
+
print('Model arrow downloaded from local')
|
| 169 |
+
if not Path(output_object).exists():
|
| 170 |
+
gdown.download(url_object, output_object, quiet=False)
|
| 171 |
+
else:
|
| 172 |
+
print('Model object downloaded from local')
|
| 173 |
+
|
| 174 |
+
# Load models
|
| 175 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 176 |
+
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
|
| 177 |
+
model_object.load_state_dict(torch.load(output_object, map_location=device))
|
| 178 |
+
st.session_state.model_loaded = True
|
| 179 |
+
st.session_state.model_arrow = model_arrow
|
| 180 |
+
st.session_state.model_object = model_object
|
| 181 |
+
|
| 182 |
+
# Function to prepare the image for processing
|
| 183 |
+
def prepare_image(image, pad=True, new_size=(1333, 1333)):
|
| 184 |
+
original_size = image.size
|
| 185 |
+
# Calculate scale to fit the new size while maintaining aspect ratio
|
| 186 |
+
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
| 187 |
+
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
|
| 188 |
+
# Resize image to new scaled size
|
| 189 |
+
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
|
| 190 |
+
|
| 191 |
+
if pad:
|
| 192 |
+
enhancer = ImageEnhance.Brightness(image)
|
| 193 |
+
image = enhancer.enhance(1.5) # Adjust the brightness if necessary
|
| 194 |
+
# Pad the resized image to make it exactly the desired size
|
| 195 |
+
padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
|
| 196 |
+
image = F.pad(image, padding, fill=200, padding_mode='edge')
|
| 197 |
+
|
| 198 |
+
return new_scaled_size, image
|
| 199 |
+
|
| 200 |
+
# Function to display various options for image annotation
|
| 201 |
+
def display_options(image, score_threshold):
|
| 202 |
+
col1, col2, col3, col4, col5 = st.columns(5)
|
| 203 |
+
with col1:
|
| 204 |
+
write_class = st.toggle("Write Class", value=True)
|
| 205 |
+
draw_keypoints = st.toggle("Draw Keypoints", value=True)
|
| 206 |
+
draw_boxes = st.toggle("Draw Boxes", value=True)
|
| 207 |
+
with col2:
|
| 208 |
+
draw_text = st.toggle("Draw Text", value=False)
|
| 209 |
+
write_text = st.toggle("Write Text", value=False)
|
| 210 |
+
draw_links = st.toggle("Draw Links", value=False)
|
| 211 |
+
with col3:
|
| 212 |
+
write_score = st.toggle("Write Score", value=True)
|
| 213 |
+
write_idx = st.toggle("Write Index", value=False)
|
| 214 |
+
with col4:
|
| 215 |
+
# Define options for the dropdown menu
|
| 216 |
+
dropdown_options = [list(class_dict.values())[i] for i in range(len(class_dict))]
|
| 217 |
+
dropdown_options[0] = 'all'
|
| 218 |
+
selected_option = st.selectbox("Show class", dropdown_options)
|
| 219 |
+
|
| 220 |
+
# Draw the annotated image with selected options
|
| 221 |
+
annotated_image = draw_stream(
|
| 222 |
+
np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred,
|
| 223 |
+
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
|
| 224 |
+
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_print=selected_option,
|
| 225 |
+
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Display the original and annotated images side by side
|
| 229 |
+
image_comparison(
|
| 230 |
+
img1=annotated_image,
|
| 231 |
+
img2=image,
|
| 232 |
+
label1="Annotated Image",
|
| 233 |
+
label2="Original Image",
|
| 234 |
+
starting_position=99,
|
| 235 |
+
width=1000,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Function to perform inference on the uploaded image using the loaded models
|
| 239 |
+
def perform_inference(model_object, model_arrow, image, score_threshold):
|
| 240 |
+
_, uploaded_image = prepare_image(image, pad=False)
|
| 241 |
+
|
| 242 |
+
img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
|
| 243 |
+
|
| 244 |
+
# Display original image
|
| 245 |
+
if 'image_placeholder' not in st.session_state:
|
| 246 |
+
image_placeholder = st.empty() # Create an empty placeholder
|
| 247 |
+
image_placeholder.image(uploaded_image, caption='Original Image', width=1000)
|
| 248 |
+
|
| 249 |
+
# Prediction
|
| 250 |
+
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=0.5)
|
| 251 |
+
|
| 252 |
+
# Perform OCR on the uploaded image
|
| 253 |
+
ocr_results = text_prediction(uploaded_image)
|
| 254 |
+
|
| 255 |
+
# Filter and map OCR results to prediction results
|
| 256 |
+
st.session_state.text_pred = filter_text(ocr_results, threshold=0.5)
|
| 257 |
+
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=0.5)
|
| 258 |
+
|
| 259 |
+
# Remove the original image display
|
| 260 |
+
image_placeholder.empty()
|
| 261 |
+
|
| 262 |
+
# Force garbage collection
|
| 263 |
+
gc.collect()
|
| 264 |
+
|
| 265 |
+
@st.cache_data
|
| 266 |
+
def get_image(uploaded_file):
|
| 267 |
+
return Image.open(uploaded_file).convert('RGB')
|
| 268 |
+
|
| 269 |
+
def main():
|
| 270 |
+
st.set_page_config(layout="wide")
|
| 271 |
+
st.title("BPMN model recognition demo")
|
| 272 |
+
|
| 273 |
+
# Display current memory usage
|
| 274 |
+
memory_usage = get_memory_usage()
|
| 275 |
+
print(f"Current memory usage: {memory_usage:.2f} MB")
|
| 276 |
+
|
| 277 |
+
# Initialize the session state for storing pool bounding boxes
|
| 278 |
+
if 'pool_bboxes' not in st.session_state:
|
| 279 |
+
st.session_state.pool_bboxes = []
|
| 280 |
+
|
| 281 |
+
# Load the models using the defined function
|
| 282 |
+
if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
|
| 283 |
+
clear_memory()
|
| 284 |
+
load_models()
|
| 285 |
+
|
| 286 |
+
model_arrow = st.session_state.model_arrow
|
| 287 |
+
model_object = st.session_state.model_object
|
| 288 |
+
|
| 289 |
+
#Create the layout for the app
|
| 290 |
+
col1, col2 = st.columns(2)
|
| 291 |
+
with col1:
|
| 292 |
+
# Create a file uploader for the user to upload an image
|
| 293 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
| 294 |
+
|
| 295 |
+
# Display the uploaded image if the user has uploaded an image
|
| 296 |
+
if uploaded_file is not None:
|
| 297 |
+
original_image = get_image(uploaded_file)
|
| 298 |
+
col1, col2 = st.columns(2)
|
| 299 |
+
|
| 300 |
+
# Create a cropper to allow the user to crop the image and display the cropped image
|
| 301 |
+
with col1:
|
| 302 |
+
cropped_image = st_cropper(original_image, realtime_update=True, box_color='#0000FF', should_resize_image=True, default_coords=(30, original_image.size[0]-30, 30, original_image.size[1]-30))
|
| 303 |
+
with col2:
|
| 304 |
+
st.image(cropped_image, caption="Cropped Image", use_column_width=False, width=500)
|
| 305 |
+
|
| 306 |
+
# Display the options for the user to set the score threshold and scale
|
| 307 |
+
if cropped_image is not None:
|
| 308 |
+
col1, col2, col3 = st.columns(3)
|
| 309 |
+
with col1:
|
| 310 |
+
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
|
| 311 |
+
with col2:
|
| 312 |
+
st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
|
| 313 |
+
|
| 314 |
+
# Launch the prediction when the user clicks the button
|
| 315 |
+
if st.button("Launch Prediction"):
|
| 316 |
+
st.session_state.crop_image = cropped_image
|
| 317 |
+
with st.spinner('Processing...'):
|
| 318 |
+
perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
|
| 319 |
+
st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
|
| 320 |
+
|
| 321 |
+
print('Detection completed!')
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image
|
| 325 |
+
if 'prediction' in st.session_state and uploaded_file is not None:
|
| 326 |
+
st.success('Detection completed!')
|
| 327 |
+
display_options(st.session_state.crop_image, score_threshold)
|
| 328 |
+
|
| 329 |
+
#if st.session_state.prediction_up==True:
|
| 330 |
+
st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
|
| 331 |
+
|
| 332 |
+
display_bpmn_xml(st.session_state.bpmn_xml)
|
| 333 |
+
|
| 334 |
+
# Force garbage collection after display
|
| 335 |
+
gc.collect()
|
| 336 |
+
|
| 337 |
+
if __name__ == "__main__":
|
| 338 |
+
print('Starting the app...')
|
| 339 |
+
main()
|
display.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import draw_annotations, create_loader, class_dict, resize_boxes, resize_keypoints, find_other_keypoint
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from OCR import group_texts
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def draw_stream(image,
|
| 12 |
+
prediction=None,
|
| 13 |
+
text_predictions=None,
|
| 14 |
+
class_dict=class_dict,
|
| 15 |
+
draw_keypoints=False,
|
| 16 |
+
draw_boxes=False,
|
| 17 |
+
draw_text=False,
|
| 18 |
+
draw_links=False,
|
| 19 |
+
draw_twins=False,
|
| 20 |
+
draw_grouped_text=False,
|
| 21 |
+
write_class=False,
|
| 22 |
+
write_score=False,
|
| 23 |
+
write_text=False,
|
| 24 |
+
score_threshold=0.4,
|
| 25 |
+
write_idx=False,
|
| 26 |
+
keypoints_correction=False,
|
| 27 |
+
new_size=(1333, 1333),
|
| 28 |
+
only_print=None,
|
| 29 |
+
axis=False,
|
| 30 |
+
return_image=False,
|
| 31 |
+
resize=False):
|
| 32 |
+
"""
|
| 33 |
+
Draws annotations on images including bounding boxes, keypoints, links, and text.
|
| 34 |
+
|
| 35 |
+
Parameters:
|
| 36 |
+
- image (np.array): The image on which annotations will be drawn.
|
| 37 |
+
- target (dict): Ground truth data containing boxes, labels, etc.
|
| 38 |
+
- prediction (dict): Prediction data from a model.
|
| 39 |
+
- full_prediction (dict): Additional detailed prediction data, potentially including relationships.
|
| 40 |
+
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
|
| 41 |
+
- class_dict (dict): Mapping from class IDs to class names.
|
| 42 |
+
- draw_keypoints (bool): Flag to draw keypoints.
|
| 43 |
+
- draw_boxes (bool): Flag to draw bounding boxes.
|
| 44 |
+
- draw_text (bool): Flag to draw text annotations.
|
| 45 |
+
- draw_links (bool): Flag to draw links between annotations.
|
| 46 |
+
- draw_twins (bool): Flag to draw twins keypoints.
|
| 47 |
+
- write_class (bool): Flag to write class names near the annotations.
|
| 48 |
+
- write_score (bool): Flag to write scores near the annotations.
|
| 49 |
+
- write_text (bool): Flag to write OCR recognized text.
|
| 50 |
+
- score_threshold (float): Threshold for scores above which annotations will be drawn.
|
| 51 |
+
- only_print (str): Specific class name to filter annotations by.
|
| 52 |
+
- resize (bool): Whether to resize annotations to fit the image size.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# Convert image to RGB (if not already in that format)
|
| 56 |
+
if prediction is None:
|
| 57 |
+
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
image_copy = image.copy()
|
| 61 |
+
scale = max(image.shape[0], image.shape[1]) / 1000
|
| 62 |
+
|
| 63 |
+
original_size = (image.shape[0], image.shape[1])
|
| 64 |
+
# Calculate scale to fit the new size while maintaining aspect ratio
|
| 65 |
+
scale_ = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
| 66 |
+
new_scaled_size = (int(original_size[0] * scale_), int(original_size[1] * scale_))
|
| 67 |
+
|
| 68 |
+
for i in range(len(prediction['boxes'])):
|
| 69 |
+
box = prediction['boxes'][i]
|
| 70 |
+
x1, y1, x2, y2 = box
|
| 71 |
+
if resize:
|
| 72 |
+
x1, y1, x2, y2 = resize_boxes(np.array([box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 73 |
+
score = prediction['scores'][i]
|
| 74 |
+
if score < score_threshold:
|
| 75 |
+
continue
|
| 76 |
+
if draw_boxes:
|
| 77 |
+
if only_print is not None and only_print != 'all':
|
| 78 |
+
if prediction['labels'][i] != list(class_dict.values()).index(only_print):
|
| 79 |
+
continue
|
| 80 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0), int(2*scale))
|
| 81 |
+
if write_score:
|
| 82 |
+
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
|
| 83 |
+
if write_idx:
|
| 84 |
+
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
|
| 85 |
+
|
| 86 |
+
if write_class and 'labels' in prediction:
|
| 87 |
+
class_id = prediction['labels'][i]
|
| 88 |
+
cv2.putText(image_copy, class_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Draw keypoints if available
|
| 92 |
+
if draw_keypoints and 'keypoints' in prediction:
|
| 93 |
+
for i in range(len(prediction['keypoints'])):
|
| 94 |
+
kp = prediction['keypoints'][i]
|
| 95 |
+
for j in range(kp.shape[0]):
|
| 96 |
+
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
score = prediction['scores'][i]
|
| 100 |
+
if score < score_threshold:
|
| 101 |
+
continue
|
| 102 |
+
x,y, v = np.array(kp[j])
|
| 103 |
+
x, y, v = resize_keypoints(np.array([kp[j]]), (new_scaled_size[1],new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 104 |
+
if j == 0:
|
| 105 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
|
| 106 |
+
else:
|
| 107 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
|
| 108 |
+
|
| 109 |
+
# Draw text predictions if available
|
| 110 |
+
if (draw_text or write_text) and text_predictions is not None:
|
| 111 |
+
for i in range(len(text_predictions[0])):
|
| 112 |
+
x1, y1, x2, y2 = text_predictions[0][i]
|
| 113 |
+
text = text_predictions[1][i]
|
| 114 |
+
if resize:
|
| 115 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 116 |
+
if draw_text:
|
| 117 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
| 118 |
+
if write_text:
|
| 119 |
+
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
'''Draws links between objects based on the full prediction data.'''
|
| 123 |
+
#check if keypoints detected are the same
|
| 124 |
+
if draw_twins and prediction is not None:
|
| 125 |
+
# Pre-calculate indices for performance
|
| 126 |
+
circle_color = (0, 255, 0) # Green color for the circle
|
| 127 |
+
circle_radius = int(10 * scale) # Circle radius scaled by image scale
|
| 128 |
+
|
| 129 |
+
for idx, (key1, key2) in enumerate(prediction['keypoints']):
|
| 130 |
+
if prediction['labels'][idx] not in [list(class_dict.values()).index('sequenceFlow'),
|
| 131 |
+
list(class_dict.values()).index('messageFlow'),
|
| 132 |
+
list(class_dict.values()).index('dataAssociation')]:
|
| 133 |
+
continue
|
| 134 |
+
# Calculate the Euclidean distance between the two keypoints
|
| 135 |
+
distance = np.linalg.norm(key1[:2] - key2[:2])
|
| 136 |
+
if distance < 10:
|
| 137 |
+
x_new,y_new, x,y = find_other_keypoint(idx,prediction)
|
| 138 |
+
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
|
| 139 |
+
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
|
| 140 |
+
|
| 141 |
+
# Draw links between objects
|
| 142 |
+
if draw_links==True and prediction is not None:
|
| 143 |
+
for i, (start_idx, end_idx) in enumerate(prediction['links']):
|
| 144 |
+
if start_idx is None or end_idx is None:
|
| 145 |
+
continue
|
| 146 |
+
start_box = prediction['boxes'][start_idx]
|
| 147 |
+
start_box = resize_boxes(np.array([start_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 148 |
+
end_box = prediction['boxes'][end_idx]
|
| 149 |
+
end_box = resize_boxes(np.array([end_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 150 |
+
current_box = prediction['boxes'][i]
|
| 151 |
+
current_box = resize_boxes(np.array([current_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 152 |
+
# Calculate the center of each bounding box
|
| 153 |
+
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
|
| 154 |
+
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
|
| 155 |
+
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
|
| 156 |
+
# Draw a line between the centers of the connected objects
|
| 157 |
+
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
|
| 158 |
+
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if draw_grouped_text and prediction is not None:
|
| 162 |
+
task_boxes = task_boxes = [box for i, box in enumerate(prediction['boxes']) if prediction['labels'][i] == list(class_dict.values()).index('task')]
|
| 163 |
+
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_predictions[0], text_predictions[1], percentage_thresh=1)
|
| 164 |
+
for i in range(len(info_boxes)):
|
| 165 |
+
x1, y1, x2, y2 = info_boxes[i]
|
| 166 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 167 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
| 168 |
+
for i in range(len(sentence_bounding_boxes)):
|
| 169 |
+
x1,y1,x2,y2 = sentence_bounding_boxes[i]
|
| 170 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 171 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
| 172 |
+
|
| 173 |
+
if return_image:
|
| 174 |
+
return image_copy
|
| 175 |
+
else:
|
| 176 |
+
# Display the image
|
| 177 |
+
plt.figure(figsize=(12, 12))
|
| 178 |
+
plt.imshow(image_copy)
|
| 179 |
+
if axis==False:
|
| 180 |
+
plt.axis('off')
|
| 181 |
+
plt.show()
|
eval.py
ADDED
|
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from toXML import create_BPMN_id
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
|
| 11 |
+
idxs = np.argsort(scores) # Sort the boxes according to their scores in ascending order
|
| 12 |
+
selected_boxes = []
|
| 13 |
+
|
| 14 |
+
while len(idxs) > 0:
|
| 15 |
+
last = len(idxs) - 1
|
| 16 |
+
i = idxs[last]
|
| 17 |
+
|
| 18 |
+
# Skip if the label is a lane
|
| 19 |
+
if labels is not None and class_dict[labels[i]] == 'lane':
|
| 20 |
+
selected_boxes.append(i)
|
| 21 |
+
idxs = np.delete(idxs, last)
|
| 22 |
+
continue
|
| 23 |
+
|
| 24 |
+
selected_boxes.append(i)
|
| 25 |
+
|
| 26 |
+
# Find the intersection of the box with the rest
|
| 27 |
+
suppress = [last]
|
| 28 |
+
for pos in range(0, last):
|
| 29 |
+
j = idxs[pos]
|
| 30 |
+
if iou(boxes[i], boxes[j]) > iou_threshold:
|
| 31 |
+
suppress.append(pos)
|
| 32 |
+
|
| 33 |
+
idxs = np.delete(idxs, suppress)
|
| 34 |
+
|
| 35 |
+
# Return only the boxes that were selected
|
| 36 |
+
return selected_boxes
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distance_treshold=15):
|
| 40 |
+
for idx, (key1, key2) in enumerate(keypoints):
|
| 41 |
+
if labels[idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
| 42 |
+
list(model_dict.values()).index('messageFlow'),
|
| 43 |
+
list(model_dict.values()).index('dataAssociation')]:
|
| 44 |
+
continue
|
| 45 |
+
# Calculate the Euclidean distance between the two keypoints
|
| 46 |
+
distance = np.linalg.norm(key1[:2] - key2[:2])
|
| 47 |
+
if distance < distance_treshold:
|
| 48 |
+
print('Key modified for index:', idx)
|
| 49 |
+
x_new,y_new, x,y = find_other_keypoint(idx, keypoints, boxes)
|
| 50 |
+
keypoints[idx][0][:2] = [x_new,y_new]
|
| 51 |
+
keypoints[idx][1][:2] = [x,y]
|
| 52 |
+
|
| 53 |
+
return keypoints
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
| 57 |
+
model.eval()
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
|
| 60 |
+
predictions = model(image_tensor)
|
| 61 |
+
|
| 62 |
+
boxes = predictions[0]['boxes'].cpu().numpy()
|
| 63 |
+
labels = predictions[0]['labels'].cpu().numpy()
|
| 64 |
+
scores = predictions[0]['scores'].cpu().numpy()
|
| 65 |
+
|
| 66 |
+
idx = np.where(scores > score_threshold)[0]
|
| 67 |
+
boxes = boxes[idx]
|
| 68 |
+
scores = scores[idx]
|
| 69 |
+
labels = labels[idx]
|
| 70 |
+
|
| 71 |
+
selected_boxes = non_maximum_suppression(boxes, scores, labels=labels, iou_threshold=iou_threshold)
|
| 72 |
+
|
| 73 |
+
#find orientation of the task by checking the size of all the boxes and delete the one that are not in the same orientation
|
| 74 |
+
vertical = 0
|
| 75 |
+
for i in range(len(labels)):
|
| 76 |
+
if labels[i] != list(object_dict.values()).index('task'):
|
| 77 |
+
continue
|
| 78 |
+
if boxes[i][2]-boxes[i][0] < boxes[i][3]-boxes[i][1]:
|
| 79 |
+
vertical += 1
|
| 80 |
+
horizontal = len(labels) - vertical
|
| 81 |
+
for i in range(len(labels)):
|
| 82 |
+
if labels[i] != list(object_dict.values()).index('task'):
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
if vertical < horizontal:
|
| 86 |
+
if boxes[i][2]-boxes[i][0] < boxes[i][3]-boxes[i][1]:
|
| 87 |
+
#find the element in the list and remove it
|
| 88 |
+
if i in selected_boxes:
|
| 89 |
+
selected_boxes.remove(i)
|
| 90 |
+
elif vertical > horizontal:
|
| 91 |
+
if boxes[i][2]-boxes[i][0] > boxes[i][3]-boxes[i][1]:
|
| 92 |
+
#find the element in the list and remove it
|
| 93 |
+
if i in selected_boxes:
|
| 94 |
+
selected_boxes.remove(i)
|
| 95 |
+
else:
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
boxes = boxes[selected_boxes]
|
| 99 |
+
scores = scores[selected_boxes]
|
| 100 |
+
labels = labels[selected_boxes]
|
| 101 |
+
|
| 102 |
+
prediction = {
|
| 103 |
+
'boxes': boxes,
|
| 104 |
+
'scores': scores,
|
| 105 |
+
'labels': labels,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
| 109 |
+
image = (image * 255).astype(np.uint8)
|
| 110 |
+
|
| 111 |
+
return image, prediction
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, distance_treshold=15):
|
| 115 |
+
model.eval()
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
|
| 118 |
+
predictions = model(image_tensor)
|
| 119 |
+
|
| 120 |
+
boxes = predictions[0]['boxes'].cpu().numpy()
|
| 121 |
+
labels = predictions[0]['labels'].cpu().numpy() + (len(object_dict) - 1)
|
| 122 |
+
scores = predictions[0]['scores'].cpu().numpy()
|
| 123 |
+
keypoints = predictions[0]['keypoints'].cpu().numpy()
|
| 124 |
+
|
| 125 |
+
idx = np.where(scores > score_threshold)[0]
|
| 126 |
+
boxes = boxes[idx]
|
| 127 |
+
scores = scores[idx]
|
| 128 |
+
labels = labels[idx]
|
| 129 |
+
keypoints = keypoints[idx]
|
| 130 |
+
|
| 131 |
+
selected_boxes = non_maximum_suppression(boxes, scores, iou_threshold=iou_threshold)
|
| 132 |
+
boxes = boxes[selected_boxes]
|
| 133 |
+
scores = scores[selected_boxes]
|
| 134 |
+
labels = labels[selected_boxes]
|
| 135 |
+
keypoints = keypoints[selected_boxes]
|
| 136 |
+
|
| 137 |
+
keypoints = keypoint_correction(keypoints, boxes, labels, class_dict, distance_treshold=distance_treshold)
|
| 138 |
+
|
| 139 |
+
prediction = {
|
| 140 |
+
'boxes': boxes,
|
| 141 |
+
'scores': scores,
|
| 142 |
+
'labels': labels,
|
| 143 |
+
'keypoints': keypoints,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
| 147 |
+
image = (image * 255).astype(np.uint8)
|
| 148 |
+
|
| 149 |
+
return image, prediction
|
| 150 |
+
|
| 151 |
+
def mix_predictions(objects_pred, arrow_pred):
|
| 152 |
+
# Initialize the list of lists for keypoints
|
| 153 |
+
object_keypoints = []
|
| 154 |
+
|
| 155 |
+
# Number of boxes
|
| 156 |
+
num_boxes = len(objects_pred['boxes'])
|
| 157 |
+
|
| 158 |
+
# Iterate over the number of boxes
|
| 159 |
+
for _ in range(num_boxes):
|
| 160 |
+
# Each box has 2 keypoints, both initialized to [0, 0, 0]
|
| 161 |
+
keypoints = [[0, 0, 0], [0, 0, 0]]
|
| 162 |
+
object_keypoints.append(keypoints)
|
| 163 |
+
|
| 164 |
+
#concatenate the two predictions
|
| 165 |
+
boxes = np.concatenate((objects_pred['boxes'], arrow_pred['boxes']))
|
| 166 |
+
labels = np.concatenate((objects_pred['labels'], arrow_pred['labels']))
|
| 167 |
+
scores = np.concatenate((objects_pred['scores'], arrow_pred['scores']))
|
| 168 |
+
keypoints = np.concatenate((object_keypoints, arrow_pred['keypoints']))
|
| 169 |
+
|
| 170 |
+
return boxes, labels, scores, keypoints
|
| 171 |
+
|
| 172 |
+
def regroup_elements_by_pool(boxes, labels, class_dict):
|
| 173 |
+
"""
|
| 174 |
+
Regroups elements by the pool they belong to, and creates a single new pool for elements that are not in any existing pool.
|
| 175 |
+
|
| 176 |
+
Parameters:
|
| 177 |
+
- boxes (list): List of bounding boxes.
|
| 178 |
+
- labels (list): List of labels corresponding to each bounding box.
|
| 179 |
+
- class_dict (dict): Dictionary mapping class indices to class names.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
- dict: A dictionary where each key is a pool's index and the value is a list of elements within that pool.
|
| 183 |
+
"""
|
| 184 |
+
# Initialize a dictionary to hold the elements in each pool
|
| 185 |
+
pool_dict = {}
|
| 186 |
+
|
| 187 |
+
# Identify the bounding boxes of the pools
|
| 188 |
+
pool_indices = [i for i, label in enumerate(labels) if (class_dict[label.item()] == 'pool')]
|
| 189 |
+
pool_boxes = [boxes[i] for i in pool_indices]
|
| 190 |
+
|
| 191 |
+
if not pool_indices:
|
| 192 |
+
# If no pools or lanes are detected, create a single pool with all elements
|
| 193 |
+
labels = np.append(labels, list(class_dict.values()).index('pool'))
|
| 194 |
+
pool_dict[len(labels)-1] = list(range(len(boxes)))
|
| 195 |
+
else:
|
| 196 |
+
# Initialize each pool index with an empty list
|
| 197 |
+
for pool_index in pool_indices:
|
| 198 |
+
pool_dict[pool_index] = []
|
| 199 |
+
|
| 200 |
+
# Initialize a list for elements not in any pool
|
| 201 |
+
elements_not_in_pool = []
|
| 202 |
+
|
| 203 |
+
# Iterate over all elements
|
| 204 |
+
for i, box in enumerate(boxes):
|
| 205 |
+
if i in pool_indices or class_dict[labels[i]] == 'messageFlow':
|
| 206 |
+
continue # Skip pool boxes themselves and messageFlow elements
|
| 207 |
+
assigned_to_pool = False
|
| 208 |
+
for j, pool_box in enumerate(pool_boxes):
|
| 209 |
+
# Check if the element is within the pool's bounding box
|
| 210 |
+
if (box[0] >= pool_box[0] and box[1] >= pool_box[1] and
|
| 211 |
+
box[2] <= pool_box[2] and box[3] <= pool_box[3]):
|
| 212 |
+
pool_index = pool_indices[j]
|
| 213 |
+
pool_dict[pool_index].append(i)
|
| 214 |
+
assigned_to_pool = True
|
| 215 |
+
break
|
| 216 |
+
if not assigned_to_pool:
|
| 217 |
+
if class_dict[labels[i]] != 'messageFlow' and class_dict[labels[i]] != 'lane':
|
| 218 |
+
elements_not_in_pool.append(i)
|
| 219 |
+
|
| 220 |
+
if elements_not_in_pool:
|
| 221 |
+
new_pool_index = max(pool_dict.keys()) + 1
|
| 222 |
+
labels = np.append(labels, list(class_dict.values()).index('pool'))
|
| 223 |
+
pool_dict[new_pool_index] = elements_not_in_pool
|
| 224 |
+
|
| 225 |
+
# Separate empty pools
|
| 226 |
+
non_empty_pools = {k: v for k, v in pool_dict.items() if v}
|
| 227 |
+
empty_pools = {k: v for k, v in pool_dict.items() if not v}
|
| 228 |
+
|
| 229 |
+
# Merge non-empty pools followed by empty pools
|
| 230 |
+
pool_dict = {**non_empty_pools, **empty_pools}
|
| 231 |
+
|
| 232 |
+
return pool_dict, labels
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def create_links(keypoints, boxes, labels, class_dict):
|
| 236 |
+
best_points = []
|
| 237 |
+
links = []
|
| 238 |
+
for i in range(len(labels)):
|
| 239 |
+
if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
|
| 240 |
+
closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
|
| 241 |
+
closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
|
| 242 |
+
if closest1 is not None and closest2 is not None:
|
| 243 |
+
best_points.append([point_start, point_end])
|
| 244 |
+
links.append([closest1, closest2])
|
| 245 |
+
else:
|
| 246 |
+
best_points.append([None,None])
|
| 247 |
+
links.append([None,None])
|
| 248 |
+
|
| 249 |
+
for i in range(len(labels)):
|
| 250 |
+
if labels[i]==list(class_dict.values()).index('dataAssociation'):
|
| 251 |
+
closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
|
| 252 |
+
closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
|
| 253 |
+
if closest1 is not None and closest2 is not None:
|
| 254 |
+
best_points[i] = ([point_start, point_end])
|
| 255 |
+
links[i] = ([closest1, closest2])
|
| 256 |
+
|
| 257 |
+
return links, best_points
|
| 258 |
+
|
| 259 |
+
def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
|
| 260 |
+
|
| 261 |
+
for pool_index, elements in pool_dict.items():
|
| 262 |
+
print(f"Pool {pool_index} contains elements: {elements}")
|
| 263 |
+
#check if each link is in the same pool
|
| 264 |
+
for i in range(len(flow_links)):
|
| 265 |
+
if labels[i] == list(class_dict.values()).index('sequenceFlow'):
|
| 266 |
+
id1, id2 = flow_links[i]
|
| 267 |
+
if (id1 and id2) is not None:
|
| 268 |
+
if id1 in elements and id2 in elements:
|
| 269 |
+
continue
|
| 270 |
+
elif id1 not in elements and id2 not in elements:
|
| 271 |
+
continue
|
| 272 |
+
else:
|
| 273 |
+
print('change the link from sequenceFlow to messageFlow')
|
| 274 |
+
labels[i]=list(class_dict.values()).index('messageFlow')
|
| 275 |
+
|
| 276 |
+
return labels, flow_links
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_dict):
|
| 280 |
+
|
| 281 |
+
#delete pool that are have only messageFlow on it
|
| 282 |
+
delete_pool = []
|
| 283 |
+
for pool_index, elements in pool_dict.items():
|
| 284 |
+
if all([labels[i] == list(class_dict.values()).index('messageFlow') for i in elements]):
|
| 285 |
+
if len(elements) > 0:
|
| 286 |
+
delete_pool.append(pool_dict[pool_index])
|
| 287 |
+
print(f"Pool {pool_index} contains only messageFlow elements, deleting it")
|
| 288 |
+
|
| 289 |
+
#sort index
|
| 290 |
+
delete_pool = sorted(delete_pool, reverse=True)
|
| 291 |
+
for pool in delete_pool:
|
| 292 |
+
index = list(pool_dict.keys())[list(pool_dict.values()).index(pool)]
|
| 293 |
+
del pool_dict[index]
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
delete_elements = []
|
| 297 |
+
# Check if there is an arrow that has the same links
|
| 298 |
+
for i in range(len(labels)):
|
| 299 |
+
for j in range(i+1, len(labels)):
|
| 300 |
+
if labels[i] == list(class_dict.values()).index('sequenceFlow') and labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
| 301 |
+
if links[i] == links[j]:
|
| 302 |
+
print(f'element {i} and {j} have the same links')
|
| 303 |
+
if scores[i] > scores[j]:
|
| 304 |
+
print('delete element', j)
|
| 305 |
+
delete_elements.append(j)
|
| 306 |
+
else:
|
| 307 |
+
print('delete element', i)
|
| 308 |
+
delete_elements.append(i)
|
| 309 |
+
|
| 310 |
+
boxes = np.delete(boxes, delete_elements, axis=0)
|
| 311 |
+
labels = np.delete(labels, delete_elements)
|
| 312 |
+
scores = np.delete(scores, delete_elements)
|
| 313 |
+
keypoints = np.delete(keypoints, delete_elements, axis=0)
|
| 314 |
+
links = np.delete(links, delete_elements, axis=0)
|
| 315 |
+
best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
|
| 316 |
+
|
| 317 |
+
#also delete the element in the pool_dict
|
| 318 |
+
for pool_index, elements in pool_dict.items():
|
| 319 |
+
pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
|
| 320 |
+
|
| 321 |
+
return boxes, labels, scores, keypoints, links, best_points, pool_dict
|
| 322 |
+
|
| 323 |
+
def give_link_to_element(links, labels):
|
| 324 |
+
#give a link to event to allow the creation of the BPMN id with start, indermediate and end event
|
| 325 |
+
for i in range(len(links)):
|
| 326 |
+
if labels[i] == list(class_dict.values()).index('sequenceFlow'):
|
| 327 |
+
id1, id2 = links[i]
|
| 328 |
+
if (id1 and id2) is not None:
|
| 329 |
+
links[id1][1] = i
|
| 330 |
+
links[id2][0] = i
|
| 331 |
+
return links
|
| 332 |
+
|
| 333 |
+
def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
|
| 334 |
+
model_object.eval() # Set the model to evaluation mode
|
| 335 |
+
model_arrow.eval() # Set the model to evaluation mode
|
| 336 |
+
|
| 337 |
+
# Load an image
|
| 338 |
+
with torch.no_grad(): # Disable gradient calculation for inference
|
| 339 |
+
_, objects_pred = object_prediction(model_object, image, score_threshold=score_threshold, iou_threshold=iou_threshold)
|
| 340 |
+
_, arrow_pred = arrow_prediction(model_arrow, image, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
|
| 341 |
+
|
| 342 |
+
#print('Object prediction:', objects_pred)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
|
| 346 |
+
|
| 347 |
+
# Regroup elements by pool
|
| 348 |
+
pool_dict, labels = regroup_elements_by_pool(boxes,labels, class_dict)
|
| 349 |
+
# Create links between elements
|
| 350 |
+
flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
|
| 351 |
+
#Correct the labels of some sequenceflow that cross multiple pool
|
| 352 |
+
labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
|
| 353 |
+
#give a link to event to allow the creation of the BPMN id with start, indermediate and end event
|
| 354 |
+
flow_links = give_link_to_element(flow_links, labels)
|
| 355 |
+
|
| 356 |
+
boxes,labels,scores,keypoints,flow_links,best_points,pool_dict = last_correction(boxes,labels,scores,keypoints,flow_links,best_points, pool_dict)
|
| 357 |
+
|
| 358 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
| 359 |
+
image = (image * 255).astype(np.uint8)
|
| 360 |
+
idx = []
|
| 361 |
+
for i in range(len(labels)):
|
| 362 |
+
idx.append(i)
|
| 363 |
+
bpmn_id = [class_dict[labels[i]] for i in range(len(labels))]
|
| 364 |
+
|
| 365 |
+
data = {
|
| 366 |
+
'image': image,
|
| 367 |
+
'idx': idx,
|
| 368 |
+
'boxes': boxes,
|
| 369 |
+
'labels': labels,
|
| 370 |
+
'scores': scores,
|
| 371 |
+
'keypoints': keypoints,
|
| 372 |
+
'links': flow_links,
|
| 373 |
+
'best_points': best_points,
|
| 374 |
+
'pool_dict': pool_dict,
|
| 375 |
+
'BPMN_id': bpmn_id,
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
# give a unique BPMN id to each element
|
| 379 |
+
data = create_BPMN_id(data)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
return image, data
|
| 384 |
+
|
| 385 |
+
def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, model_dict, iou_threshold=0.5):
|
| 386 |
+
# Initialize dictionaries to hold per-class counts
|
| 387 |
+
class_tp = {cls: 0 for cls in model_dict.values()}
|
| 388 |
+
class_fp = {cls: 0 for cls in model_dict.values()}
|
| 389 |
+
class_fn = {cls: 0 for cls in model_dict.values()}
|
| 390 |
+
|
| 391 |
+
# Track which true boxes have been matched
|
| 392 |
+
matched = [False] * len(true_boxes)
|
| 393 |
+
|
| 394 |
+
# Check each prediction against true boxes
|
| 395 |
+
for pred_box, pred_label in zip(pred_boxes, pred_labels):
|
| 396 |
+
match_found = False
|
| 397 |
+
for idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
|
| 398 |
+
if not matched[idx] and pred_label == true_label:
|
| 399 |
+
if iou(np.array(pred_box), np.array(true_box)) >= iou_threshold:
|
| 400 |
+
class_tp[model_dict[pred_label]] += 1
|
| 401 |
+
matched[idx] = True
|
| 402 |
+
match_found = True
|
| 403 |
+
break
|
| 404 |
+
if not match_found:
|
| 405 |
+
class_fp[model_dict[pred_label]] += 1
|
| 406 |
+
|
| 407 |
+
# Count false negatives
|
| 408 |
+
for idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
|
| 409 |
+
if not matched[idx]:
|
| 410 |
+
class_fn[model_dict[true_label]] += 1
|
| 411 |
+
|
| 412 |
+
# Calculate precision, recall, and F1-score per class
|
| 413 |
+
class_precision = {}
|
| 414 |
+
class_recall = {}
|
| 415 |
+
class_f1_score = {}
|
| 416 |
+
|
| 417 |
+
for cls in model_dict.values():
|
| 418 |
+
precision = class_tp[cls] / (class_tp[cls] + class_fp[cls]) if class_tp[cls] + class_fp[cls] > 0 else 0
|
| 419 |
+
recall = class_tp[cls] / (class_tp[cls] + class_fn[cls]) if class_tp[cls] + class_fn[cls] > 0 else 0
|
| 420 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
|
| 421 |
+
|
| 422 |
+
class_precision[cls] = precision
|
| 423 |
+
class_recall[cls] = recall
|
| 424 |
+
class_f1_score[cls] = f1_score
|
| 425 |
+
|
| 426 |
+
return class_precision, class_recall, class_f1_score
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold=5):
|
| 430 |
+
result = 0
|
| 431 |
+
reverted = False
|
| 432 |
+
#find the position of keypoints in the list
|
| 433 |
+
idx = np.where(pred_boxes == pred_box)[0][0]
|
| 434 |
+
idx2 = np.where(true_boxes == true_box)[0][0]
|
| 435 |
+
|
| 436 |
+
keypoint1_pred = pred_keypoints[idx][0]
|
| 437 |
+
keypoint1_true = true_keypoints[idx2][0]
|
| 438 |
+
keypoint2_pred = pred_keypoints[idx][1]
|
| 439 |
+
keypoint2_true = true_keypoints[idx2][1]
|
| 440 |
+
|
| 441 |
+
distance1 = np.linalg.norm(keypoint1_pred[:2] - keypoint1_true[:2])
|
| 442 |
+
distance2 = np.linalg.norm(keypoint2_pred[:2] - keypoint2_true[:2])
|
| 443 |
+
distance3 = np.linalg.norm(keypoint1_pred[:2] - keypoint2_true[:2])
|
| 444 |
+
distance4 = np.linalg.norm(keypoint2_pred[:2] - keypoint1_true[:2])
|
| 445 |
+
|
| 446 |
+
if distance1 < distance_threshold:
|
| 447 |
+
result += 1
|
| 448 |
+
if distance2 < distance_threshold:
|
| 449 |
+
result += 1
|
| 450 |
+
if distance3 < distance_threshold or distance4 < distance_threshold:
|
| 451 |
+
reverted = True
|
| 452 |
+
|
| 453 |
+
return result, reverted
|
| 454 |
+
|
| 455 |
+
def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred_keypoints, true_keypoints, iou_threshold=0.5, distance_threshold=5):
|
| 456 |
+
tp, fp, fn = 0, 0, 0
|
| 457 |
+
key_t, key_f = 0, 0
|
| 458 |
+
labels_t, labels_f = 0, 0
|
| 459 |
+
reverted_tot = 0
|
| 460 |
+
|
| 461 |
+
matched_true_boxes = set()
|
| 462 |
+
for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
|
| 463 |
+
match_found = False
|
| 464 |
+
for true_idx, true_box in enumerate(true_boxes):
|
| 465 |
+
if true_idx in matched_true_boxes:
|
| 466 |
+
continue
|
| 467 |
+
iou_val = iou(pred_box, true_box)
|
| 468 |
+
if iou_val >= iou_threshold:
|
| 469 |
+
if true_keypoints is not None and pred_keypoints is not None:
|
| 470 |
+
key_result, reverted = keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold)
|
| 471 |
+
key_t += key_result
|
| 472 |
+
key_f += 2 - key_result
|
| 473 |
+
if reverted:
|
| 474 |
+
reverted_tot += 1
|
| 475 |
+
|
| 476 |
+
match_found = True
|
| 477 |
+
matched_true_boxes.add(true_idx)
|
| 478 |
+
if pred_label == true_labels[true_idx]:
|
| 479 |
+
labels_t += 1
|
| 480 |
+
else:
|
| 481 |
+
labels_f += 1
|
| 482 |
+
tp += 1
|
| 483 |
+
break
|
| 484 |
+
if not match_found:
|
| 485 |
+
fp += 1
|
| 486 |
+
|
| 487 |
+
fn = len(true_boxes) - tp
|
| 488 |
+
|
| 489 |
+
return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted_tot
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
|
| 493 |
+
model.eval()
|
| 494 |
+
tp, fp, fn = 0, 0, 0
|
| 495 |
+
labels_t, labels_f = 0, 0
|
| 496 |
+
key_t, key_f = 0, 0
|
| 497 |
+
reverted = 0
|
| 498 |
+
|
| 499 |
+
with torch.no_grad():
|
| 500 |
+
for images, targets_im in tqdm(loader, desc="Testing... "): # Wrap the loader with tqdm
|
| 501 |
+
devices = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 502 |
+
images = [image.to(devices) for image in images]
|
| 503 |
+
targets = [{k: v.clone().detach().to(devices) for k, v in t.items()} for t in targets_im]
|
| 504 |
+
|
| 505 |
+
predictions = model(images)
|
| 506 |
+
|
| 507 |
+
for target, prediction in zip(targets, predictions):
|
| 508 |
+
true_boxes = target['boxes'].cpu().numpy()
|
| 509 |
+
true_labels = target['labels'].cpu().numpy()
|
| 510 |
+
if 'keypoints' in target:
|
| 511 |
+
true_keypoints = target['keypoints'].cpu().numpy()
|
| 512 |
+
|
| 513 |
+
pred_boxes = prediction['boxes'].cpu().numpy()
|
| 514 |
+
scores = prediction['scores'].cpu().numpy()
|
| 515 |
+
pred_labels = prediction['labels'].cpu().numpy()
|
| 516 |
+
if 'keypoints' in prediction:
|
| 517 |
+
pred_keypoints = prediction['keypoints'].cpu().numpy()
|
| 518 |
+
|
| 519 |
+
selected_boxes = non_maximum_suppression(pred_boxes, scores, iou_threshold=iou_threshold)
|
| 520 |
+
pred_boxes = pred_boxes[selected_boxes]
|
| 521 |
+
scores = scores[selected_boxes]
|
| 522 |
+
pred_labels = pred_labels[selected_boxes]
|
| 523 |
+
if 'keypoints' in prediction:
|
| 524 |
+
pred_keypoints = pred_keypoints[selected_boxes]
|
| 525 |
+
|
| 526 |
+
filtered_boxes = []
|
| 527 |
+
filtered_labels = []
|
| 528 |
+
filtered_keypoints = []
|
| 529 |
+
if 'keypoints' not in prediction:
|
| 530 |
+
#create a list of zeros of length equal to the number of boxes
|
| 531 |
+
pred_keypoints = [np.zeros((2, 3)) for _ in range(len(pred_boxes))]
|
| 532 |
+
|
| 533 |
+
for box, score, label, keypoints in zip(pred_boxes, scores, pred_labels, pred_keypoints):
|
| 534 |
+
if score >= score_threshold:
|
| 535 |
+
filtered_boxes.append(box)
|
| 536 |
+
filtered_labels.append(label)
|
| 537 |
+
if 'keypoints' in prediction:
|
| 538 |
+
filtered_keypoints.append(keypoints)
|
| 539 |
+
|
| 540 |
+
if key_correction and ('keypoints' in prediction):
|
| 541 |
+
filtered_keypoints = keypoint_correction(filtered_keypoints, filtered_boxes, filtered_labels)
|
| 542 |
+
|
| 543 |
+
if 'keypoints' not in target:
|
| 544 |
+
filtered_keypoints = None
|
| 545 |
+
true_keypoints = None
|
| 546 |
+
tp_img, fp_img, fn_img, labels_t_img, labels_f_img, key_t_img, key_f_img, reverted_img = evaluate_single_image(
|
| 547 |
+
filtered_boxes, true_boxes, filtered_labels, true_labels, filtered_keypoints, true_keypoints, iou_threshold, distance_threshold)
|
| 548 |
+
|
| 549 |
+
tp += tp_img
|
| 550 |
+
fp += fp_img
|
| 551 |
+
fn += fn_img
|
| 552 |
+
labels_t += labels_t_img
|
| 553 |
+
labels_f += labels_f_img
|
| 554 |
+
key_t += key_t_img
|
| 555 |
+
key_f += key_f_img
|
| 556 |
+
reverted += reverted_img
|
| 557 |
+
|
| 558 |
+
return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted
|
| 559 |
+
|
| 560 |
+
def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type = 'object'):
|
| 561 |
+
|
| 562 |
+
tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted = pred_4_evaluation(model, test_loader, score_threshold, iou_threshold, distance_threshold, key_correction, model_type)
|
| 563 |
+
|
| 564 |
+
labels_precision = labels_t / (labels_t + labels_f) if (labels_t + labels_f) > 0 else 0
|
| 565 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 566 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 567 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
| 568 |
+
if model_type == 'arrow':
|
| 569 |
+
key_accuracy = key_t / (key_t + key_f) if (key_t + key_f) > 0 else 0
|
| 570 |
+
reverted_accuracy = reverted / (key_t + key_f) if (key_t + key_f) > 0 else 0
|
| 571 |
+
else:
|
| 572 |
+
key_accuracy = 0
|
| 573 |
+
reverted_accuracy = 0
|
| 574 |
+
|
| 575 |
+
return labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold=0.5):
|
| 580 |
+
matched_true_boxes = set()
|
| 581 |
+
for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
|
| 582 |
+
match_found = False
|
| 583 |
+
for true_idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
|
| 584 |
+
if true_idx in matched_true_boxes:
|
| 585 |
+
continue
|
| 586 |
+
if pred_label == true_label and iou(np.array(pred_box), np.array(true_box)) >= iou_threshold:
|
| 587 |
+
class_tp[model_dict[pred_label]] += 1
|
| 588 |
+
matched_true_boxes.add(true_idx)
|
| 589 |
+
match_found = True
|
| 590 |
+
break
|
| 591 |
+
if not match_found:
|
| 592 |
+
class_fp[model_dict[pred_label]] += 1
|
| 593 |
+
|
| 594 |
+
for idx, true_label in enumerate(true_labels):
|
| 595 |
+
if idx not in matched_true_boxes:
|
| 596 |
+
class_fn[model_dict[true_label]] += 1
|
| 597 |
+
|
| 598 |
+
def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshold=0.5):
|
| 599 |
+
model.eval()
|
| 600 |
+
with torch.no_grad():
|
| 601 |
+
for images, targets_im in tqdm(loader, desc="Testing... "):
|
| 602 |
+
devices = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 603 |
+
images = [image.to(devices) for image in images]
|
| 604 |
+
targets = [{k: v.clone().detach().to(devices) for k, v in t.items()} for t in targets_im]
|
| 605 |
+
|
| 606 |
+
predictions = model(images)
|
| 607 |
+
|
| 608 |
+
for target, prediction in zip(targets, predictions):
|
| 609 |
+
true_boxes = target['boxes'].cpu().numpy()
|
| 610 |
+
true_labels = target['labels'].cpu().numpy()
|
| 611 |
+
|
| 612 |
+
pred_boxes = prediction['boxes'].cpu().numpy()
|
| 613 |
+
scores = prediction['scores'].cpu().numpy()
|
| 614 |
+
pred_labels = prediction['labels'].cpu().numpy()
|
| 615 |
+
|
| 616 |
+
idx = np.where(scores > score_threshold)[0]
|
| 617 |
+
pred_boxes = pred_boxes[idx]
|
| 618 |
+
scores = scores[idx]
|
| 619 |
+
pred_labels = pred_labels[idx]
|
| 620 |
+
|
| 621 |
+
selected_boxes = non_maximum_suppression(pred_boxes, scores, iou_threshold=iou_threshold)
|
| 622 |
+
pred_boxes = pred_boxes[selected_boxes]
|
| 623 |
+
scores = scores[selected_boxes]
|
| 624 |
+
pred_labels = pred_labels[selected_boxes]
|
| 625 |
+
|
| 626 |
+
yield pred_boxes, true_boxes, pred_labels, true_labels
|
| 627 |
+
|
| 628 |
+
def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5):
|
| 629 |
+
class_tp = {cls: 0 for cls in model_dict.values()}
|
| 630 |
+
class_fp = {cls: 0 for cls in model_dict.values()}
|
| 631 |
+
class_fn = {cls: 0 for cls in model_dict.values()}
|
| 632 |
+
|
| 633 |
+
for pred_boxes, true_boxes, pred_labels, true_labels in pred_4_evaluation_per_class(model, test_loader, score_threshold, iou_threshold):
|
| 634 |
+
evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold)
|
| 635 |
+
|
| 636 |
+
class_precision = {}
|
| 637 |
+
class_recall = {}
|
| 638 |
+
class_f1_score = {}
|
| 639 |
+
|
| 640 |
+
for cls in model_dict.values():
|
| 641 |
+
precision = class_tp[cls] / (class_tp[cls] + class_fp[cls]) if class_tp[cls] + class_fp[cls] > 0 else 0
|
| 642 |
+
recall = class_tp[cls] / (class_tp[cls] + class_fn[cls]) if class_tp[cls] + class_fn[cls] > 0 else 0
|
| 643 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
|
| 644 |
+
|
| 645 |
+
class_precision[cls] = precision
|
| 646 |
+
class_recall[cls] = recall
|
| 647 |
+
class_f1_score[cls] = f1_score
|
| 648 |
+
|
| 649 |
+
return class_precision, class_recall, class_f1_score
|
flask.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask
|
| 2 |
+
app = Flask(__name__)
|
| 3 |
+
|
| 4 |
+
@app.route("/")
|
| 5 |
+
def hello():
|
| 6 |
+
return "Hello World!\n"
|
htlm_webpage.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import streamlit.components.v1 as components
|
| 3 |
+
|
| 4 |
+
def display_bpmn_xml(bpmn_xml):
|
| 5 |
+
html_template = f"""
|
| 6 |
+
<!DOCTYPE html>
|
| 7 |
+
<html>
|
| 8 |
+
<head>
|
| 9 |
+
<meta charset="UTF-8">
|
| 10 |
+
<title>BPMN Modeler</title>
|
| 11 |
+
<link rel="stylesheet" href="https://unpkg.com/bpmn-js/dist/assets/diagram-js.css">
|
| 12 |
+
<link rel="stylesheet" href="https://unpkg.com/bpmn-js/dist/assets/bpmn-font/css/bpmn-embedded.css">
|
| 13 |
+
<script src="https://unpkg.com/bpmn-js/dist/bpmn-modeler.development.js"></script>
|
| 14 |
+
<style>
|
| 15 |
+
html, body {{
|
| 16 |
+
height: 100%;
|
| 17 |
+
padding: 0;
|
| 18 |
+
margin: 0;
|
| 19 |
+
font-family: Arial, sans-serif;
|
| 20 |
+
display: flex;
|
| 21 |
+
flex-direction: column;
|
| 22 |
+
overflow: hidden;
|
| 23 |
+
}}
|
| 24 |
+
#button-container {{
|
| 25 |
+
padding: 10px;
|
| 26 |
+
background-color: #ffffff;
|
| 27 |
+
border-bottom: 1px solid #ddd;
|
| 28 |
+
display: flex;
|
| 29 |
+
justify-content: flex-start;
|
| 30 |
+
gap: 10px;
|
| 31 |
+
}}
|
| 32 |
+
#save-button, #download-button {{
|
| 33 |
+
background-color: #4CAF50;
|
| 34 |
+
color: white;
|
| 35 |
+
border: none;
|
| 36 |
+
padding: 10px 20px;
|
| 37 |
+
text-align: center;
|
| 38 |
+
text-decoration: none;
|
| 39 |
+
display: inline-block;
|
| 40 |
+
font-size: 16px;
|
| 41 |
+
margin: 4px 2px;
|
| 42 |
+
cursor: pointer;
|
| 43 |
+
border-radius: 8px;
|
| 44 |
+
}}
|
| 45 |
+
#download-button {{
|
| 46 |
+
background-color: #008CBA;
|
| 47 |
+
}}
|
| 48 |
+
#canvas-container {{
|
| 49 |
+
flex: 1;
|
| 50 |
+
position: relative;
|
| 51 |
+
background-color: #FBFBFB;
|
| 52 |
+
overflow: hidden; /* Prevent scrolling */
|
| 53 |
+
display: flex;
|
| 54 |
+
justify-content: center;
|
| 55 |
+
align-items: center;
|
| 56 |
+
}}
|
| 57 |
+
#canvas {{
|
| 58 |
+
height: 100%;
|
| 59 |
+
width: 100%;
|
| 60 |
+
position: relative;
|
| 61 |
+
}}
|
| 62 |
+
</style>
|
| 63 |
+
</head>
|
| 64 |
+
<body>
|
| 65 |
+
<div id="button-container">
|
| 66 |
+
<button id="save-button">Save as BPMN</button>
|
| 67 |
+
<button id="download-button">Save as XML</button>
|
| 68 |
+
<button id="download-button">Save as Vizi</button>
|
| 69 |
+
</div>
|
| 70 |
+
<div id="canvas-container">
|
| 71 |
+
<div id="canvas"></div>
|
| 72 |
+
</div>
|
| 73 |
+
<script>
|
| 74 |
+
var bpmnModeler = new BpmnJS({{
|
| 75 |
+
container: '#canvas'
|
| 76 |
+
}});
|
| 77 |
+
|
| 78 |
+
async function openDiagram(bpmnXML) {{
|
| 79 |
+
try {{
|
| 80 |
+
await bpmnModeler.importXML(bpmnXML);
|
| 81 |
+
bpmnModeler.get('canvas').zoom('fit-viewport');
|
| 82 |
+
bpmnModeler.get('canvas').zoom(0.8); // Adjust this value for zooming out
|
| 83 |
+
}} catch (err) {{
|
| 84 |
+
console.error('Error rendering BPMN diagram', err);
|
| 85 |
+
}}
|
| 86 |
+
}}
|
| 87 |
+
|
| 88 |
+
async function saveDiagram() {{
|
| 89 |
+
try {{
|
| 90 |
+
const result = await bpmnModeler.saveXML({{ format: true }});
|
| 91 |
+
const xml = result.xml;
|
| 92 |
+
const blob = new Blob([xml], {{ type: 'text/xml' }});
|
| 93 |
+
const url = URL.createObjectURL(blob);
|
| 94 |
+
const a = document.createElement('a');
|
| 95 |
+
a.href = url;
|
| 96 |
+
a.download = 'diagram.bpmn';
|
| 97 |
+
document.body.appendChild(a);
|
| 98 |
+
a.click();
|
| 99 |
+
document.body.removeChild(a);
|
| 100 |
+
}} catch (err) {{
|
| 101 |
+
console.error('Error saving BPMN diagram', err);
|
| 102 |
+
}}
|
| 103 |
+
}}
|
| 104 |
+
|
| 105 |
+
async function downloadXML() {{
|
| 106 |
+
const xml = `{bpmn_xml}`;
|
| 107 |
+
const blob = new Blob([xml], {{ type: 'text/xml' }});
|
| 108 |
+
const url = URL.createObjectURL(blob);
|
| 109 |
+
const a = document.createElement('a');
|
| 110 |
+
a.href = url;
|
| 111 |
+
a.download = 'diagram.xml';
|
| 112 |
+
document.body.appendChild(a);
|
| 113 |
+
a.click();
|
| 114 |
+
document.body.removeChild(a);
|
| 115 |
+
}}
|
| 116 |
+
|
| 117 |
+
document.getElementById('save-button').addEventListener('click', saveDiagram);
|
| 118 |
+
document.getElementById('download-button').addEventListener('click', downloadXML);
|
| 119 |
+
|
| 120 |
+
// Ensure the canvas is focused to capture keyboard events
|
| 121 |
+
document.getElementById('canvas').focus();
|
| 122 |
+
|
| 123 |
+
// Add event listeners for keyboard shortcuts
|
| 124 |
+
document.addEventListener('keydown', function(event) {{
|
| 125 |
+
if (event.ctrlKey && event.key === 'z') {{
|
| 126 |
+
bpmnModeler.get('commandStack').undo();
|
| 127 |
+
}} else if (event.key === 'Delete' || event.key === 'Backspace') {{
|
| 128 |
+
bpmnModeler.get('selection').get().forEach(function(element) {{
|
| 129 |
+
bpmnModeler.get('modeling').removeElements([element]);
|
| 130 |
+
}});
|
| 131 |
+
}}
|
| 132 |
+
}});
|
| 133 |
+
|
| 134 |
+
openDiagram(`{bpmn_xml}`);
|
| 135 |
+
</script>
|
| 136 |
+
</body>
|
| 137 |
+
</html>
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
components.html(html_template, height=1000, width=1500)
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
libgl1-mesa-glx
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
yamlu==0.0.17
|
| 2 |
+
tqdm==4.66.4
|
| 3 |
+
torchvision==0.18.0
|
| 4 |
+
azure-ai-vision-imageanalysis==1.0.0b2
|
| 5 |
+
streamlit==1.35.0
|
| 6 |
+
streamlit-image-comparison==0.0.4
|
| 7 |
+
streamlit-cropper==0.2.2
|
| 8 |
+
streamlit-drawable-canvas==0.9.3
|
| 9 |
+
opencv-python
|
| 10 |
+
gdown
|
toXML.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xml.etree.ElementTree as ET
|
| 2 |
+
from utils import class_dict
|
| 3 |
+
|
| 4 |
+
def rescale(scale, boxes):
|
| 5 |
+
for i in range(len(boxes)):
|
| 6 |
+
boxes[i] = [boxes[i][0]*scale,
|
| 7 |
+
boxes[i][1]*scale,
|
| 8 |
+
boxes[i][2]*scale,
|
| 9 |
+
boxes[i][3]*scale]
|
| 10 |
+
return boxes
|
| 11 |
+
|
| 12 |
+
def create_BPMN_id(data):
|
| 13 |
+
enums = {
|
| 14 |
+
'end_event': 1,
|
| 15 |
+
'start_event': 1,
|
| 16 |
+
'task': 1,
|
| 17 |
+
'sequenceFlow': 1,
|
| 18 |
+
'messageFlow': 1,
|
| 19 |
+
'message_event': 1,
|
| 20 |
+
'exclusiveGateway': 1,
|
| 21 |
+
'parallelGateway': 1,
|
| 22 |
+
'dataAssociation': 1,
|
| 23 |
+
'pool': 1,
|
| 24 |
+
'dataObject': 1,
|
| 25 |
+
'timerEvent': 1
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
BPMN_name = [class_dict[label] for label in data['labels']]
|
| 29 |
+
|
| 30 |
+
for idx, Bpmn_id in enumerate(BPMN_name):
|
| 31 |
+
if Bpmn_id == 'event':
|
| 32 |
+
if data['links'][idx][0] is not None and data['links'][idx][1] is None:
|
| 33 |
+
key = 'end_event'
|
| 34 |
+
elif data['links'][idx][0] is None and data['links'][idx][1] is not None:
|
| 35 |
+
key = 'start_event'
|
| 36 |
+
else:
|
| 37 |
+
key = {
|
| 38 |
+
'task': 'task',
|
| 39 |
+
'dataObject': 'dataObject',
|
| 40 |
+
'sequenceFlow': 'sequenceFlow',
|
| 41 |
+
'messageFlow': 'messageFlow',
|
| 42 |
+
'messageEvent': 'message_event',
|
| 43 |
+
'exclusiveGateway': 'exclusiveGateway',
|
| 44 |
+
'parallelGateway': 'parallelGateway',
|
| 45 |
+
'dataAssociation': 'dataAssociation',
|
| 46 |
+
'pool': 'pool',
|
| 47 |
+
'timerEvent': 'timerEvent'
|
| 48 |
+
}.get(Bpmn_id, None)
|
| 49 |
+
|
| 50 |
+
if key:
|
| 51 |
+
data['BPMN_id'][idx] = f'{key}_{enums[key]}'
|
| 52 |
+
enums[key] += 1
|
| 53 |
+
|
| 54 |
+
return data
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def add_diagram_elements(parent, element_id, x, y, width, height):
|
| 59 |
+
"""Utility to add BPMN diagram notation for elements."""
|
| 60 |
+
shape = ET.SubElement(parent, 'bpmndi:BPMNShape', attrib={
|
| 61 |
+
'bpmnElement': element_id,
|
| 62 |
+
'id': element_id + '_di'
|
| 63 |
+
})
|
| 64 |
+
bounds = ET.SubElement(shape, 'dc:Bounds', attrib={
|
| 65 |
+
'x': str(x),
|
| 66 |
+
'y': str(y),
|
| 67 |
+
'width': str(width),
|
| 68 |
+
'height': str(height)
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
def add_diagram_edge(parent, element_id, waypoints):
|
| 72 |
+
"""Utility to add BPMN diagram notation for sequence flows."""
|
| 73 |
+
edge = ET.SubElement(parent, 'bpmndi:BPMNEdge', attrib={
|
| 74 |
+
'bpmnElement': element_id,
|
| 75 |
+
'id': element_id + '_di'
|
| 76 |
+
})
|
| 77 |
+
for x, y in waypoints:
|
| 78 |
+
ET.SubElement(edge, 'di:waypoint', attrib={
|
| 79 |
+
'x': str(x),
|
| 80 |
+
'y': str(y)
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def check_status(link, keep_elements):
|
| 85 |
+
if link[0] in keep_elements and link[1] in keep_elements:
|
| 86 |
+
return 'middle'
|
| 87 |
+
elif link[0] is None and link[1] in keep_elements:
|
| 88 |
+
return 'start'
|
| 89 |
+
elif link[0] in keep_elements and link[1] is None:
|
| 90 |
+
return 'end'
|
| 91 |
+
else:
|
| 92 |
+
return 'middle'
|
| 93 |
+
|
| 94 |
+
def check_data_association(i, links, labels, keep_elements):
|
| 95 |
+
for j, (k,l) in enumerate(links):
|
| 96 |
+
if labels[j] == 14:
|
| 97 |
+
if k==i:
|
| 98 |
+
return 'output',j
|
| 99 |
+
elif l==i:
|
| 100 |
+
return 'input',j
|
| 101 |
+
|
| 102 |
+
return 'no association', None
|
| 103 |
+
|
| 104 |
+
def create_data_Association(bpmn,data,size,element_id,source_id,target_id):
|
| 105 |
+
waypoints = calculate_waypoints(data, size, source_id, target_id)
|
| 106 |
+
add_diagram_edge(bpmn, element_id, waypoints)
|
| 107 |
+
|
| 108 |
+
# Function to dynamically create and layout BPMN elements
|
| 109 |
+
def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
|
| 110 |
+
elements = data['BPMN_id']
|
| 111 |
+
positions = data['boxes']
|
| 112 |
+
links = data['links']
|
| 113 |
+
|
| 114 |
+
for i in keep_elements:
|
| 115 |
+
element_id = elements[i]
|
| 116 |
+
if element_id is None:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
element_type = element_id.split('_')[0]
|
| 120 |
+
x, y = positions[i][:2]
|
| 121 |
+
|
| 122 |
+
# Start Event
|
| 123 |
+
if element_type == 'start':
|
| 124 |
+
element = ET.SubElement(process, 'bpmn:startEvent', id=element_id, name=text_mapping[element_id])
|
| 125 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['start'][0], size['start'][1])
|
| 126 |
+
|
| 127 |
+
# Task
|
| 128 |
+
elif element_type == 'task':
|
| 129 |
+
element = ET.SubElement(process, 'bpmn:task', id=element_id, name=text_mapping[element_id])
|
| 130 |
+
status, dataAssociation_idx = check_data_association(i, data['links'], data['labels'], keep_elements)
|
| 131 |
+
|
| 132 |
+
# Handle Data Input Association
|
| 133 |
+
if status == 'input':
|
| 134 |
+
dataObject_idx = links[dataAssociation_idx][0]
|
| 135 |
+
dataObject_name = elements[dataObject_idx]
|
| 136 |
+
dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
|
| 137 |
+
sub_element = ET.SubElement(element, 'bpmn:dataInputAssociation', id=f'dataInputAssociation_{dataObject_ref.split("_")[1]}')
|
| 138 |
+
ET.SubElement(sub_element, 'bpmn:sourceRef').text = dataObject_ref
|
| 139 |
+
create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], dataObject_name, element_id)
|
| 140 |
+
|
| 141 |
+
# Handle Data Output Association
|
| 142 |
+
elif status == 'output':
|
| 143 |
+
dataObject_idx = links[dataAssociation_idx][1]
|
| 144 |
+
dataObject_name = elements[dataObject_idx]
|
| 145 |
+
dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
|
| 146 |
+
sub_element = ET.SubElement(element, 'bpmn:dataOutputAssociation', id=f'dataOutputAssociation_{dataObject_ref.split("_")[1]}')
|
| 147 |
+
ET.SubElement(sub_element, 'bpmn:targetRef').text = dataObject_ref
|
| 148 |
+
create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], element_id, dataObject_name)
|
| 149 |
+
|
| 150 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['task'][0], size['task'][1])
|
| 151 |
+
|
| 152 |
+
# Message Events (Start, Intermediate, End)
|
| 153 |
+
elif element_type == 'message':
|
| 154 |
+
status = check_status(links[i], keep_elements)
|
| 155 |
+
if status == 'start':
|
| 156 |
+
element = ET.SubElement(process, 'bpmn:startEvent', id=element_id, name=text_mapping[element_id])
|
| 157 |
+
elif status == 'middle':
|
| 158 |
+
element = ET.SubElement(process, 'bpmn:intermediateCatchEvent', id=element_id, name=text_mapping[element_id])
|
| 159 |
+
elif status == 'end':
|
| 160 |
+
element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
|
| 161 |
+
ET.SubElement(element, 'bpmn:messageEventDefinition', id=f'MessageEventDefinition_{i+1}')
|
| 162 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['message'][0], size['message'][1])
|
| 163 |
+
|
| 164 |
+
# End Event
|
| 165 |
+
elif element_type == 'end':
|
| 166 |
+
element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
|
| 167 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['end'][0], size['end'][1])
|
| 168 |
+
|
| 169 |
+
# Gateways (Exclusive, Parallel)
|
| 170 |
+
elif element_type in ['exclusiveGateway', 'parallelGateway']:
|
| 171 |
+
gateway_type = 'exclusiveGateway' if element_type == 'exclusiveGateway' else 'parallelGateway'
|
| 172 |
+
element = ET.SubElement(process, f'bpmn:{gateway_type}', id=element_id)
|
| 173 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size[element_type][0], size[element_type][1])
|
| 174 |
+
|
| 175 |
+
# Data Object
|
| 176 |
+
elif element_type == 'dataObject':
|
| 177 |
+
dataObject_idx = element_id.split('_')[1]
|
| 178 |
+
dataObject_ref = f'DataObjectReference_{dataObject_idx}'
|
| 179 |
+
element = ET.SubElement(process, 'bpmn:dataObjectReference', id=dataObject_ref, dataObjectRef=element_id, name=text_mapping[element_id])
|
| 180 |
+
ET.SubElement(process, 'bpmn:dataObject', id=element_id)
|
| 181 |
+
add_diagram_elements(bpmnplane, dataObject_ref, x, y, size['dataObject'][0], size['dataObject'][1])
|
| 182 |
+
|
| 183 |
+
# Timer Event
|
| 184 |
+
elif element_type == 'timerEvent':
|
| 185 |
+
element = ET.SubElement(process, 'bpmn:intermediateCatchEvent', id=element_id, name=text_mapping[element_id])
|
| 186 |
+
ET.SubElement(element, 'bpmn:timerEventDefinition', id=f'TimerEventDefinition_{i+1}')
|
| 187 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# Calculate simple waypoints between two elements (this function assumes direct horizontal links for simplicity)
|
| 192 |
+
def calculate_waypoints(data, size, source_id, target_id):
|
| 193 |
+
source_idx = data['BPMN_id'].index(source_id)
|
| 194 |
+
target_idx = data['BPMN_id'].index(target_id)
|
| 195 |
+
name_source = source_id.split('_')[0]
|
| 196 |
+
name_target = target_id.split('_')[0]
|
| 197 |
+
|
| 198 |
+
#Get the position of the source and target
|
| 199 |
+
source_x, source_y = data['boxes'][source_idx][:2]
|
| 200 |
+
target_x, target_y = data['boxes'][target_idx][:2]
|
| 201 |
+
|
| 202 |
+
# Calculate relative position between source and target from their centers
|
| 203 |
+
relative_x = (target_x+size[name_target][0])/2 - (source_x+size[name_source][0])/2
|
| 204 |
+
relative_y = (target_y+size[name_target][1])/2 - (source_y+size[name_source][1])/2
|
| 205 |
+
|
| 206 |
+
# Get the size of the elements
|
| 207 |
+
size_x_source = size[name_source][0]
|
| 208 |
+
size_y_source = size[name_source][1]
|
| 209 |
+
size_x_target = size[name_target][0]
|
| 210 |
+
size_y_target = size[name_target][1]
|
| 211 |
+
|
| 212 |
+
#if it going to right
|
| 213 |
+
if relative_x >= size[name_source][0]:
|
| 214 |
+
source_x += size_x_source
|
| 215 |
+
source_y += size_y_source / 2
|
| 216 |
+
target_x = target_x
|
| 217 |
+
target_y += size_y_target / 2
|
| 218 |
+
#if the source is going up
|
| 219 |
+
if relative_y < -size[name_source][1]:
|
| 220 |
+
source_x -= size_x_source / 2
|
| 221 |
+
source_y -= size_y_source / 2
|
| 222 |
+
#if the source is going down
|
| 223 |
+
elif relative_y > size[name_source][1]:
|
| 224 |
+
source_x -= size_x_source / 2
|
| 225 |
+
source_y += size_y_source / 2
|
| 226 |
+
#if it going to left
|
| 227 |
+
elif relative_x < -size[name_source][0]:
|
| 228 |
+
source_x = source_x
|
| 229 |
+
source_y += size_y_source / 2
|
| 230 |
+
target_x += size_x_target
|
| 231 |
+
target_y += size_y_target / 2
|
| 232 |
+
#if the source is going up
|
| 233 |
+
if relative_y < -size[name_source][1]:
|
| 234 |
+
source_x += size_x_source / 2
|
| 235 |
+
source_y -= size_y_source / 2
|
| 236 |
+
#if the source is going down
|
| 237 |
+
elif relative_y > size[name_source][1]:
|
| 238 |
+
source_x += size_x_source / 2
|
| 239 |
+
source_y += size_y_source / 2
|
| 240 |
+
#if it going up and down
|
| 241 |
+
elif -size[name_source][0] < relative_x < size[name_source][0]:
|
| 242 |
+
source_x += size_x_source / 2
|
| 243 |
+
target_x += size_x_target / 2
|
| 244 |
+
#if it's going down
|
| 245 |
+
if relative_y >= size[name_source][1]/2:
|
| 246 |
+
source_y += size_y_source
|
| 247 |
+
#if it's going up
|
| 248 |
+
elif relative_y < -size[name_source][1]/2:
|
| 249 |
+
source_y = source_y
|
| 250 |
+
target_y += size_y_target
|
| 251 |
+
else:
|
| 252 |
+
if relative_x >= 0:
|
| 253 |
+
source_x += size_x_source/2
|
| 254 |
+
source_y += size_y_source/2
|
| 255 |
+
target_x -= size_x_target/2
|
| 256 |
+
target_y += size_y_target/2
|
| 257 |
+
else:
|
| 258 |
+
source_x -= size_x_source/2
|
| 259 |
+
source_y += size_y_source/2
|
| 260 |
+
target_x += size_x_target/2
|
| 261 |
+
target_y += size_y_target/2
|
| 262 |
+
|
| 263 |
+
return [(source_x, source_y), (target_x, target_y)]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def calculate_pool_bounds(data, keep_elements, size):
|
| 267 |
+
min_x = min_y = float('10000')
|
| 268 |
+
max_x = max_y = float('0')
|
| 269 |
+
|
| 270 |
+
for i in keep_elements:
|
| 271 |
+
if i >= len(data['BPMN_id']):
|
| 272 |
+
print("Problem with the index")
|
| 273 |
+
continue
|
| 274 |
+
element = data['BPMN_id'][i]
|
| 275 |
+
if element is None or data['labels'][i] == 13 or data['labels'][i] == 14 or data['labels'][i] == 15 or data['labels'][i] == 7 or data['labels'][i] == 15:
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
element_type = element.split('_')[0]
|
| 279 |
+
x, y = data['boxes'][i][:2]
|
| 280 |
+
element_width, element_height = size[element_type]
|
| 281 |
+
|
| 282 |
+
min_x = min(min_x, x)
|
| 283 |
+
min_y = min(min_y, y)
|
| 284 |
+
max_x = max(max_x, x + element_width)
|
| 285 |
+
max_y = max(max_y, y + element_height)
|
| 286 |
+
|
| 287 |
+
return min_x, min_y, max_x, max_y
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
|
| 291 |
+
# Get the bounding boxes of the source and target elements
|
| 292 |
+
source_box = data['boxes'][source_idx]
|
| 293 |
+
target_box = data['boxes'][target_idx]
|
| 294 |
+
|
| 295 |
+
# Get the midpoints of the source element
|
| 296 |
+
source_mid_x = (source_box[0] + source_box[2]) / 2
|
| 297 |
+
source_mid_y = (source_box[1] + source_box[3]) / 2
|
| 298 |
+
|
| 299 |
+
# Check if the connection involves a pool
|
| 300 |
+
if source_element == 'pool':
|
| 301 |
+
pool_box = source_box
|
| 302 |
+
element_box = (target_box[0], target_box[1], target_box[0]+size[target_element][0], target_box[1]+size[target_element][1])
|
| 303 |
+
element_mid_x = (element_box[0] + element_box[2]) / 2
|
| 304 |
+
element_mid_y = (element_box[1] + element_box[3]) / 2
|
| 305 |
+
# Connect the pool's bottom or top side to the target element's top or bottom center
|
| 306 |
+
if pool_box[3] < element_box[1]: # Pool is above the target element
|
| 307 |
+
waypoints = [(element_mid_x, pool_box[3]-50), (element_mid_x, element_box[1])]
|
| 308 |
+
else: # Pool is below the target element
|
| 309 |
+
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1]-50)]
|
| 310 |
+
else:
|
| 311 |
+
pool_box = target_box
|
| 312 |
+
element_box = (source_box[0], source_box[1], source_box[0]+size[source_element][0], source_box[1]+size[source_element][1])
|
| 313 |
+
element_mid_x = (element_box[0] + element_box[2]) / 2
|
| 314 |
+
element_mid_y = (element_box[1] + element_box[3]) / 2
|
| 315 |
+
|
| 316 |
+
# Connect the element's bottom or top center to the pool's top or bottom side
|
| 317 |
+
if pool_box[3] < element_box[1]: # Pool is above the target element
|
| 318 |
+
waypoints = [(element_mid_x, element_box[1]), (element_mid_x, pool_box[3]-50)]
|
| 319 |
+
else: # Pool is below the target element
|
| 320 |
+
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1]-50)]
|
| 321 |
+
|
| 322 |
+
return waypoints
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
|
| 327 |
+
source_idx, target_idx = data['links'][idx]
|
| 328 |
+
source_id, target_id = data['BPMN_id'][source_idx], data['BPMN_id'][target_idx]
|
| 329 |
+
if message:
|
| 330 |
+
element_id = f'messageflow_{source_id}_{target_id}'
|
| 331 |
+
else:
|
| 332 |
+
element_id = f'sequenceflow_{source_id}_{target_id}'
|
| 333 |
+
|
| 334 |
+
if source_id.split('_')[0] == 'pool' or target_id.split('_')[0] == 'pool':
|
| 335 |
+
waypoints = calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_id.split('_')[0], target_id.split('_')[0])
|
| 336 |
+
#waypoints = data['best_points'][idx]
|
| 337 |
+
if source_id.split('_')[0] == 'pool':
|
| 338 |
+
source_id = f"participant_{source_id.split('_')[1]}"
|
| 339 |
+
if target_id.split('_')[0] == 'pool':
|
| 340 |
+
target_id = f"participant_{target_id.split('_')[1]}"
|
| 341 |
+
else:
|
| 342 |
+
waypoints = calculate_waypoints(data, size, source_id, target_id)
|
| 343 |
+
#waypoints = data['best_points'][idx]
|
| 344 |
+
|
| 345 |
+
#waypoints = data['best_points'][idx]
|
| 346 |
+
if message:
|
| 347 |
+
element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
| 348 |
+
else:
|
| 349 |
+
element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
| 350 |
+
add_diagram_edge(bpmn, element_id, waypoints)
|
| 351 |
+
|
train.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms.functional as F
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
from eval import main_evaluation
|
| 11 |
+
from torch.optim import SGD, AdamW
|
| 12 |
+
from torch.utils.data import DataLoader, Dataset, Subset, ConcatDataset
|
| 13 |
+
from torch.utils.data.dataloader import default_collate
|
| 14 |
+
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
| 15 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
| 16 |
+
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from utils import write_results
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_arrow_model(num_classes, num_keypoints=2):
|
| 24 |
+
"""
|
| 25 |
+
Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints.
|
| 26 |
+
|
| 27 |
+
Parameters:
|
| 28 |
+
- num_classes (int): Number of classes for the model to detect, excluding the background class.
|
| 29 |
+
- num_keypoints (int): Number of keypoints to predict for each detected object.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
- model (torch.nn.Module): The modified Keypoint R-CNN model.
|
| 33 |
+
|
| 34 |
+
Steps:
|
| 35 |
+
1. Load a pre-trained Keypoint R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN).
|
| 36 |
+
The model is initially configured for the COCO dataset, which includes various object classes and keypoints.
|
| 37 |
+
2. Replace the box predictor to adjust the number of output classes. The box predictor is responsible for
|
| 38 |
+
classifying detected regions and predicting their bounding boxes.
|
| 39 |
+
3. Replace the keypoint predictor to adjust the number of keypoints the model predicts for each object.
|
| 40 |
+
This is necessary to tailor the model to specific tasks that may have different keypoint structures.
|
| 41 |
+
"""
|
| 42 |
+
# Load a model pre-trained on COCO, initialized without pre-trained weights
|
| 43 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 44 |
+
if device == torch.device('cuda'):
|
| 45 |
+
model = keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1)
|
| 46 |
+
else:
|
| 47 |
+
model = keypointrcnn_resnet50_fpn(weights=False)
|
| 48 |
+
|
| 49 |
+
# Get the number of input features for the classifier in the box predictor.
|
| 50 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
| 51 |
+
|
| 52 |
+
# Replace the box predictor in the ROI heads with a new one, tailored to the number of classes.
|
| 53 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
| 54 |
+
|
| 55 |
+
# Replace the keypoint predictor in the ROI heads with a new one, specifically designed for the desired number of keypoints.
|
| 56 |
+
model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(512, num_keypoints)
|
| 57 |
+
|
| 58 |
+
return model
|
| 59 |
+
|
| 60 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
| 61 |
+
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
|
| 62 |
+
def get_faster_rcnn_model(num_classes):
|
| 63 |
+
"""
|
| 64 |
+
Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes.
|
| 65 |
+
|
| 66 |
+
Parameters:
|
| 67 |
+
- num_classes (int): Number of classes for the model to detect, including the background class.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
- model (torch.nn.Module): The modified Faster R-CNN model.
|
| 71 |
+
"""
|
| 72 |
+
# Load a pre-trained Faster R-CNN model
|
| 73 |
+
model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
|
| 74 |
+
|
| 75 |
+
# Get the number of input features for the classifier in the box predictor
|
| 76 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
| 77 |
+
|
| 78 |
+
# Replace the box predictor with a new one, tailored to the number of classes (num_classes includes the background)
|
| 79 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
| 80 |
+
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
def prepare_model(dict,opti,learning_rate= 0.0003,model_to_load=None, model_type = 'object'):
|
| 84 |
+
# Adjusted to pass the class_dict directly
|
| 85 |
+
if model_type == 'object':
|
| 86 |
+
model = get_faster_rcnn_model(len(dict))
|
| 87 |
+
elif model_type == 'arrow':
|
| 88 |
+
model = get_arrow_model(len(dict),2)
|
| 89 |
+
|
| 90 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 91 |
+
# Load the model weights
|
| 92 |
+
if model_to_load:
|
| 93 |
+
model.load_state_dict(torch.load('./models/'+ model_to_load +'.pth', map_location=device))
|
| 94 |
+
print(f"Model '{model_to_load}' loaded")
|
| 95 |
+
|
| 96 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 97 |
+
model.to(device)
|
| 98 |
+
|
| 99 |
+
if opti == 'SGD':
|
| 100 |
+
#learning_rate= 0.002
|
| 101 |
+
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
|
| 102 |
+
elif opti == 'Adam':
|
| 103 |
+
#learning_rate = 0.0003
|
| 104 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00056, eps=1e-08, betas=(0.9, 0.999))
|
| 105 |
+
else:
|
| 106 |
+
print('Optimizer not found')
|
| 107 |
+
|
| 108 |
+
return model, optimizer, device
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
|
| 114 |
+
model.train() # Set the model to evaluation mode
|
| 115 |
+
total_loss = 0
|
| 116 |
+
|
| 117 |
+
# Initialize lists to keep track of individual losses
|
| 118 |
+
loss_classifier_list = []
|
| 119 |
+
loss_box_reg_list = []
|
| 120 |
+
loss_objectness_list = []
|
| 121 |
+
loss_rpn_box_reg_list = []
|
| 122 |
+
loss_keypoints_list = []
|
| 123 |
+
|
| 124 |
+
with torch.no_grad(): # Disable gradient computation
|
| 125 |
+
for images, targets_im in tqdm(data_loader, desc="Evaluating"):
|
| 126 |
+
images = [image.to(device) for image in images]
|
| 127 |
+
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
|
| 128 |
+
|
| 129 |
+
loss_dict = model(images, targets)
|
| 130 |
+
|
| 131 |
+
# Calculate the total loss for the current batch
|
| 132 |
+
losses = 0
|
| 133 |
+
if loss_config is not None:
|
| 134 |
+
for key, loss in loss_dict.items():
|
| 135 |
+
if loss_config.get(key, False):
|
| 136 |
+
losses += loss
|
| 137 |
+
else:
|
| 138 |
+
losses = sum(loss for key, loss in loss_dict.items())
|
| 139 |
+
|
| 140 |
+
total_loss += losses.item()
|
| 141 |
+
|
| 142 |
+
# Collect individual losses
|
| 143 |
+
if loss_dict.get('loss_classifier') is not None:
|
| 144 |
+
loss_classifier_list.append(loss_dict['loss_classifier'].item())
|
| 145 |
+
else:
|
| 146 |
+
loss_classifier_list.append(0)
|
| 147 |
+
|
| 148 |
+
if loss_dict.get('loss_box_reg') is not None:
|
| 149 |
+
loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
|
| 150 |
+
else:
|
| 151 |
+
loss_box_reg_list.append(0)
|
| 152 |
+
|
| 153 |
+
if loss_dict.get('loss_objectness') is not None:
|
| 154 |
+
loss_objectness_list.append(loss_dict['loss_objectness'].item())
|
| 155 |
+
else:
|
| 156 |
+
loss_objectness_list.append(0)
|
| 157 |
+
|
| 158 |
+
if loss_dict.get('loss_rpn_box_reg') is not None:
|
| 159 |
+
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
|
| 160 |
+
else:
|
| 161 |
+
loss_rpn_box_reg_list.append(0)
|
| 162 |
+
|
| 163 |
+
if 'loss_keypoint' in loss_dict:
|
| 164 |
+
loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
|
| 165 |
+
else:
|
| 166 |
+
loss_keypoints_list.append(0)
|
| 167 |
+
|
| 168 |
+
# Calculate average loss
|
| 169 |
+
avg_loss = total_loss / len(data_loader)
|
| 170 |
+
|
| 171 |
+
avg_loss_classifier = np.mean(loss_classifier_list)
|
| 172 |
+
avg_loss_box_reg = np.mean(loss_box_reg_list)
|
| 173 |
+
avg_loss_objectness = np.mean(loss_objectness_list)
|
| 174 |
+
avg_loss_rpn_box_reg = np.mean(loss_rpn_box_reg_list)
|
| 175 |
+
avg_loss_keypoints = np.mean(loss_keypoints_list)
|
| 176 |
+
|
| 177 |
+
if print_losses:
|
| 178 |
+
print(f"Average Loss: {avg_loss:.4f}")
|
| 179 |
+
print(f"Average Classifier Loss: {avg_loss_classifier:.4f}")
|
| 180 |
+
print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}")
|
| 181 |
+
print(f"Average Objectness Loss: {avg_loss_objectness:.4f}")
|
| 182 |
+
print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}")
|
| 183 |
+
print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}")
|
| 184 |
+
|
| 185 |
+
return avg_loss
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def training_model(num_epochs, model, data_loader, subset_test_loader,
|
| 189 |
+
optimizer, model_to_load=None, change_learning_rate=5, start_key=30,
|
| 190 |
+
batch_size=4, crop_prob=0.2, h_flip_prob=0.3, v_flip_prob=0.3,
|
| 191 |
+
max_rotate_deg=20, rotate_proba=0.2, blur_prob=0.2,
|
| 192 |
+
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
|
| 193 |
+
information_training='training', start_epoch=0, loss_config=None, model_type = 'object',
|
| 194 |
+
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if loss_config is None:
|
| 198 |
+
print('No loss config found, all losses will be used.')
|
| 199 |
+
else:
|
| 200 |
+
#print the list of the losses that will be used
|
| 201 |
+
print('The following losses will be used: ', end='')
|
| 202 |
+
for key, value in loss_config.items():
|
| 203 |
+
if value:
|
| 204 |
+
print(key, end=", ")
|
| 205 |
+
print()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# Initialize lists to store epoch-wise average losses
|
| 209 |
+
epoch_avg_losses = []
|
| 210 |
+
epoch_avg_loss_classifier = []
|
| 211 |
+
epoch_avg_loss_box_reg = []
|
| 212 |
+
epoch_avg_loss_objectness = []
|
| 213 |
+
epoch_avg_loss_rpn_box_reg = []
|
| 214 |
+
epoch_avg_loss_keypoints = []
|
| 215 |
+
epoch_precision = []
|
| 216 |
+
epoch_recall = []
|
| 217 |
+
epoch_f1_score = []
|
| 218 |
+
epoch_test_loss = []
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
start_tot = time.time()
|
| 222 |
+
best_metrics = -1000
|
| 223 |
+
best_epoch = 0
|
| 224 |
+
best_model_state = None
|
| 225 |
+
same = 0
|
| 226 |
+
learning_rate = optimizer.param_groups[0]['lr']
|
| 227 |
+
bad_test_loss = 0
|
| 228 |
+
previous_test_loss = 1000
|
| 229 |
+
|
| 230 |
+
print(f"Let's go training {model_type} model with {num_epochs} epochs!")
|
| 231 |
+
print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, Flip prob: {h_flip_prob}, Rotate prob: {rotate_proba}, Blur prob: {blur_prob}")
|
| 232 |
+
|
| 233 |
+
for epoch in range(num_epochs):
|
| 234 |
+
|
| 235 |
+
if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss>1:
|
| 236 |
+
learning_rate = 0.7*learning_rate
|
| 237 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
|
| 238 |
+
print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
|
| 239 |
+
bad_test_loss = 0
|
| 240 |
+
if epoch>0 and (epoch)==start_key:
|
| 241 |
+
print("Now it's training Keypoints also")
|
| 242 |
+
loss_config['loss_keypoint'] = True
|
| 243 |
+
for name, param in model.named_parameters():
|
| 244 |
+
if 'keypoint' in name:
|
| 245 |
+
param.requires_grad = True
|
| 246 |
+
|
| 247 |
+
model.train()
|
| 248 |
+
start = time.time()
|
| 249 |
+
total_loss = 0
|
| 250 |
+
|
| 251 |
+
# Initialize lists to keep track of individual losses
|
| 252 |
+
loss_classifier_list = []
|
| 253 |
+
loss_box_reg_list = []
|
| 254 |
+
loss_objectness_list = []
|
| 255 |
+
loss_rpn_box_reg_list = []
|
| 256 |
+
loss_keypoints_list = []
|
| 257 |
+
|
| 258 |
+
# Create a tqdm progress bar
|
| 259 |
+
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}')
|
| 260 |
+
|
| 261 |
+
for images, targets_im in progress_bar:
|
| 262 |
+
images = [image.to(device) for image in images]
|
| 263 |
+
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
|
| 264 |
+
|
| 265 |
+
optimizer.zero_grad()
|
| 266 |
+
|
| 267 |
+
loss_dict = model(images, targets)
|
| 268 |
+
# Inside the training loop where losses are calculated:
|
| 269 |
+
losses = 0
|
| 270 |
+
if loss_config is not None:
|
| 271 |
+
for key, loss in loss_dict.items():
|
| 272 |
+
if loss_config.get(key, False):
|
| 273 |
+
if key == 'loss_classifier':
|
| 274 |
+
loss *= 3
|
| 275 |
+
losses += loss
|
| 276 |
+
else:
|
| 277 |
+
losses = sum(loss for key, loss in loss_dict.items())
|
| 278 |
+
|
| 279 |
+
# Collect individual losses
|
| 280 |
+
if loss_dict['loss_classifier']:
|
| 281 |
+
loss_classifier_list.append(loss_dict['loss_classifier'].item())
|
| 282 |
+
else:
|
| 283 |
+
loss_classifier_list.append(0)
|
| 284 |
+
|
| 285 |
+
if loss_dict['loss_box_reg']:
|
| 286 |
+
loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
|
| 287 |
+
else:
|
| 288 |
+
loss_box_reg_list.append(0)
|
| 289 |
+
|
| 290 |
+
if loss_dict['loss_objectness']:
|
| 291 |
+
loss_objectness_list.append(loss_dict['loss_objectness'].item())
|
| 292 |
+
else:
|
| 293 |
+
loss_objectness_list.append(0)
|
| 294 |
+
|
| 295 |
+
if loss_dict['loss_rpn_box_reg']:
|
| 296 |
+
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
|
| 297 |
+
else:
|
| 298 |
+
loss_rpn_box_reg_list.append(0)
|
| 299 |
+
|
| 300 |
+
if 'loss_keypoint' in loss_dict:
|
| 301 |
+
loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
|
| 302 |
+
else:
|
| 303 |
+
loss_keypoints_list.append(0)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
losses.backward()
|
| 307 |
+
optimizer.step()
|
| 308 |
+
|
| 309 |
+
total_loss += losses.item()
|
| 310 |
+
|
| 311 |
+
# Update the description with the current loss
|
| 312 |
+
progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}')
|
| 313 |
+
|
| 314 |
+
# Calculate average loss
|
| 315 |
+
avg_loss = total_loss / len(data_loader)
|
| 316 |
+
|
| 317 |
+
epoch_avg_losses.append(avg_loss)
|
| 318 |
+
epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
|
| 319 |
+
epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
|
| 320 |
+
epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
|
| 321 |
+
epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
|
| 322 |
+
epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# Evaluate the model on the test set
|
| 326 |
+
if eval_metric != 'loss':
|
| 327 |
+
avg_test_loss = 0
|
| 328 |
+
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
|
| 329 |
+
print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
|
| 330 |
+
if eval_metric == 'all':
|
| 331 |
+
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
| 332 |
+
print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
| 333 |
+
if eval_metric == 'loss':
|
| 334 |
+
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0
|
| 335 |
+
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
| 336 |
+
print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
| 337 |
+
|
| 338 |
+
print(f"Time: {time.time() - start:.2f} [s]")
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
if epoch>0 and (epoch)%start_key == 0:
|
| 342 |
+
print(f"Keypoints Accuracy: {key_accuracy:.4f}", end=", ")
|
| 343 |
+
|
| 344 |
+
if eval_metric == 'f1_score':
|
| 345 |
+
metric_used = f1_score
|
| 346 |
+
elif eval_metric == 'precision':
|
| 347 |
+
metric_used = precision
|
| 348 |
+
elif eval_metric == 'recall':
|
| 349 |
+
metric_used = recall
|
| 350 |
+
else:
|
| 351 |
+
metric_used = -avg_test_loss
|
| 352 |
+
|
| 353 |
+
# Check if this epoch's model has the lowest average loss
|
| 354 |
+
if metric_used > best_metrics:
|
| 355 |
+
best_metrics = metric_used
|
| 356 |
+
best_epoch = epoch+1+start_epoch
|
| 357 |
+
best_model_state = copy.deepcopy(model.state_dict())
|
| 358 |
+
|
| 359 |
+
if epoch>0 and f1_score>early_stop_f1_score:
|
| 360 |
+
same+=1
|
| 361 |
+
|
| 362 |
+
epoch_precision.append(precision)
|
| 363 |
+
epoch_recall.append(recall)
|
| 364 |
+
epoch_f1_score.append(f1_score)
|
| 365 |
+
epoch_test_loss.append(avg_test_loss)
|
| 366 |
+
|
| 367 |
+
name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
|
| 368 |
+
|
| 369 |
+
if same >=1 :
|
| 370 |
+
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
| 371 |
+
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
| 372 |
+
write_results(name_model,metrics_list,start_epoch)
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
if (epoch+1+start_epoch) % 5 == 0:
|
| 376 |
+
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
| 377 |
+
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
| 378 |
+
model.load_state_dict(best_model_state)
|
| 379 |
+
write_results(name_model,metrics_list,start_epoch)
|
| 380 |
+
|
| 381 |
+
if avg_test_loss > previous_test_loss:
|
| 382 |
+
bad_test_loss += 1
|
| 383 |
+
previous_test_loss = avg_test_loss
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an f1_score of {best_metrics:.4f}")
|
| 387 |
+
if best_model_state:
|
| 388 |
+
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
| 389 |
+
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
| 390 |
+
model.load_state_dict(best_model_state)
|
| 391 |
+
write_results(name_model,metrics_list,start_epoch)
|
| 392 |
+
print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}")
|
| 393 |
+
|
| 394 |
+
return model, metrics_list
|
utils.py
ADDED
|
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision.models.detection import keypointrcnn_resnet50_fpn
|
| 2 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
| 3 |
+
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
|
| 4 |
+
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
|
| 5 |
+
import random
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
import torchvision.transforms.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data.dataloader import default_collate
|
| 11 |
+
import cv2
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from torch.utils.data import DataLoader, Subset, ConcatDataset
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from torch.optim import SGD
|
| 16 |
+
import time
|
| 17 |
+
from torch.optim import AdamW
|
| 18 |
+
import copy
|
| 19 |
+
from torchvision import transforms
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
object_dict = {
|
| 23 |
+
0: 'background',
|
| 24 |
+
1: 'task',
|
| 25 |
+
2: 'exclusiveGateway',
|
| 26 |
+
3: 'event',
|
| 27 |
+
4: 'parallelGateway',
|
| 28 |
+
5: 'messageEvent',
|
| 29 |
+
6: 'pool',
|
| 30 |
+
7: 'lane',
|
| 31 |
+
8: 'dataObject',
|
| 32 |
+
9: 'dataStore',
|
| 33 |
+
10: 'subProcess',
|
| 34 |
+
11: 'eventBasedGateway',
|
| 35 |
+
12: 'timerEvent',
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
arrow_dict = {
|
| 39 |
+
0: 'background',
|
| 40 |
+
1: 'sequenceFlow',
|
| 41 |
+
2: 'dataAssociation',
|
| 42 |
+
3: 'messageFlow',
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
class_dict = {
|
| 46 |
+
0: 'background',
|
| 47 |
+
1: 'task',
|
| 48 |
+
2: 'exclusiveGateway',
|
| 49 |
+
3: 'event',
|
| 50 |
+
4: 'parallelGateway',
|
| 51 |
+
5: 'messageEvent',
|
| 52 |
+
6: 'pool',
|
| 53 |
+
7: 'lane',
|
| 54 |
+
8: 'dataObject',
|
| 55 |
+
9: 'dataStore',
|
| 56 |
+
10: 'subProcess',
|
| 57 |
+
11: 'eventBasedGateway',
|
| 58 |
+
12: 'timerEvent',
|
| 59 |
+
13: 'sequenceFlow',
|
| 60 |
+
14: 'dataAssociation',
|
| 61 |
+
15: 'messageFlow',
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def rescale_boxes(scale, boxes):
|
| 65 |
+
for i in range(len(boxes)):
|
| 66 |
+
boxes[i] = [boxes[i][0]*scale,
|
| 67 |
+
boxes[i][1]*scale,
|
| 68 |
+
boxes[i][2]*scale,
|
| 69 |
+
boxes[i][3]*scale]
|
| 70 |
+
return boxes
|
| 71 |
+
|
| 72 |
+
def iou(box1, box2):
|
| 73 |
+
# Calcule l'intersection des deux boîtes englobantes
|
| 74 |
+
inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
|
| 75 |
+
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
| 76 |
+
|
| 77 |
+
# Calcule l'union des deux boîtes englobantes
|
| 78 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 79 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 80 |
+
union_area = box1_area + box2_area - inter_area
|
| 81 |
+
|
| 82 |
+
return inter_area / union_area
|
| 83 |
+
|
| 84 |
+
def proportion_inside(box1, box2):
|
| 85 |
+
# Calculate the intersection of the two bounding boxes
|
| 86 |
+
inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
|
| 87 |
+
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
| 88 |
+
|
| 89 |
+
# Calculate the area of box1
|
| 90 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 91 |
+
|
| 92 |
+
# Calculate the proportion of box1 inside box2
|
| 93 |
+
if box1_area == 0:
|
| 94 |
+
return 0
|
| 95 |
+
proportion = inter_area / box1_area
|
| 96 |
+
|
| 97 |
+
# Ensure the proportion is at most 100%
|
| 98 |
+
return min(proportion, 1.0)
|
| 99 |
+
|
| 100 |
+
def resize_boxes(boxes, original_size, target_size):
|
| 101 |
+
"""
|
| 102 |
+
Resizes bounding boxes according to a new image size.
|
| 103 |
+
|
| 104 |
+
Parameters:
|
| 105 |
+
- boxes (np.array): The original bounding boxes as a numpy array of shape [N, 4].
|
| 106 |
+
- original_size (tuple): The original size of the image as (width, height).
|
| 107 |
+
- target_size (tuple): The desired size to resize the image to as (width, height).
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
- np.array: The resized bounding boxes as a numpy array of shape [N, 4].
|
| 111 |
+
"""
|
| 112 |
+
orig_width, orig_height = original_size
|
| 113 |
+
target_width, target_height = target_size
|
| 114 |
+
|
| 115 |
+
# Calculate the ratios for width and height
|
| 116 |
+
width_ratio = target_width / orig_width
|
| 117 |
+
height_ratio = target_height / orig_height
|
| 118 |
+
|
| 119 |
+
# Apply the ratios to the bounding boxes
|
| 120 |
+
boxes[:, 0] *= width_ratio
|
| 121 |
+
boxes[:, 1] *= height_ratio
|
| 122 |
+
boxes[:, 2] *= width_ratio
|
| 123 |
+
boxes[:, 3] *= height_ratio
|
| 124 |
+
|
| 125 |
+
return boxes
|
| 126 |
+
|
| 127 |
+
def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: tuple) -> np.ndarray:
|
| 128 |
+
"""
|
| 129 |
+
Resize keypoints based on the original and target dimensions of an image.
|
| 130 |
+
|
| 131 |
+
Parameters:
|
| 132 |
+
- keypoints (np.ndarray): The array of keypoints, where each keypoint is represented by its (x, y) coordinates.
|
| 133 |
+
- original_size (tuple): The width and height of the original image (width, height).
|
| 134 |
+
- target_size (tuple): The width and height of the target image (width, height).
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
- np.ndarray: The resized keypoints.
|
| 138 |
+
|
| 139 |
+
Explanation:
|
| 140 |
+
The function calculates the ratio of the target dimensions to the original dimensions.
|
| 141 |
+
It then applies these ratios to the x and y coordinates of each keypoint to scale them
|
| 142 |
+
appropriately to the target image size.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
orig_width, orig_height = original_size
|
| 146 |
+
target_width, target_height = target_size
|
| 147 |
+
|
| 148 |
+
# Calculate the ratios for width and height scaling
|
| 149 |
+
width_ratio = target_width / orig_width
|
| 150 |
+
height_ratio = target_height / orig_height
|
| 151 |
+
|
| 152 |
+
# Apply the scaling ratios to the x and y coordinates of each keypoint
|
| 153 |
+
keypoints[:, 0] *= width_ratio # Scale x coordinates
|
| 154 |
+
keypoints[:, 1] *= height_ratio # Scale y coordinates
|
| 155 |
+
|
| 156 |
+
return keypoints
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class RandomCrop:
|
| 161 |
+
def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
|
| 162 |
+
self.crop_fraction = crop_fraction
|
| 163 |
+
self.min_objects = min_objects
|
| 164 |
+
self.new_size = new_size
|
| 165 |
+
|
| 166 |
+
def __call__(self, image, target):
|
| 167 |
+
new_w1, new_h1 = self.new_size
|
| 168 |
+
w, h = image.size
|
| 169 |
+
new_w = int(w * self.crop_fraction)
|
| 170 |
+
new_h = int(new_w*new_h1/new_w1)
|
| 171 |
+
|
| 172 |
+
i=0
|
| 173 |
+
for i in range(4):
|
| 174 |
+
if new_h >= h:
|
| 175 |
+
i += 0.05
|
| 176 |
+
new_w = int(w * (self.crop_fraction - i))
|
| 177 |
+
new_h = int(new_w*new_h1/new_w1)
|
| 178 |
+
if new_h < h:
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
if new_h >= h:
|
| 182 |
+
return image, target
|
| 183 |
+
|
| 184 |
+
boxes = target["boxes"]
|
| 185 |
+
if 'keypoints' in target:
|
| 186 |
+
keypoints = target["keypoints"]
|
| 187 |
+
else:
|
| 188 |
+
keypoints = []
|
| 189 |
+
for i in range(len(boxes)):
|
| 190 |
+
keypoints.append(torch.zeros((2,3)))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# Attempt to find a suitable crop region
|
| 194 |
+
success = False
|
| 195 |
+
for _ in range(100): # Max 100 attempts to find a valid crop
|
| 196 |
+
top = random.randint(0, h - new_h)
|
| 197 |
+
left = random.randint(0, w - new_w)
|
| 198 |
+
crop_region = [left, top, left + new_w, top + new_h]
|
| 199 |
+
|
| 200 |
+
# Check how many objects are fully contained in this region
|
| 201 |
+
contained_boxes = []
|
| 202 |
+
contained_keypoints = []
|
| 203 |
+
for box, kp in zip(boxes, keypoints):
|
| 204 |
+
if box[0] >= crop_region[0] and box[1] >= crop_region[1] and box[2] <= crop_region[2] and box[3] <= crop_region[3]:
|
| 205 |
+
# Adjust box and keypoints coordinates
|
| 206 |
+
new_box = box - torch.tensor([crop_region[0], crop_region[1], crop_region[0], crop_region[1]])
|
| 207 |
+
new_kp = kp - torch.tensor([crop_region[0], crop_region[1], 0])
|
| 208 |
+
contained_boxes.append(new_box)
|
| 209 |
+
contained_keypoints.append(new_kp)
|
| 210 |
+
|
| 211 |
+
if len(contained_boxes) >= self.min_objects:
|
| 212 |
+
success = True
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
if success:
|
| 216 |
+
# Perform the actual crop
|
| 217 |
+
image = F.crop(image, top, left, new_h, new_w)
|
| 218 |
+
target["boxes"] = torch.stack(contained_boxes) if contained_boxes else torch.zeros((0, 4))
|
| 219 |
+
if 'keypoints' in target:
|
| 220 |
+
target["keypoints"] = torch.stack(contained_keypoints) if contained_keypoints else torch.zeros((0, 2, 4))
|
| 221 |
+
|
| 222 |
+
return image, target
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class RandomFlip:
|
| 226 |
+
def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
|
| 227 |
+
"""
|
| 228 |
+
Initializes the RandomFlip with probabilities for flipping.
|
| 229 |
+
|
| 230 |
+
Parameters:
|
| 231 |
+
- h_flip_prob (float): Probability of applying a horizontal flip to the image.
|
| 232 |
+
- v_flip_prob (float): Probability of applying a vertical flip to the image.
|
| 233 |
+
"""
|
| 234 |
+
self.h_flip_prob = h_flip_prob
|
| 235 |
+
self.v_flip_prob = v_flip_prob
|
| 236 |
+
|
| 237 |
+
def __call__(self, image, target):
|
| 238 |
+
"""
|
| 239 |
+
Applies random horizontal and/or vertical flip to the image and updates target data accordingly.
|
| 240 |
+
|
| 241 |
+
Parameters:
|
| 242 |
+
- image (PIL Image): The image to be flipped.
|
| 243 |
+
- target (dict): The target dictionary containing 'boxes' and 'keypoints'.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
- PIL Image, dict: The flipped image and its updated target dictionary.
|
| 247 |
+
"""
|
| 248 |
+
if random.random() < self.h_flip_prob:
|
| 249 |
+
image = F.hflip(image)
|
| 250 |
+
w, _ = image.size # Get the new width of the image after flip for bounding box adjustment
|
| 251 |
+
# Adjust bounding boxes for horizontal flip
|
| 252 |
+
for i, box in enumerate(target['boxes']):
|
| 253 |
+
xmin, ymin, xmax, ymax = box
|
| 254 |
+
target['boxes'][i] = torch.tensor([w - xmax, ymin, w - xmin, ymax], dtype=torch.float32)
|
| 255 |
+
|
| 256 |
+
# Adjust keypoints for horizontal flip
|
| 257 |
+
if 'keypoints' in target:
|
| 258 |
+
new_keypoints = []
|
| 259 |
+
for keypoints_for_object in target['keypoints']:
|
| 260 |
+
flipped_keypoints_for_object = []
|
| 261 |
+
for kp in keypoints_for_object:
|
| 262 |
+
x, y = kp[:2]
|
| 263 |
+
new_x = w - x
|
| 264 |
+
flipped_keypoints_for_object.append(torch.tensor([new_x, y] + list(kp[2:])))
|
| 265 |
+
new_keypoints.append(torch.stack(flipped_keypoints_for_object))
|
| 266 |
+
target['keypoints'] = torch.stack(new_keypoints)
|
| 267 |
+
|
| 268 |
+
if random.random() < self.v_flip_prob:
|
| 269 |
+
image = F.vflip(image)
|
| 270 |
+
_, h = image.size # Get the new height of the image after flip for bounding box adjustment
|
| 271 |
+
# Adjust bounding boxes for vertical flip
|
| 272 |
+
for i, box in enumerate(target['boxes']):
|
| 273 |
+
xmin, ymin, xmax, ymax = box
|
| 274 |
+
target['boxes'][i] = torch.tensor([xmin, h - ymax, xmax, h - ymin], dtype=torch.float32)
|
| 275 |
+
|
| 276 |
+
# Adjust keypoints for vertical flip
|
| 277 |
+
if 'keypoints' in target:
|
| 278 |
+
new_keypoints = []
|
| 279 |
+
for keypoints_for_object in target['keypoints']:
|
| 280 |
+
flipped_keypoints_for_object = []
|
| 281 |
+
for kp in keypoints_for_object:
|
| 282 |
+
x, y = kp[:2]
|
| 283 |
+
new_y = h - y
|
| 284 |
+
flipped_keypoints_for_object.append(torch.tensor([x, new_y] + list(kp[2:])))
|
| 285 |
+
new_keypoints.append(torch.stack(flipped_keypoints_for_object))
|
| 286 |
+
target['keypoints'] = torch.stack(new_keypoints)
|
| 287 |
+
|
| 288 |
+
return image, target
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class RandomRotate:
|
| 292 |
+
def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
|
| 293 |
+
"""
|
| 294 |
+
Initializes the RandomRotate with a maximum rotation angle and probability of rotating.
|
| 295 |
+
|
| 296 |
+
Parameters:
|
| 297 |
+
- max_rotate_deg (int): Maximum degree to rotate the image.
|
| 298 |
+
- rotate_proba (float): Probability of applying rotation to the image.
|
| 299 |
+
"""
|
| 300 |
+
self.max_rotate_deg = max_rotate_deg
|
| 301 |
+
self.rotate_proba = rotate_proba
|
| 302 |
+
|
| 303 |
+
def __call__(self, image, target):
|
| 304 |
+
"""
|
| 305 |
+
Randomly rotates the image and updates the target data accordingly.
|
| 306 |
+
|
| 307 |
+
Parameters:
|
| 308 |
+
- image (PIL Image): The image to be rotated.
|
| 309 |
+
- target (dict): The target dictionary containing 'boxes', 'labels', and 'keypoints'.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
- PIL Image, dict: The rotated image and its updated target dictionary.
|
| 313 |
+
"""
|
| 314 |
+
if random.random() < self.rotate_proba:
|
| 315 |
+
angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
|
| 316 |
+
image = F.rotate(image, angle, expand=False, fill=200)
|
| 317 |
+
|
| 318 |
+
# Rotate bounding boxes
|
| 319 |
+
w, h = image.size
|
| 320 |
+
cx, cy = w / 2, h / 2
|
| 321 |
+
boxes = target["boxes"]
|
| 322 |
+
new_boxes = []
|
| 323 |
+
for box in boxes:
|
| 324 |
+
new_box = self.rotate_box(box, angle, cx, cy)
|
| 325 |
+
new_boxes.append(new_box)
|
| 326 |
+
target["boxes"] = torch.stack(new_boxes)
|
| 327 |
+
|
| 328 |
+
# Rotate keypoints
|
| 329 |
+
if 'keypoints' in target:
|
| 330 |
+
new_keypoints = []
|
| 331 |
+
for keypoints in target["keypoints"]:
|
| 332 |
+
new_kp = self.rotate_keypoints(keypoints, angle, cx, cy)
|
| 333 |
+
new_keypoints.append(new_kp)
|
| 334 |
+
target["keypoints"] = torch.stack(new_keypoints)
|
| 335 |
+
|
| 336 |
+
return image, target
|
| 337 |
+
|
| 338 |
+
def rotate_box(self, box, angle, cx, cy):
|
| 339 |
+
"""
|
| 340 |
+
Rotates a bounding box by a given angle around the center of the image.
|
| 341 |
+
"""
|
| 342 |
+
x1, y1, x2, y2 = box
|
| 343 |
+
corners = torch.tensor([
|
| 344 |
+
[x1, y1],
|
| 345 |
+
[x2, y1],
|
| 346 |
+
[x2, y2],
|
| 347 |
+
[x1, y2]
|
| 348 |
+
])
|
| 349 |
+
corners = torch.cat((corners, torch.ones(corners.shape[0], 1)), dim=1)
|
| 350 |
+
M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
|
| 351 |
+
corners = torch.matmul(torch.tensor(M, dtype=torch.float32), corners.T).T
|
| 352 |
+
x_ = corners[:, 0]
|
| 353 |
+
y_ = corners[:, 1]
|
| 354 |
+
x_min, x_max = torch.min(x_), torch.max(x_)
|
| 355 |
+
y_min, y_max = torch.min(y_), torch.max(y_)
|
| 356 |
+
return torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32)
|
| 357 |
+
|
| 358 |
+
def rotate_keypoints(self, keypoints, angle, cx, cy):
|
| 359 |
+
"""
|
| 360 |
+
Rotates keypoints by a given angle around the center of the image.
|
| 361 |
+
"""
|
| 362 |
+
new_keypoints = []
|
| 363 |
+
for kp in keypoints:
|
| 364 |
+
x, y, v = kp
|
| 365 |
+
point = torch.tensor([x, y, 1])
|
| 366 |
+
M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
|
| 367 |
+
new_point = torch.matmul(torch.tensor(M, dtype=torch.float32), point)
|
| 368 |
+
new_keypoints.append(torch.tensor([new_point[0], new_point[1], v], dtype=torch.float32))
|
| 369 |
+
return torch.stack(new_keypoints)
|
| 370 |
+
|
| 371 |
+
def rotate_90_box(box, angle, w, h):
|
| 372 |
+
x1, y1, x2, y2 = box
|
| 373 |
+
if angle == 90:
|
| 374 |
+
return torch.tensor([y1,h-x2,y2,h-x1])
|
| 375 |
+
elif angle == 270 or angle == -90:
|
| 376 |
+
return torch.tensor([w-y2,x1,w-y1,x2])
|
| 377 |
+
else:
|
| 378 |
+
print("angle not supported")
|
| 379 |
+
|
| 380 |
+
def rotate_90_keypoints(kp, angle, w, h):
|
| 381 |
+
# Extract coordinates and visibility from each keypoint tensor
|
| 382 |
+
x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
|
| 383 |
+
x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
|
| 384 |
+
# Swap x and y coordinates for each keypoint
|
| 385 |
+
if angle == 90:
|
| 386 |
+
new = [[y1, h-x1, v1], [y2, h-x2, v2]]
|
| 387 |
+
elif angle == 270 or angle == -90:
|
| 388 |
+
new = [[w-y1, x1, v1], [w-y2, x2, v2]]
|
| 389 |
+
|
| 390 |
+
return torch.tensor(new, dtype=torch.float32)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def rotate_vertical(image, target):
|
| 394 |
+
# Rotate the image and target if the image is vertical
|
| 395 |
+
new_boxes = []
|
| 396 |
+
angle = random.choice([-90,90])
|
| 397 |
+
image = F.rotate(image, angle, expand=True, fill=200)
|
| 398 |
+
for box in target["boxes"]:
|
| 399 |
+
new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
|
| 400 |
+
new_boxes.append(new_box)
|
| 401 |
+
target["boxes"] = torch.stack(new_boxes)
|
| 402 |
+
|
| 403 |
+
if 'keypoints' in target:
|
| 404 |
+
new_kp = []
|
| 405 |
+
for kp in target['keypoints']:
|
| 406 |
+
new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
|
| 407 |
+
new_kp.append(new_key)
|
| 408 |
+
target['keypoints'] = torch.stack(new_kp)
|
| 409 |
+
return image, target
|
| 410 |
+
|
| 411 |
+
class BPMN_Dataset(Dataset):
|
| 412 |
+
def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2, flip_transform=None, rotate_transform=None, new_size=(1333,800),keep_ratio=False,resize=True, model_type='object', rotate_vertical=False):
|
| 413 |
+
self.annotations = annotations
|
| 414 |
+
print(f"Loaded {len(self.annotations)} annotations.")
|
| 415 |
+
self.transform = transform
|
| 416 |
+
self.crop_transform = crop_transform
|
| 417 |
+
self.crop_prob = crop_prob
|
| 418 |
+
self.flip_transform = flip_transform
|
| 419 |
+
self.rotate_transform = rotate_transform
|
| 420 |
+
self.resize = resize
|
| 421 |
+
self.rotate_vertical = rotate_vertical
|
| 422 |
+
self.new_size = new_size
|
| 423 |
+
self.keep_ratio = keep_ratio
|
| 424 |
+
self.model_type = model_type
|
| 425 |
+
if model_type == 'object':
|
| 426 |
+
self.dict = object_dict
|
| 427 |
+
elif model_type == 'arrow':
|
| 428 |
+
self.dict = arrow_dict
|
| 429 |
+
self.rotate_90_proba = rotate_90_proba
|
| 430 |
+
|
| 431 |
+
def __len__(self):
|
| 432 |
+
return len(self.annotations)
|
| 433 |
+
|
| 434 |
+
def __getitem__(self, idx):
|
| 435 |
+
annotation = self.annotations[idx]
|
| 436 |
+
image = annotation.img.convert("RGB")
|
| 437 |
+
boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
|
| 438 |
+
labels_names = [ann for ann in annotation.categories]
|
| 439 |
+
|
| 440 |
+
#only keep the labels, boxes and keypoints that are in the class_dict
|
| 441 |
+
kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
|
| 442 |
+
boxes = boxes[kept_indices]
|
| 443 |
+
labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
|
| 444 |
+
|
| 445 |
+
labels_id = torch.tensor([(list(self.dict.values()).index(ann)) for ann in labels_names], dtype=torch.int64)
|
| 446 |
+
|
| 447 |
+
# Initialize keypoints tensor
|
| 448 |
+
max_keypoints = 2
|
| 449 |
+
keypoints = torch.zeros((len(labels_id), max_keypoints, 3), dtype=torch.float32)
|
| 450 |
+
|
| 451 |
+
ii=0
|
| 452 |
+
for i, ann in enumerate(annotation.annotations):
|
| 453 |
+
#only keep the keypoints that are in the kept indices
|
| 454 |
+
if i not in kept_indices:
|
| 455 |
+
continue
|
| 456 |
+
if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
|
| 457 |
+
# Fill the keypoints tensor for this annotation, mark as visible (1)
|
| 458 |
+
kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
|
| 459 |
+
kp = kp[:,:2]
|
| 460 |
+
visible = np.ones((kp.shape[0], 1), dtype=np.float32)
|
| 461 |
+
kp = np.hstack([kp, visible])
|
| 462 |
+
keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
|
| 463 |
+
ii += 1
|
| 464 |
+
|
| 465 |
+
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
| 466 |
+
|
| 467 |
+
if self.model_type == 'object':
|
| 468 |
+
target = {
|
| 469 |
+
"boxes": boxes,
|
| 470 |
+
"labels": labels_id,
|
| 471 |
+
#"area": area,
|
| 472 |
+
#"keypoints": keypoints,
|
| 473 |
+
}
|
| 474 |
+
elif self.model_type == 'arrow':
|
| 475 |
+
target = {
|
| 476 |
+
"boxes": boxes,
|
| 477 |
+
"labels": labels_id,
|
| 478 |
+
#"area": area,
|
| 479 |
+
"keypoints": keypoints,
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
# Randomly apply flip transform
|
| 483 |
+
if self.flip_transform:
|
| 484 |
+
image, target = self.flip_transform(image, target)
|
| 485 |
+
|
| 486 |
+
# Randomly apply rotate transform
|
| 487 |
+
if self.rotate_transform:
|
| 488 |
+
image, target = self.rotate_transform(image, target)
|
| 489 |
+
|
| 490 |
+
# Randomly apply the custom cropping transform
|
| 491 |
+
if self.crop_transform and random.random() < self.crop_prob:
|
| 492 |
+
image, target = self.crop_transform(image, target)
|
| 493 |
+
|
| 494 |
+
# Rotate vertical image
|
| 495 |
+
if self.rotate_vertical and random.random() < self.rotate_90_proba:
|
| 496 |
+
image, target = rotate_vertical(image, target)
|
| 497 |
+
|
| 498 |
+
if self.resize:
|
| 499 |
+
if self.keep_ratio:
|
| 500 |
+
original_size = image.size
|
| 501 |
+
# Calculate scale to fit the new size while maintaining aspect ratio
|
| 502 |
+
scale = min(self.new_size[0] / original_size[0], self.new_size[1] / original_size[1])
|
| 503 |
+
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
|
| 504 |
+
|
| 505 |
+
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), (new_scaled_size))
|
| 506 |
+
if 'area' in target:
|
| 507 |
+
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
| 508 |
+
|
| 509 |
+
if 'keypoints' in target:
|
| 510 |
+
for i in range(len(target['keypoints'])):
|
| 511 |
+
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), (new_scaled_size))
|
| 512 |
+
|
| 513 |
+
# Resize image to new scaled size
|
| 514 |
+
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
|
| 515 |
+
|
| 516 |
+
# Pad the resized image to make it exactly the desired size
|
| 517 |
+
padding = [0, 0, self.new_size[0] - new_scaled_size[0], self.new_size[1] - new_scaled_size[1]]
|
| 518 |
+
image = F.pad(image, padding, fill=200, padding_mode='constant')
|
| 519 |
+
else:
|
| 520 |
+
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
|
| 521 |
+
if 'area' in target:
|
| 522 |
+
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
| 523 |
+
if 'keypoints' in target:
|
| 524 |
+
for i in range(len(target['keypoints'])):
|
| 525 |
+
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
|
| 526 |
+
image = F.resize(image, (self.new_size[1], self.new_size[0]))
|
| 527 |
+
|
| 528 |
+
return self.transform(image), target
|
| 529 |
+
|
| 530 |
+
def collate_fn(batch):
|
| 531 |
+
"""
|
| 532 |
+
Custom collation function for DataLoader that handles batches of images and targets.
|
| 533 |
+
|
| 534 |
+
This function ensures that images are properly batched together using PyTorch's default collation,
|
| 535 |
+
while keeping the targets (such as bounding boxes and labels) in a list of dictionaries,
|
| 536 |
+
as each image might have a different number of objects detected.
|
| 537 |
+
|
| 538 |
+
Parameters:
|
| 539 |
+
- batch (list): A list of tuples, where each tuple contains an image and its corresponding target dictionary.
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
- Tuple containing:
|
| 543 |
+
- Tensor: Batched images.
|
| 544 |
+
- List of dicts: Targets corresponding to each image in the batch.
|
| 545 |
+
"""
|
| 546 |
+
images, targets = zip(*batch) # Unzip the batch into separate lists for images and targets.
|
| 547 |
+
|
| 548 |
+
# Batch images using the default collate function which handles tensors, numpy arrays, numbers, etc.
|
| 549 |
+
images = default_collate(images)
|
| 550 |
+
|
| 551 |
+
return images, targets
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def create_loader(new_size,transformation, annotations1, annotations2=None,
|
| 556 |
+
batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
|
| 557 |
+
h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
|
| 558 |
+
seed=42, resize=True, rotate_vertical=False, keep_ratio=False, model_type = 'object'):
|
| 559 |
+
"""
|
| 560 |
+
Creates a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
|
| 561 |
+
|
| 562 |
+
Parameters:
|
| 563 |
+
- transformation (callable): Transformation function to apply to each image (e.g., normalization).
|
| 564 |
+
- annotations1 (list): Primary list of annotations.
|
| 565 |
+
- annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
|
| 566 |
+
- batch_size (int): Number of images per batch.
|
| 567 |
+
- crop_prob (float): Probability of applying the crop transformation.
|
| 568 |
+
- crop_fraction (float): Fraction of the original width to use when cropping.
|
| 569 |
+
- min_objects (int): Minimum number of objects required to be within the crop.
|
| 570 |
+
- h_flip_prob (float): Probability of applying horizontal flip.
|
| 571 |
+
- v_flip_prob (float): Probability of applying vertical flip.
|
| 572 |
+
- seed (int): Seed for random number generators for reproducibility.
|
| 573 |
+
- resize (bool): Flag indicating whether to resize images after transformations.
|
| 574 |
+
|
| 575 |
+
Returns:
|
| 576 |
+
- DataLoader: Configured data loader for the dataset.
|
| 577 |
+
"""
|
| 578 |
+
|
| 579 |
+
# Initialize custom transformations for cropping and flipping
|
| 580 |
+
custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
|
| 581 |
+
custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
|
| 582 |
+
custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
|
| 583 |
+
|
| 584 |
+
# Create the primary dataset
|
| 585 |
+
dataset = BPMN_Dataset(
|
| 586 |
+
annotations=annotations1,
|
| 587 |
+
transform=transformation,
|
| 588 |
+
crop_transform=custom_crop_transform,
|
| 589 |
+
crop_prob=crop_prob,
|
| 590 |
+
rotate_90_proba=rotate_90_proba,
|
| 591 |
+
flip_transform=custom_flip_transform,
|
| 592 |
+
rotate_transform=custom_rotate_transform,
|
| 593 |
+
rotate_vertical=rotate_vertical,
|
| 594 |
+
new_size=new_size,
|
| 595 |
+
keep_ratio=keep_ratio,
|
| 596 |
+
model_type=model_type,
|
| 597 |
+
resize=resize
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# Optionally concatenate a second dataset
|
| 601 |
+
if annotations2:
|
| 602 |
+
dataset2 = BPMN_Dataset(
|
| 603 |
+
annotations=annotations2,
|
| 604 |
+
transform=transformation,
|
| 605 |
+
crop_transform=custom_crop_transform,
|
| 606 |
+
crop_prob=crop_prob,
|
| 607 |
+
rotate_90_proba=rotate_90_proba,
|
| 608 |
+
flip_transform=custom_flip_transform,
|
| 609 |
+
rotate_vertical=rotate_vertical,
|
| 610 |
+
new_size=new_size,
|
| 611 |
+
keep_ratio=keep_ratio,
|
| 612 |
+
model_type=model_type,
|
| 613 |
+
resize=resize
|
| 614 |
+
)
|
| 615 |
+
dataset = ConcatDataset([dataset, dataset2]) # Concatenate the two datasets
|
| 616 |
+
|
| 617 |
+
# Set the seed for reproducibility in random operations within transformations and data loading
|
| 618 |
+
random.seed(seed)
|
| 619 |
+
torch.manual_seed(seed)
|
| 620 |
+
|
| 621 |
+
# Create the DataLoader with the dataset
|
| 622 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
|
| 623 |
+
|
| 624 |
+
return data_loader
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def write_results(name_model,metrics_list,start_epoch):
|
| 629 |
+
with open('./results/'+ name_model+ '.txt', 'w') as f:
|
| 630 |
+
for i in range(len(metrics_list[0])):
|
| 631 |
+
f.write(f"{i+1+start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n")
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def find_other_keypoint(idx, keypoints, boxes):
|
| 635 |
+
box = boxes[idx]
|
| 636 |
+
key1,key2 = keypoints[idx]
|
| 637 |
+
x1, y1, x2, y2 = box
|
| 638 |
+
center = ((x1 + x2) // 2, (y1 + y2) // 2)
|
| 639 |
+
average_keypoint = (key1 + key2) // 2
|
| 640 |
+
#find the opposite keypoint to the center
|
| 641 |
+
if average_keypoint[0] < center[0]:
|
| 642 |
+
x = center[0] + abs(center[0] - average_keypoint[0])
|
| 643 |
+
else:
|
| 644 |
+
x = center[0] - abs(center[0] - average_keypoint[0])
|
| 645 |
+
if average_keypoint[1] < center[1]:
|
| 646 |
+
y = center[1] + abs(center[1] - average_keypoint[1])
|
| 647 |
+
else:
|
| 648 |
+
y = center[1] - abs(center[1] - average_keypoint[1])
|
| 649 |
+
return x, y, average_keypoint[0], average_keypoint[1]
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
|
| 653 |
+
"""
|
| 654 |
+
Filters overlapping boxes based on the Intersection over Union (IoU) metric, keeping only the boxes with the highest scores.
|
| 655 |
+
|
| 656 |
+
Parameters:
|
| 657 |
+
- boxes (np.ndarray): Array of bounding boxes with shape (N, 4), where each row contains [x_min, y_min, x_max, y_max].
|
| 658 |
+
- scores (np.ndarray): Array of scores for each box, reflecting the confidence of detection.
|
| 659 |
+
- labels (np.ndarray): Array of labels corresponding to each box.
|
| 660 |
+
- keypoints (np.ndarray): Array of keypoints associated with each box.
|
| 661 |
+
- iou_threshold (float): Threshold for IoU above which a box is considered overlapping.
|
| 662 |
+
|
| 663 |
+
Returns:
|
| 664 |
+
- tuple: Filtered boxes, scores, labels, and keypoints.
|
| 665 |
+
"""
|
| 666 |
+
# Calculate the area of each bounding box to use in IoU calculation.
|
| 667 |
+
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 668 |
+
|
| 669 |
+
# Sort the indices of the boxes based on their scores in descending order.
|
| 670 |
+
order = scores.argsort()[::-1]
|
| 671 |
+
|
| 672 |
+
keep = [] # List to store indices of boxes to keep.
|
| 673 |
+
|
| 674 |
+
while order.size > 0:
|
| 675 |
+
# Take the first index (highest score) from the sorted list.
|
| 676 |
+
i = order[0]
|
| 677 |
+
keep.append(i) # Add this index to 'keep' list.
|
| 678 |
+
|
| 679 |
+
# Compute the coordinates of the intersection rectangle.
|
| 680 |
+
xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
|
| 681 |
+
yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
|
| 682 |
+
xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
|
| 683 |
+
yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
|
| 684 |
+
|
| 685 |
+
# Compute the area of the intersection rectangle.
|
| 686 |
+
w = np.maximum(0.0, xx2 - xx1)
|
| 687 |
+
h = np.maximum(0.0, yy2 - yy1)
|
| 688 |
+
inter = w * h
|
| 689 |
+
|
| 690 |
+
# Calculate IoU and find boxes with IoU less than the threshold to keep.
|
| 691 |
+
iou = inter / (areas[i] + areas[order[1:]] - inter)
|
| 692 |
+
inds = np.where(iou <= iou_threshold)[0]
|
| 693 |
+
|
| 694 |
+
# Update the list of box indices to consider in the next iteration.
|
| 695 |
+
order = order[inds + 1] # Skip the first element since it's already included in 'keep'.
|
| 696 |
+
|
| 697 |
+
# Use the indices in 'keep' to select the boxes, scores, labels, and keypoints to return.
|
| 698 |
+
boxes = boxes[keep]
|
| 699 |
+
scores = scores[keep]
|
| 700 |
+
labels = labels[keep]
|
| 701 |
+
keypoints = keypoints[keep]
|
| 702 |
+
|
| 703 |
+
return boxes, scores, labels, keypoints
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def draw_annotations(image,
|
| 708 |
+
target=None,
|
| 709 |
+
prediction=None,
|
| 710 |
+
full_prediction=None,
|
| 711 |
+
text_predictions=None,
|
| 712 |
+
model_dict=class_dict,
|
| 713 |
+
draw_keypoints=False,
|
| 714 |
+
draw_boxes=False,
|
| 715 |
+
draw_text=False,
|
| 716 |
+
draw_links=False,
|
| 717 |
+
draw_twins=False,
|
| 718 |
+
write_class=False,
|
| 719 |
+
write_score=False,
|
| 720 |
+
write_text=False,
|
| 721 |
+
write_idx=False,
|
| 722 |
+
score_threshold=0.4,
|
| 723 |
+
keypoints_correction=False,
|
| 724 |
+
only_print=None,
|
| 725 |
+
axis=False,
|
| 726 |
+
return_image=False,
|
| 727 |
+
new_size=(1333,800),
|
| 728 |
+
resize=False):
|
| 729 |
+
"""
|
| 730 |
+
Draws annotations on images including bounding boxes, keypoints, links, and text.
|
| 731 |
+
|
| 732 |
+
Parameters:
|
| 733 |
+
- image (np.array): The image on which annotations will be drawn.
|
| 734 |
+
- target (dict): Ground truth data containing boxes, labels, etc.
|
| 735 |
+
- prediction (dict): Prediction data from a model.
|
| 736 |
+
- full_prediction (dict): Additional detailed prediction data, potentially including relationships.
|
| 737 |
+
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
|
| 738 |
+
- model_dict (dict): Mapping from class IDs to class names.
|
| 739 |
+
- draw_keypoints (bool): Flag to draw keypoints.
|
| 740 |
+
- draw_boxes (bool): Flag to draw bounding boxes.
|
| 741 |
+
- draw_text (bool): Flag to draw text annotations.
|
| 742 |
+
- draw_links (bool): Flag to draw links between annotations.
|
| 743 |
+
- draw_twins (bool): Flag to draw twins keypoints.
|
| 744 |
+
- write_class (bool): Flag to write class names near the annotations.
|
| 745 |
+
- write_score (bool): Flag to write scores near the annotations.
|
| 746 |
+
- write_text (bool): Flag to write OCR recognized text.
|
| 747 |
+
- score_threshold (float): Threshold for scores above which annotations will be drawn.
|
| 748 |
+
- only_print (str): Specific class name to filter annotations by.
|
| 749 |
+
- resize (bool): Whether to resize annotations to fit the image size.
|
| 750 |
+
"""
|
| 751 |
+
|
| 752 |
+
# Convert image to RGB (if not already in that format)
|
| 753 |
+
if prediction is None:
|
| 754 |
+
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
| 755 |
+
|
| 756 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 757 |
+
image_copy = image.copy()
|
| 758 |
+
scale = max(image.shape[0], image.shape[1]) / 1000
|
| 759 |
+
|
| 760 |
+
# Function to draw bounding boxes and keypoints
|
| 761 |
+
def draw(data,is_prediction=False):
|
| 762 |
+
""" Helper function to draw annotations based on provided data. """
|
| 763 |
+
|
| 764 |
+
for i in range(len(data['boxes'])):
|
| 765 |
+
if is_prediction:
|
| 766 |
+
box = data['boxes'][i].tolist()
|
| 767 |
+
x1, y1, x2, y2 = box
|
| 768 |
+
if resize:
|
| 769 |
+
x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 770 |
+
score = data['scores'][i].item()
|
| 771 |
+
if score < score_threshold:
|
| 772 |
+
continue
|
| 773 |
+
else:
|
| 774 |
+
box = data['boxes'][i].tolist()
|
| 775 |
+
x1, y1, x2, y2 = box
|
| 776 |
+
if draw_boxes:
|
| 777 |
+
if only_print is not None:
|
| 778 |
+
if data['labels'][i] != list(model_dict.values()).index(only_print):
|
| 779 |
+
continue
|
| 780 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2*scale))
|
| 781 |
+
if is_prediction and write_score:
|
| 782 |
+
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
|
| 783 |
+
|
| 784 |
+
if write_class and 'labels' in data:
|
| 785 |
+
class_id = data['labels'][i].item()
|
| 786 |
+
cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
|
| 787 |
+
|
| 788 |
+
if write_idx:
|
| 789 |
+
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
# Draw keypoints if available
|
| 793 |
+
if draw_keypoints and 'keypoints' in data:
|
| 794 |
+
if is_prediction and keypoints_correction:
|
| 795 |
+
for idx, (key1, key2) in enumerate(data['keypoints']):
|
| 796 |
+
if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
| 797 |
+
list(model_dict.values()).index('messageFlow'),
|
| 798 |
+
list(model_dict.values()).index('dataAssociation')]:
|
| 799 |
+
continue
|
| 800 |
+
# Calculate the Euclidean distance between the two keypoints
|
| 801 |
+
distance = np.linalg.norm(key1[:2] - key2[:2])
|
| 802 |
+
|
| 803 |
+
if distance < 5:
|
| 804 |
+
x_new,y_new, x,y = find_other_keypoint(idx, data['keypoints'], data['boxes'])
|
| 805 |
+
data['keypoints'][idx][0] = torch.tensor([x_new, y_new,1])
|
| 806 |
+
data['keypoints'][idx][1] = torch.tensor([x, y,1])
|
| 807 |
+
print("keypoint has been changed")
|
| 808 |
+
for i in range(len(data['keypoints'])):
|
| 809 |
+
kp = data['keypoints'][i]
|
| 810 |
+
for j in range(kp.shape[0]):
|
| 811 |
+
if is_prediction and data['labels'][i] != list(model_dict.values()).index('sequenceFlow') and data['labels'][i] != list(model_dict.values()).index('messageFlow') and data['labels'][i] != list(model_dict.values()).index('dataAssociation'):
|
| 812 |
+
continue
|
| 813 |
+
if is_prediction:
|
| 814 |
+
score = data['scores'][i]
|
| 815 |
+
if score < score_threshold:
|
| 816 |
+
continue
|
| 817 |
+
x,y,v = np.array(kp[j])
|
| 818 |
+
if resize:
|
| 819 |
+
x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 820 |
+
if j == 0:
|
| 821 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
|
| 822 |
+
else:
|
| 823 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
|
| 824 |
+
|
| 825 |
+
# Draw text predictions if available
|
| 826 |
+
if (draw_text or write_text) and text_predictions is not None:
|
| 827 |
+
for i in range(len(text_predictions[0])):
|
| 828 |
+
x1, y1, x2, y2 = text_predictions[0][i]
|
| 829 |
+
text = text_predictions[1][i]
|
| 830 |
+
if resize:
|
| 831 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
| 832 |
+
if draw_text:
|
| 833 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
| 834 |
+
if write_text:
|
| 835 |
+
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
|
| 836 |
+
|
| 837 |
+
def draw_with_links(full_prediction):
|
| 838 |
+
'''Draws links between objects based on the full prediction data.'''
|
| 839 |
+
#check if keypoints detected are the same
|
| 840 |
+
if draw_twins and full_prediction is not None:
|
| 841 |
+
# Pre-calculate indices for performance
|
| 842 |
+
circle_color = (0, 255, 0) # Green color for the circle
|
| 843 |
+
circle_radius = int(10 * scale) # Circle radius scaled by image scale
|
| 844 |
+
|
| 845 |
+
for idx, (key1, key2) in enumerate(full_prediction['keypoints']):
|
| 846 |
+
if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
| 847 |
+
list(model_dict.values()).index('messageFlow'),
|
| 848 |
+
list(model_dict.values()).index('dataAssociation')]:
|
| 849 |
+
continue
|
| 850 |
+
# Calculate the Euclidean distance between the two keypoints
|
| 851 |
+
distance = np.linalg.norm(key1[:2] - key2[:2])
|
| 852 |
+
if distance < 10:
|
| 853 |
+
x_new,y_new, x,y = find_other_keypoint(idx,full_prediction)
|
| 854 |
+
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
|
| 855 |
+
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
|
| 856 |
+
|
| 857 |
+
# Draw links between objects
|
| 858 |
+
if draw_links==True and full_prediction is not None:
|
| 859 |
+
for i, (start_idx, end_idx) in enumerate(full_prediction['links']):
|
| 860 |
+
if start_idx is None or end_idx is None:
|
| 861 |
+
continue
|
| 862 |
+
start_box = full_prediction['boxes'][start_idx]
|
| 863 |
+
end_box = full_prediction['boxes'][end_idx]
|
| 864 |
+
current_box = full_prediction['boxes'][i]
|
| 865 |
+
# Calculate the center of each bounding box
|
| 866 |
+
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
|
| 867 |
+
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
|
| 868 |
+
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
|
| 869 |
+
# Draw a line between the centers of the connected objects
|
| 870 |
+
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
|
| 871 |
+
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
|
| 872 |
+
|
| 873 |
+
i+=1
|
| 874 |
+
|
| 875 |
+
# Draw GT annotations
|
| 876 |
+
if target is not None:
|
| 877 |
+
draw(target, is_prediction=False)
|
| 878 |
+
# Draw predictions
|
| 879 |
+
if prediction is not None:
|
| 880 |
+
#prediction = prediction[0]
|
| 881 |
+
draw(prediction, is_prediction=True)
|
| 882 |
+
# Draw links with full predictions
|
| 883 |
+
if full_prediction is not None:
|
| 884 |
+
draw_with_links(full_prediction)
|
| 885 |
+
|
| 886 |
+
# Display the image
|
| 887 |
+
image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
|
| 888 |
+
plt.figure(figsize=(12, 12))
|
| 889 |
+
plt.imshow(image_copy)
|
| 890 |
+
if axis==False:
|
| 891 |
+
plt.axis('off')
|
| 892 |
+
plt.show()
|
| 893 |
+
|
| 894 |
+
if return_image:
|
| 895 |
+
return image_copy
|
| 896 |
+
|
| 897 |
+
def find_closest_object(keypoint, boxes, labels):
|
| 898 |
+
"""
|
| 899 |
+
Find the closest object to a keypoint based on their proximity.
|
| 900 |
+
|
| 901 |
+
Parameters:
|
| 902 |
+
- keypoint (numpy.ndarray): The coordinates of the keypoint.
|
| 903 |
+
- boxes (numpy.ndarray): The bounding boxes of the objects.
|
| 904 |
+
|
| 905 |
+
Returns:
|
| 906 |
+
- int or None: The index of the closest object to the keypoint, or None if no object is found.
|
| 907 |
+
"""
|
| 908 |
+
min_distance = float('inf')
|
| 909 |
+
closest_object_idx = None
|
| 910 |
+
# Iterate over each bounding box
|
| 911 |
+
for i, box in enumerate(boxes):
|
| 912 |
+
if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
|
| 913 |
+
list(class_dict.values()).index('messageFlow'),
|
| 914 |
+
list(class_dict.values()).index('dataAssociation'),
|
| 915 |
+
#list(class_dict.values()).index('pool'),
|
| 916 |
+
list(class_dict.values()).index('lane')]:
|
| 917 |
+
continue
|
| 918 |
+
x1, y1, x2, y2 = box
|
| 919 |
+
|
| 920 |
+
top = ((x1+x2)/2, y1)
|
| 921 |
+
bottom = ((x1+x2)/2, y2)
|
| 922 |
+
left = (x1, (y1+y2)/2)
|
| 923 |
+
right = (x2, (y1+y2)/2)
|
| 924 |
+
points = [left, top , right, bottom]
|
| 925 |
+
|
| 926 |
+
# Calculate the distance between the keypoint and the center of the bounding box
|
| 927 |
+
for point in points:
|
| 928 |
+
distance = np.linalg.norm(keypoint[:2] - point)
|
| 929 |
+
# Update the closest object index if this object is closer
|
| 930 |
+
if distance < min_distance:
|
| 931 |
+
min_distance = distance
|
| 932 |
+
closest_object_idx = i
|
| 933 |
+
best_point = point
|
| 934 |
+
|
| 935 |
+
return closest_object_idx, best_point
|
| 936 |
+
|