BenjiELCA commited on
Commit
acc7969
·
1 Parent(s): 2ecdafc

change the model object, correct some bug and add warnings

Browse files
Files changed (4) hide show
  1. app.py +25 -17
  2. modules/eval.py +28 -8
  3. modules/toXML.py +26 -2
  4. modules/utils.py +9 -1
app.py CHANGED
@@ -13,7 +13,7 @@ import gdown
13
 
14
  from modules.htlm_webpage import display_bpmn_xml
15
  from modules.OCR import text_prediction, filter_text, mapping_text, rescale
16
- from modules.utils import class_dict, arrow_dict, object_dict, find_closest_object
17
  from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element
18
  from modules.display import draw_stream
19
  from modules.eval import full_prediction
@@ -49,6 +49,8 @@ def modif_box_pos(pred, size):
49
  modified_pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
50
  return modified_pred['boxes']
51
 
 
 
52
  # Function to create a BPMN XML file from prediction results
53
  def create_XML(full_pred, text_mapping, scale):
54
  namespaces = {
@@ -58,22 +60,24 @@ def create_XML(full_pred, text_mapping, scale):
58
  'dc': 'http://www.omg.org/spec/DD/20100524/DC',
59
  'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
60
  }
 
61
 
62
  size_elements = {
63
- 'event': (43.2, 43.2),
64
- 'task': (120, 96),
65
- 'message': (43.2, 43.2),
66
- 'messageEvent': (43.2, 43.2),
67
- 'exclusiveGateway': (60, 60),
68
- 'parallelGateway': (60, 60),
69
- 'dataObject': (48, 72),
70
- 'dataStore': (72, 72),
71
- 'subProcess': (144, 108),
72
- 'eventBasedGateway': (60, 60),
73
- 'timerEvent': (48, 48),
74
  }
75
 
76
 
 
77
  definitions = ET.Element('bpmn:definitions', {
78
  'xmlns:xsi': namespaces['xsi'],
79
  'xmlns:bpmn': namespaces['bpmn'],
@@ -153,7 +157,7 @@ def load_models():
153
  model_arrow = get_arrow_model(len(arrow_dict),2)
154
 
155
  url_arrow = 'https://drive.google.com/uc?id=1vv1X_r_lZ8gnzMAIKxcVEb_T_Qb-NkyA'
156
- url_object = 'https://drive.google.com/uc?id=1lmeVBRfTrcrBsq3904m4TgppSA4nqsIi'
157
 
158
  # Define paths to save models
159
  output_arrow = 'model_arrow.pth'
@@ -190,7 +194,7 @@ def prepare_image(image, pad=True, new_size=(1333, 1333)):
190
 
191
  if pad:
192
  enhancer = ImageEnhance.Brightness(image)
193
- image = enhancer.enhance(1.5) # Adjust the brightness if necessary
194
  # Pad the resized image to make it exactly the desired size
195
  padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
196
  image = F.pad(image, padding, fill=200, padding_mode='edge')
@@ -324,7 +328,7 @@ def main():
324
  st.sidebar.subheader("Instructions:")
325
  st.sidebar.text("1. Upload you image")
326
  st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
327
- st.sidebar.text("3. Set the score threshold \n for prediction (default is 0.6)")
328
  st.sidebar.text("4. Click on 'Launch Prediction'")
329
  st.sidebar.text("5. You can now see the annotation \n and the BPMN XML result")
330
  st.sidebar.text("6. You can change the scale for \n the XML file (default is 1.0)")
@@ -410,7 +414,7 @@ def main():
410
  with col1:
411
  score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
412
  else:
413
- score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
414
 
415
  if st.button("Launch Prediction"):
416
  st.session_state.crop_image = cropped_image
@@ -425,7 +429,11 @@ def main():
425
  with st.spinner('Waiting for BPMN modeler...'):
426
  col1, col2 = st.columns(2)
427
  with col1:
428
- st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
 
 
 
 
429
  st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
430
  display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width))
431
 
 
13
 
14
  from modules.htlm_webpage import display_bpmn_xml
15
  from modules.OCR import text_prediction, filter_text, mapping_text, rescale
16
+ from modules.utils import class_dict, arrow_dict, object_dict
17
  from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element
18
  from modules.display import draw_stream
19
  from modules.eval import full_prediction
 
49
  modified_pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
50
  return modified_pred['boxes']
51
 
52
+
53
+
54
  # Function to create a BPMN XML file from prediction results
55
  def create_XML(full_pred, text_mapping, scale):
56
  namespaces = {
 
60
  'dc': 'http://www.omg.org/spec/DD/20100524/DC',
61
  'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
62
  }
63
+
64
 
