Spaces:
Sleeping
Sleeping
change the model object, correct some bug and add warnings
Browse files- app.py +25 -17
- modules/eval.py +28 -8
- modules/toXML.py +26 -2
- modules/utils.py +9 -1
app.py
CHANGED
@@ -13,7 +13,7 @@ import gdown
|
|
13 |
|
14 |
from modules.htlm_webpage import display_bpmn_xml
|
15 |
from modules.OCR import text_prediction, filter_text, mapping_text, rescale
|
16 |
-
from modules.utils import class_dict, arrow_dict, object_dict
|
17 |
from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element
|
18 |
from modules.display import draw_stream
|
19 |
from modules.eval import full_prediction
|
@@ -49,6 +49,8 @@ def modif_box_pos(pred, size):
|
|
49 |
modified_pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
|
50 |
return modified_pred['boxes']
|
51 |
|
|
|
|
|
52 |
# Function to create a BPMN XML file from prediction results
|
53 |
def create_XML(full_pred, text_mapping, scale):
|
54 |
namespaces = {
|
@@ -58,22 +60,24 @@ def create_XML(full_pred, text_mapping, scale):
|
|
58 |
'dc': 'http://www.omg.org/spec/DD/20100524/DC',
|
59 |
'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
|
60 |
}
|
|
|
61 |
|
62 |
size_elements = {
|
63 |
-
'event': (43.2, 43.2),
|
64 |
-
'task': (120, 96),
|
65 |
-
'message': (43.2, 43.2),
|
66 |
-
'messageEvent': (43.2, 43.2),
|
67 |
-
'exclusiveGateway': (60, 60),
|
68 |
-
'parallelGateway': (60, 60),
|
69 |
-
'dataObject': (48, 72),
|
70 |
-
'dataStore': (72, 72),
|
71 |
-
'subProcess': (144, 108),
|
72 |
-
'eventBasedGateway': (60, 60),
|
73 |
-
'timerEvent': (48, 48),
|
74 |
}
|
75 |
|
76 |
|
|
|
77 |
definitions = ET.Element('bpmn:definitions', {
|
78 |
'xmlns:xsi': namespaces['xsi'],
|
79 |
'xmlns:bpmn': namespaces['bpmn'],
|
@@ -153,7 +157,7 @@ def load_models():
|
|
153 |
model_arrow = get_arrow_model(len(arrow_dict),2)
|
154 |
|
155 |
url_arrow = 'https://drive.google.com/uc?id=1vv1X_r_lZ8gnzMAIKxcVEb_T_Qb-NkyA'
|
156 |
-
url_object = 'https://drive.google.com/uc?id=
|
157 |
|
158 |
# Define paths to save models
|
159 |
output_arrow = 'model_arrow.pth'
|
@@ -190,7 +194,7 @@ def prepare_image(image, pad=True, new_size=(1333, 1333)):
|
|
190 |
|
191 |
if pad:
|
192 |
enhancer = ImageEnhance.Brightness(image)
|
193 |
-
image = enhancer.enhance(1.
|
194 |
# Pad the resized image to make it exactly the desired size
|
195 |
padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
|
196 |
image = F.pad(image, padding, fill=200, padding_mode='edge')
|
@@ -324,7 +328,7 @@ def main():
|
|
324 |
st.sidebar.subheader("Instructions:")
|
325 |
st.sidebar.text("1. Upload you image")
|
326 |
st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
|
327 |
-
st.sidebar.text("3. Set the score threshold \n for prediction (default is 0.
|
328 |
st.sidebar.text("4. Click on 'Launch Prediction'")
|
329 |
st.sidebar.text("5. You can now see the annotation \n and the BPMN XML result")
|
330 |
st.sidebar.text("6. You can change the scale for \n the XML file (default is 1.0)")
|
@@ -410,7 +414,7 @@ def main():
|
|
410 |
with col1:
|
411 |
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
|
412 |
else:
|
413 |
-
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.
|
414 |
|
415 |
if st.button("Launch Prediction"):
|
416 |
st.session_state.crop_image = cropped_image
|
@@ -425,7 +429,11 @@ def main():
|
|
425 |
with st.spinner('Waiting for BPMN modeler...'):
|
426 |
col1, col2 = st.columns(2)
|
427 |
with col1:
|
428 |
-
st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
|
|
|
|
|
|
|
|
|
429 |
st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
|
430 |
display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width))
|
431 |
|
|
|
13 |
|
14 |
from modules.htlm_webpage import display_bpmn_xml
|
15 |
from modules.OCR import text_prediction, filter_text, mapping_text, rescale
|
16 |
+
from modules.utils import class_dict, arrow_dict, object_dict
|
17 |
from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element
|
18 |
from modules.display import draw_stream
|
19 |
from modules.eval import full_prediction
|
|
|
49 |
modified_pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
|
50 |
return modified_pred['boxes']
|
51 |
|
52 |
+
|
53 |
+
|
54 |
# Function to create a BPMN XML file from prediction results
|
55 |
def create_XML(full_pred, text_mapping, scale):
|
56 |
namespaces = {
|
|
|
60 |
'dc': 'http://www.omg.org/spec/DD/20100524/DC',
|
61 |
'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
|
62 |
}
|
63 |
+
|
64 |
|
65 |
size_elements = {
|
66 |
+
'event': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
|
67 |
+
'task': (st.session_state.size_scale*120, st.session_state.size_scale*96),
|
68 |
+
'message': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
|
69 |
+
'messageEvent': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
|
70 |
+
'exclusiveGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
|
71 |
+
'parallelGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
|
72 |
+
'dataObject': ( st.session_state.size_scale*48, st.session_state.size_scale*72),
|
73 |
+
'dataStore': (st.session_state.size_scale*72, st.session_state.size_scale*72),
|
74 |
+
'subProcess': (st.session_state.size_scale*144, st.session_state.size_scale*108),
|
75 |
+
'eventBasedGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
|
76 |
+
'timerEvent': (st.session_state.size_scale*48, st.session_state.size_scale*48),
|
77 |
}
|
78 |
|
79 |
|
80 |
+
|
81 |
definitions = ET.Element('bpmn:definitions', {
|
82 |
'xmlns:xsi': namespaces['xsi'],
|
83 |
'xmlns:bpmn': namespaces['bpmn'],
|
|
|
157 |
model_arrow = get_arrow_model(len(arrow_dict),2)
|
158 |
|
159 |
url_arrow = 'https://drive.google.com/uc?id=1vv1X_r_lZ8gnzMAIKxcVEb_T_Qb-NkyA'
|
160 |
+
url_object = 'https://drive.google.com/uc?id=1b1bqogxqdPS-SnvaOfWJGV1I1qOrTKh5'
|
161 |
|
162 |
# Define paths to save models
|
163 |
output_arrow = 'model_arrow.pth'
|
|
|
194 |
|
195 |
if pad:
|
196 |
enhancer = ImageEnhance.Brightness(image)
|
197 |
+
image = enhancer.enhance(1.0) # Adjust the brightness if necessary
|
198 |
# Pad the resized image to make it exactly the desired size
|
199 |
padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
|
200 |
image = F.pad(image, padding, fill=200, padding_mode='edge')
|
|
|
328 |
st.sidebar.subheader("Instructions:")
|
329 |
st.sidebar.text("1. Upload you image")
|
330 |
st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
|
331 |
+
st.sidebar.text("3. Set the score threshold \n for prediction (default is 0.5)")
|
332 |
st.sidebar.text("4. Click on 'Launch Prediction'")
|
333 |
st.sidebar.text("5. You can now see the annotation \n and the BPMN XML result")
|
334 |
st.sidebar.text("6. You can change the scale for \n the XML file (default is 1.0)")
|
|
|
414 |
with col1:
|
415 |
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
|
416 |
else:
|
417 |
+
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.6, step=0.05)
|
418 |
|
419 |
if st.button("Launch Prediction"):
|
420 |
st.session_state.crop_image = cropped_image
|
|
|
429 |
with st.spinner('Waiting for BPMN modeler...'):
|
430 |
col1, col2 = st.columns(2)
|
431 |
with col1:
|
432 |
+
st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
|
433 |
+
if is_mobile is False:
|
434 |
+
st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
|
435 |
+
else:
|
436 |
+
st.session_state.size_scale = 1.0
|
437 |
st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
|
438 |
display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width))
|
439 |
|
modules/eval.py
CHANGED
@@ -172,14 +172,17 @@ def mix_predictions(objects_pred, arrow_pred):
|
|
172 |
|
173 |
return boxes, labels, scores, keypoints
|
174 |
|
175 |
-
|
|
|
176 |
"""
|
177 |
Regroups elements by the pool they belong to, and creates a single new pool for elements that are not in any existing pool.
|
|
|
178 |
|
179 |
Parameters:
|
180 |
- boxes (list): List of bounding boxes.
|
181 |
- labels (list): List of labels corresponding to each bounding box.
|
182 |
- class_dict (dict): Dictionary mapping class indices to class names.
|
|
|
183 |
|
184 |
Returns:
|
185 |
- dict: A dictionary where each key is a pool's index and the value is a list of elements within that pool.
|
@@ -187,14 +190,29 @@ def regroup_elements_by_pool(boxes, labels, class_dict):
|
|
187 |
# Initialize a dictionary to hold the elements in each pool
|
188 |
pool_dict = {}
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
# Identify the bounding boxes of the pools
|
191 |
pool_indices = [i for i, label in enumerate(labels) if (class_dict[label.item()] == 'pool')]
|
192 |
pool_boxes = [boxes[i] for i in pool_indices]
|
193 |
|
|
|
194 |
if not pool_indices:
|
195 |
# If no pools or lanes are detected, create a single pool with all elements
|
196 |
labels = np.append(labels, list(class_dict.values()).index('pool'))
|
197 |
-
pool_dict[len(labels)-1] = list(range(len(boxes)))
|
198 |
else:
|
199 |
# Initialize each pool index with an empty list
|
200 |
for pool_index in pool_indices:
|
@@ -232,7 +250,8 @@ def regroup_elements_by_pool(boxes, labels, class_dict):
|
|
232 |
# Merge non-empty pools followed by empty pools
|
233 |
pool_dict = {**non_empty_pools, **empty_pools}
|
234 |
|
235 |
-
return pool_dict, labels
|
|
|
236 |
|
237 |
|
238 |
def create_links(keypoints, boxes, labels, class_dict):
|
@@ -260,8 +279,7 @@ def create_links(keypoints, boxes, labels, class_dict):
|
|
260 |
|
261 |
return links, best_points
|
262 |
|
263 |
-
def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
|
264 |
-
|
265 |
for pool_index, elements in pool_dict.items():
|
266 |
print(f"Pool {pool_index} contains elements: {elements}")
|
267 |
#check if the label sequenceflow is good
|
@@ -307,10 +325,12 @@ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_d
|
|
307 |
#delete pool that are have only messageFlow on it
|
308 |
delete_pool = []
|
309 |
for pool_index, elements in pool_dict.items():
|
310 |
-
if all([labels[i]
|
|
|
|
|
311 |
if len(elements) > 0:
|
312 |
delete_pool.append(pool_dict[pool_index])
|
313 |
-
print(f"Pool {pool_index} contains only
|
314 |
|
315 |
#sort index
|
316 |
delete_pool = sorted(delete_pool, reverse=True)
|
@@ -371,7 +391,7 @@ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_t
|
|
371 |
boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
|
372 |
|
373 |
# Regroup elements by pool
|
374 |
-
pool_dict, labels = regroup_elements_by_pool(boxes,labels, class_dict)
|
375 |
# Create links between elements
|
376 |
flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
|
377 |
#Correct the labels of some sequenceflow that cross multiple pool
|
|
|
172 |
|
173 |
return boxes, labels, scores, keypoints
|
174 |
|
175 |
+
|
176 |
+
def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.3):
|
177 |
"""
|
178 |
Regroups elements by the pool they belong to, and creates a single new pool for elements that are not in any existing pool.
|
179 |
+
Filters out pools that have an IoU greater than the specified threshold.
|
180 |
|
181 |
Parameters:
|
182 |
- boxes (list): List of bounding boxes.
|
183 |
- labels (list): List of labels corresponding to each bounding box.
|
184 |
- class_dict (dict): Dictionary mapping class indices to class names.
|
185 |
+
- iou_threshold (float): IoU threshold for filtering pools.
|
186 |
|
187 |
Returns:
|
188 |
- dict: A dictionary where each key is a pool's index and the value is a list of elements within that pool.
|
|
|
190 |
# Initialize a dictionary to hold the elements in each pool
|
191 |
pool_dict = {}
|
192 |
|
193 |
+
# Filter out pools with IoU greater than the threshold
|
194 |
+
to_delete = []
|
195 |
+
for i in range(len(boxes)):
|
196 |
+
for j in range(i + 1, len(boxes)):
|
197 |
+
if labels[i] == labels[j] and labels[i] == list(class_dict.values()).index('pool'):
|
198 |
+
if iou(np.array(boxes[i]), np.array(boxes[j])) > iou_threshold:
|
199 |
+
to_delete.append(j)
|
200 |
+
|
201 |
+
|
202 |
+
boxes = np.delete(boxes, to_delete, axis=0)
|
203 |
+
labels = np.delete(labels, to_delete)
|
204 |
+
scores = np.delete(scores, to_delete)
|
205 |
+
keypoints = np.delete(keypoints, to_delete, axis=0)
|
206 |
+
|
207 |
# Identify the bounding boxes of the pools
|
208 |
pool_indices = [i for i, label in enumerate(labels) if (class_dict[label.item()] == 'pool')]
|
209 |
pool_boxes = [boxes[i] for i in pool_indices]
|
210 |
|
211 |
+
|
212 |
if not pool_indices:
|
213 |
# If no pools or lanes are detected, create a single pool with all elements
|
214 |
labels = np.append(labels, list(class_dict.values()).index('pool'))
|
215 |
+
pool_dict[len(labels) - 1] = list(range(len(boxes)))
|
216 |
else:
|
217 |
# Initialize each pool index with an empty list
|
218 |
for pool_index in pool_indices:
|
|
|
250 |
# Merge non-empty pools followed by empty pools
|
251 |
pool_dict = {**non_empty_pools, **empty_pools}
|
252 |
|
253 |
+
return pool_dict, boxes, labels, scores, keypoints
|
254 |
+
|
255 |
|
256 |
|
257 |
def create_links(keypoints, boxes, labels, class_dict):
|
|
|
279 |
|
280 |
return links, best_points
|
281 |
|
282 |
+
def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
|
|
|
283 |
for pool_index, elements in pool_dict.items():
|
284 |
print(f"Pool {pool_index} contains elements: {elements}")
|
285 |
#check if the label sequenceflow is good
|
|
|
325 |
#delete pool that are have only messageFlow on it
|
326 |
delete_pool = []
|
327 |
for pool_index, elements in pool_dict.items():
|
328 |
+
if all([labels[i] in [list(class_dict.values()).index('messageFlow'),
|
329 |
+
list(class_dict.values()).index('sequenceFlow'),
|
330 |
+
list(class_dict.values()).index('dataAssociation')] for i in elements]):
|
331 |
if len(elements) > 0:
|
332 |
delete_pool.append(pool_dict[pool_index])
|
333 |
+
print(f"Pool {pool_index} contains only arrow elements, deleting it")
|
334 |
|
335 |
#sort index
|
336 |
delete_pool = sorted(delete_pool, reverse=True)
|
|
|
391 |
boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
|
392 |
|
393 |
# Regroup elements by pool
|
394 |
+
pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
|
395 |
# Create links between elements
|
396 |
flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
|
397 |
#Correct the labels of some sequenceflow that cross multiple pool
|
modules/toXML.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import xml.etree.ElementTree as ET
|
2 |
-
from modules.utils import class_dict
|
3 |
|
4 |
def rescale(scale, boxes):
|
5 |
for i in range(len(boxes)):
|
@@ -317,6 +317,17 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
317 |
|
318 |
source_idx = data['BPMN_id'].index(source_id)
|
319 |
target_idx = data['BPMN_id'].index(target_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
name_source = source_id.split('_')[0]
|
321 |
name_target = target_id.split('_')[0]
|
322 |
|
@@ -324,6 +335,10 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
324 |
source_x, source_y = data['boxes'][source_idx][:2]
|
325 |
target_x, target_y = data['boxes'][target_idx][:2]
|
326 |
|
|
|
|
|
|
|
|
|
327 |
if pos_source == 'left':
|
328 |
source_x = source_x
|
329 |
source_y += size[name_source][1]/2
|
@@ -352,8 +367,13 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
352 |
|
353 |
return [(source_x, source_y), (target_x, target_y)]
|
354 |
|
355 |
-
def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
|
356 |
source_idx, target_idx = data['links'][idx]
|
|
|
|
|
|
|
|
|
|
|
357 |
source_id, target_id = data['BPMN_id'][source_idx], data['BPMN_id'][target_idx]
|
358 |
if message:
|
359 |
element_id = f'messageflow_{source_id}_{target_id}'
|
@@ -373,9 +393,13 @@ def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=Fal
|
|
373 |
element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=XML_source_id, targetRef=XML_target_id, name=text_mapping[data['BPMN_id'][idx]])
|
374 |
else:
|
375 |
waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
|
|
|
|
|
376 |
element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
377 |
else:
|
378 |
waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
|
|
|
|
|
379 |
element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
380 |
add_diagram_edge(bpmn, element_id, waypoints)
|
381 |
|
|
|
1 |
import xml.etree.ElementTree as ET
|
2 |
+
from modules.utils import class_dict, error, warning
|
3 |
|
4 |
def rescale(scale, boxes):
|
5 |
for i in range(len(boxes)):
|
|
|
317 |
|
318 |
source_idx = data['BPMN_id'].index(source_id)
|
319 |
target_idx = data['BPMN_id'].index(target_id)
|
320 |
+
|
321 |
+
if source_idx==target_idx:
|
322 |
+
warning()
|
323 |
+
#return [data['keypoints'][current_idx][0][:2], data['keypoints'][current_idx][1][:2]]
|
324 |
+
return None
|
325 |
+
|
326 |
+
if source_idx is None or target_idx is None:
|
327 |
+
warning()
|
328 |
+
return [(source_x, source_y), (target_x, target_y)]
|
329 |
+
|
330 |
+
|
331 |
name_source = source_id.split('_')[0]
|
332 |
name_target = target_id.split('_')[0]
|
333 |
|
|
|
335 |
source_x, source_y = data['boxes'][source_idx][:2]
|
336 |
target_x, target_y = data['boxes'][target_idx][:2]
|
337 |
|
338 |
+
if name_source == 'pool' or name_target == 'pool':
|
339 |
+
warning()
|
340 |
+
return [(source_x, source_y), (target_x, target_y)]
|
341 |
+
|
342 |
if pos_source == 'left':
|
343 |
source_x = source_x
|
344 |
source_y += size[name_source][1]/2
|
|
|
367 |
|
368 |
return [(source_x, source_y), (target_x, target_y)]
|
369 |
|
370 |
+
def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
|
371 |
source_idx, target_idx = data['links'][idx]
|
372 |
+
|
373 |
+
if source_idx is None or target_idx is None:
|
374 |
+
warning()
|
375 |
+
return
|
376 |
+
|
377 |
source_id, target_id = data['BPMN_id'][source_idx], data['BPMN_id'][target_idx]
|
378 |
if message:
|
379 |
element_id = f'messageflow_{source_id}_{target_id}'
|
|
|
393 |
element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=XML_source_id, targetRef=XML_target_id, name=text_mapping[data['BPMN_id'][idx]])
|
394 |
else:
|
395 |
waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
|
396 |
+
if waypoints is None:
|
397 |
+
return
|
398 |
element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
399 |
else:
|
400 |
waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
|
401 |
+
if waypoints is None:
|
402 |
+
return
|
403 |
element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
404 |
add_diagram_edge(bpmn, element_id, waypoints)
|
405 |
|
modules/utils.py
CHANGED
@@ -17,6 +17,7 @@ import time
|
|
17 |
from torch.optim import AdamW
|
18 |
import copy
|
19 |
from torchvision import transforms
|
|
|
20 |
|
21 |
|
22 |
object_dict = {
|
@@ -912,8 +913,9 @@ def find_closest_object(keypoint, boxes, labels):
|
|
912 |
Returns:
|
913 |
- int or None: The index of the closest object to the keypoint, or None if no object is found.
|
914 |
"""
|
915 |
-
min_distance = float('inf')
|
916 |
closest_object_idx = None
|
|
|
|
|
917 |
# Iterate over each bounding box
|
918 |
for i, box in enumerate(boxes):
|
919 |
if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
|
@@ -943,3 +945,9 @@ def find_closest_object(keypoint, boxes, labels):
|
|
943 |
|
944 |
return closest_object_idx, best_point
|
945 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from torch.optim import AdamW
|
18 |
import copy
|
19 |
from torchvision import transforms
|
20 |
+
import streamlit as st
|
21 |
|
22 |
|
23 |
object_dict = {
|
|
|
913 |
Returns:
|
914 |
- int or None: The index of the closest object to the keypoint, or None if no object is found.
|
915 |
"""
|
|
|
916 |
closest_object_idx = None
|
917 |
+
best_point = None
|
918 |
+
min_distance = float('inf')
|
919 |
# Iterate over each bounding box
|
920 |
for i, box in enumerate(boxes):
|
921 |
if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
|
|
|
945 |
|
946 |
return closest_object_idx, best_point
|
947 |
|
948 |
+
|
949 |
+
def error():
|
950 |
+
st.error('There is an error in the detection', icon="🚨")
|
951 |
+
|
952 |
+
def warning():
|
953 |
+
st.warning('Some element are not detected, verify your parameters', icon="⚠️")
|