65
  size_elements = {
66
+ 'event': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
67
+ 'task': (st.session_state.size_scale*120, st.session_state.size_scale*96),
68
+ 'message': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
69
+ 'messageEvent': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
70
+ 'exclusiveGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
71
+ 'parallelGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
72
+ 'dataObject': ( st.session_state.size_scale*48, st.session_state.size_scale*72),
73
+ 'dataStore': (st.session_state.size_scale*72, st.session_state.size_scale*72),
74
+ 'subProcess': (st.session_state.size_scale*144, st.session_state.size_scale*108),
75
+ 'eventBasedGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
76
+ 'timerEvent': (st.session_state.size_scale*48, st.session_state.size_scale*48),
77
  }
78
 
79
 
80
+
81
  definitions = ET.Element('bpmn:definitions', {
82
  'xmlns:xsi': namespaces['xsi'],
83
  'xmlns:bpmn': namespaces['bpmn'],
 
157
  model_arrow = get_arrow_model(len(arrow_dict),2)
158
 
159
  url_arrow = 'https://drive.google.com/uc?id=1vv1X_r_lZ8gnzMAIKxcVEb_T_Qb-NkyA'
160
+ url_object = 'https://drive.google.com/uc?id=1b1bqogxqdPS-SnvaOfWJGV1I1qOrTKh5'
161
 
162
  # Define paths to save models
163
  output_arrow = 'model_arrow.pth'
 
194
 
195
  if pad:
196
  enhancer = ImageEnhance.Brightness(image)
197
+ image = enhancer.enhance(1.0) # Adjust the brightness if necessary
198
  # Pad the resized image to make it exactly the desired size
199
  padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
200
  image = F.pad(image, padding, fill=200, padding_mode='edge')
 
328
  st.sidebar.subheader("Instructions:")
329
  st.sidebar.text("1. Upload you image")
330
  st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
331
+ st.sidebar.text("3. Set the score threshold \n for prediction (default is 0.5)")
332
  st.sidebar.text("4. Click on 'Launch Prediction'")
333
  st.sidebar.text("5. You can now see the annotation \n and the BPMN XML result")
334
  st.sidebar.text("6. You can change the scale for \n the XML file (default is 1.0)")
 
414
  with col1:
415
  score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
416
  else:
417
+ score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.6, step=0.05)
418
 
419
  if st.button("Launch Prediction"):
420
  st.session_state.crop_image = cropped_image
 
429
  with st.spinner('Waiting for BPMN modeler...'):
430
  col1, col2 = st.columns(2)
431
  with col1:
432
+ st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
433
+ if is_mobile is False:
434
+ st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
435
+ else:
436
+ st.session_state.size_scale = 1.0
437
  st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
438
  display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width))
439
 
modules/eval.py CHANGED
@@ -172,14 +172,17 @@ def mix_predictions(objects_pred, arrow_pred):
172
 
173
  return boxes, labels, scores, keypoints
174
 
175
- def regroup_elements_by_pool(boxes, labels, class_dict):
 
176
  """
177
  Regroups elements by the pool they belong to, and creates a single new pool for elements that are not in any existing pool.
 
178
 
179
  Parameters:
180
  - boxes (list): List of bounding boxes.
181
  - labels (list): List of labels corresponding to each bounding box.
182
  - class_dict (dict): Dictionary mapping class indices to class names.
 
183
 
184
  Returns:
185
  - dict: A dictionary where each key is a pool's index and the value is a list of elements within that pool.
@@ -187,14 +190,29 @@ def regroup_elements_by_pool(boxes, labels, class_dict):
187
  # Initialize a dictionary to hold the elements in each pool
188
  pool_dict = {}
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  # Identify the bounding boxes of the pools
191
  pool_indices = [i for i, label in enumerate(labels) if (class_dict[label.item()] == 'pool')]
192
  pool_boxes = [boxes[i] for i in pool_indices]
193
 
 
194
  if not pool_indices:
195
  # If no pools or lanes are detected, create a single pool with all elements
196
  labels = np.append(labels, list(class_dict.values()).index('pool'))
197
- pool_dict[len(labels)-1] = list(range(len(boxes)))
198
  else:
199
  # Initialize each pool index with an empty list
200
  for pool_index in pool_indices:
@@ -232,7 +250,8 @@ def regroup_elements_by_pool(boxes, labels, class_dict):
232
  # Merge non-empty pools followed by empty pools
233
  pool_dict = {**non_empty_pools, **empty_pools}
234
 
235
- return pool_dict, labels
 
236
 
237
 
238
  def create_links(keypoints, boxes, labels, class_dict):
@@ -260,8 +279,7 @@ def create_links(keypoints, boxes, labels, class_dict):
260
 
261
  return links, best_points
262
 
263
- def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
264
-
265
  for pool_index, elements in pool_dict.items():
266
  print(f"Pool {pool_index} contains elements: {elements}")
267
  #check if the label sequenceflow is good
@@ -307,10 +325,12 @@ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_d
307
  #delete pool that are have only messageFlow on it
308
  delete_pool = []
309
  for pool_index, elements in pool_dict.items():
310
- if all([labels[i] == list(class_dict.values()).index('messageFlow') for i in elements]):
 
 
311
  if len(elements) > 0:
312
  delete_pool.append(pool_dict[pool_index])
313
- print(f"Pool {pool_index} contains only messageFlow elements, deleting it")
314
 
315
  #sort index
316
  delete_pool = sorted(delete_pool, reverse=True)
@@ -371,7 +391,7 @@ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_t
371
  boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
372
 
373
  # Regroup elements by pool
374
- pool_dict, labels = regroup_elements_by_pool(boxes,labels, class_dict)
375
  # Create links between elements
376
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
377
  #Correct the labels of some sequenceflow that cross multiple pool
 
172
 
173
  return boxes, labels, scores, keypoints
174
 
175
+
176
+ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.3):
177
  """
178
  Regroups elements by the pool they belong to, and creates a single new pool for elements that are not in any existing pool.
179
+ Filters out pools that have an IoU greater than the specified threshold.
180
 
181
  Parameters:
182
  - boxes (list): List of bounding boxes.
183
  - labels (list): List of labels corresponding to each bounding box.
184
  - class_dict (dict): Dictionary mapping class indices to class names.
185
+ - iou_threshold (float): IoU threshold for filtering pools.
186
 
187
  Returns:
188
  - dict: A dictionary where each key is a pool's index and the value is a list of elements within that pool.
 
190
  # Initialize a dictionary to hold the elements in each pool
191
  pool_dict = {}
192
 
193
+ # Filter out pools with IoU greater than the threshold
194
+ to_delete = []
195
+ for i in range(len(boxes)):
196
+ for j in range(i + 1, len(boxes)):
197
+ if labels[i] == labels[j] and labels[i] == list(class_dict.values()).index('pool'):
198
+ if iou(np.array(boxes[i]), np.array(boxes[j])) > iou_threshold:
199
+ to_delete.append(j)
200
+
201
+
202
+ boxes = np.delete(boxes, to_delete, axis=0)
203
+ labels = np.delete(labels, to_delete)
204
+ scores = np.delete(scores, to_delete)
205
+ keypoints = np.delete(keypoints, to_delete, axis=0)
206
+
207
  # Identify the bounding boxes of the pools
208
  pool_indices = [i for i, label in enumerate(labels) if (class_dict[label.item()] == 'pool')]
209
  pool_boxes = [boxes[i] for i in pool_indices]
210
 
211
+
212
  if not pool_indices:
213
  # If no pools or lanes are detected, create a single pool with all elements
214
  labels = np.append(labels, list(class_dict.values()).index('pool'))
215
+ pool_dict[len(labels) - 1] = list(range(len(boxes)))
216
  else:
217
  # Initialize each pool index with an empty list
218
  for pool_index in pool_indices:
 
250
  # Merge non-empty pools followed by empty pools
251
  pool_dict = {**non_empty_pools, **empty_pools}
252
 
253
+ return pool_dict, boxes, labels, scores, keypoints
254
+
255
 
256
 
257
  def create_links(keypoints, boxes, labels, class_dict):
 
279
 
280
  return links, best_points
281
 
282
+ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
 
283
  for pool_index, elements in pool_dict.items():
284
  print(f"Pool {pool_index} contains elements: {elements}")
285
  #check if the label sequenceflow is good
 
325
  #delete pool that are have only messageFlow on it
326
  delete_pool = []
327
  for pool_index, elements in pool_dict.items():
328
+ if all([labels[i] in [list(class_dict.values()).index('messageFlow'),
329
+ list(class_dict.values()).index('sequenceFlow'),
330
+ list(class_dict.values()).index('dataAssociation')] for i in elements]):
331
  if len(elements) > 0:
332
  delete_pool.append(pool_dict[pool_index])
333
+ print(f"Pool {pool_index} contains only arrow elements, deleting it")
334
 
335
  #sort index
336
  delete_pool = sorted(delete_pool, reverse=True)
 
391
  boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
392
 
393
  # Regroup elements by pool
394
+ pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
395
  # Create links between elements
396
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
397
  #Correct the labels of some sequenceflow that cross multiple pool
modules/toXML.py CHANGED
@@ -1,5 +1,5 @@
1
  import xml.etree.ElementTree as ET
2
- from modules.utils import class_dict
3
 
4
  def rescale(scale, boxes):
5
  for i in range(len(boxes)):
@@ -317,6 +317,17 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
317
 
318
  source_idx = data['BPMN_id'].index(source_id)
319
  target_idx = data['BPMN_id'].index(target_id)
 
 
 
 
 
 
 
 
 
 
 
320
  name_source = source_id.split('_')[0]
321
  name_target = target_id.split('_')[0]
322
 
@@ -324,6 +335,10 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
324
  source_x, source_y = data['boxes'][source_idx][:2]
325
  target_x, target_y = data['boxes'][target_idx][:2]
326
 
 
 
 
 
327
  if pos_source == 'left':
328
  source_x = source_x
329
  source_y += size[name_source][1]/2
@@ -352,8 +367,13 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
352
 
353
  return [(source_x, source_y), (target_x, target_y)]
354
 
355
- def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
356
  source_idx, target_idx = data['links'][idx]
 
 
 
 
 
357
  source_id, target_id = data['BPMN_id'][source_idx], data['BPMN_id'][target_idx]
358
  if message:
359
  element_id = f'messageflow_{source_id}_{target_id}'
@@ -373,9 +393,13 @@ def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=Fal
373
  element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=XML_source_id, targetRef=XML_target_id, name=text_mapping[data['BPMN_id'][idx]])
374
  else:
375
  waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
 
 
376
  element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
377
  else:
378
  waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
 
 
379
  element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
380
  add_diagram_edge(bpmn, element_id, waypoints)
381
 
 
1
  import xml.etree.ElementTree as ET
2
+ from modules.utils import class_dict, error, warning
3
 
4
  def rescale(scale, boxes):
5
  for i in range(len(boxes)):
 
317
 
318
  source_idx = data['BPMN_id'].index(source_id)
319
  target_idx = data['BPMN_id'].index(target_id)
320
+
321
+ if source_idx==target_idx:
322
+ warning()
323
+ #return [data['keypoints'][current_idx][0][:2], data['keypoints'][current_idx][1][:2]]
324
+ return None
325
+
326
+ if source_idx is None or target_idx is None:
327
+ warning()
328
+ return [(source_x, source_y), (target_x, target_y)]
329
+
330
+
331
  name_source = source_id.split('_')[0]
332
  name_target = target_id.split('_')[0]
333
 
 
335
  source_x, source_y = data['boxes'][source_idx][:2]
336
  target_x, target_y = data['boxes'][target_idx][:2]
337
 
338
+ if name_source == 'pool' or name_target == 'pool':
339
+ warning()
340
+ return [(source_x, source_y), (target_x, target_y)]
341
+
342
  if pos_source == 'left':
343
  source_x = source_x
344
  source_y += size[name_source][1]/2
 
367
 
368
  return [(source_x, source_y), (target_x, target_y)]
369
 
370
+ def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
371
  source_idx, target_idx = data['links'][idx]
372
+
373
+ if source_idx is None or target_idx is None:
374
+ warning()
375
+ return
376
+
377
  source_id, target_id = data['BPMN_id'][source_idx], data['BPMN_id'][target_idx]
378
  if message:
379
  element_id = f'messageflow_{source_id}_{target_id}'
 
393
  element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=XML_source_id, targetRef=XML_target_id, name=text_mapping[data['BPMN_id'][idx]])
394
  else:
395
  waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
396
+ if waypoints is None:
397
+ return
398
  element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
399
  else:
400
  waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
401
+ if waypoints is None:
402
+ return
403
  element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
404
  add_diagram_edge(bpmn, element_id, waypoints)
405
 
modules/utils.py CHANGED
@@ -17,6 +17,7 @@ import time
17
  from torch.optim import AdamW
18
  import copy
19
  from torchvision import transforms
 
20
 
21
 
22
  object_dict = {
@@ -912,8 +913,9 @@ def find_closest_object(keypoint, boxes, labels):
912
  Returns:
913
  - int or None: The index of the closest object to the keypoint, or None if no object is found.
914
  """
915
- min_distance = float('inf')
916
  closest_object_idx = None
 
 
917
  # Iterate over each bounding box
918
  for i, box in enumerate(boxes):
919
  if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
@@ -943,3 +945,9 @@ def find_closest_object(keypoint, boxes, labels):
943
 
944
  return closest_object_idx, best_point
945
 
 
 
 
 
 
 
 
17
  from torch.optim import AdamW
18
  import copy
19
  from torchvision import transforms
20
+ import streamlit as st
21
 
22
 
23
  object_dict = {
 
913
  Returns:
914
  - int or None: The index of the closest object to the keypoint, or None if no object is found.
915
  """
 
916
  closest_object_idx = None
917
+ best_point = None
918
+ min_distance = float('inf')
919
  # Iterate over each bounding box
920
  for i, box in enumerate(boxes):
921
  if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
 
945
 
946
  return closest_object_idx, best_point
947
 
948
+
949
+ def error():
950
+ st.error('There is an error in the detection', icon="🚨")
951
+
952
+ def warning():
953
+ st.warning('Some element are not detected, verify your parameters', icon="⚠️")