Spaces:
Running
Running
add verification of length and ready for demo
Browse files- app.py +84 -10
- modules/display.py +2 -0
- modules/streamlit_utils.py +1 -1
- modules/toXML.py +10 -2
app.py
CHANGED
|
@@ -13,7 +13,7 @@ from glob import glob
|
|
| 13 |
from streamlit_image_annotation import detection
|
| 14 |
from modules.toXML import create_XML
|
| 15 |
from modules.eval import develop_prediction, generate_data
|
| 16 |
-
from modules.utils import class_dict
|
| 17 |
|
| 18 |
def configure_page():
|
| 19 |
st.set_page_config(layout="wide")
|
|
@@ -114,41 +114,114 @@ def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
|
|
| 114 |
score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
|
| 115 |
)
|
| 116 |
st.balloons()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
def modify_results(percentage_text_dist_thresh=0.5):
|
| 120 |
with st.expander("Method and Style modification (beta version)"):
|
| 121 |
-
label_list = list(
|
| 122 |
bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
|
| 123 |
for i in range(len(bboxes)):
|
| 124 |
bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
|
| 125 |
bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
|
| 126 |
labels = [int(label) for label in st.session_state.prediction['labels']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
|
| 128 |
scale = 2000 / uploaded_image.size[0]
|
| 129 |
new_labels = detection(
|
| 130 |
-
image=uploaded_image, bboxes=
|
| 131 |
label_list=label_list, line_width=3, width=2000, use_space=False
|
| 132 |
)
|
| 133 |
|
| 134 |
if new_labels is not None:
|
| 135 |
-
new_lab = np.array([label['label_id'] for label in new_labels])
|
| 136 |
-
|
| 137 |
# Convert back to original format
|
| 138 |
bboxes = np.array([label['bbox'] for label in new_labels])
|
| 139 |
for i in range(len(bboxes)):
|
| 140 |
bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
|
| 141 |
bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
scores = st.session_state.prediction['scores']
|
| 144 |
keypoints = st.session_state.prediction['keypoints']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
#print('Old prediction:', st.session_state.prediction['keypoints'])
|
| 146 |
-
boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(
|
| 147 |
|
| 148 |
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
|
| 149 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
| 150 |
|
| 151 |
#print('New prediction:', st.session_state.prediction['keypoints'])
|
|
|
|
| 152 |
|
| 153 |
|
| 154 |
def display_bpmn_modeler(is_mobile, screen_width):
|
|
@@ -186,15 +259,16 @@ def main():
|
|
| 186 |
|
| 187 |
if cropped_image is not None:
|
| 188 |
get_score_threshold(is_mobile)
|
| 189 |
-
if st.button("Launch Prediction"):
|
| 190 |
launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
|
|
|
|
| 191 |
st.rerun()
|
| 192 |
|
| 193 |
if 'prediction' in st.session_state and uploaded_file:
|
| 194 |
-
if st.button("π Refresh image"):
|
| 195 |
-
st.rerun()
|
| 196 |
|
| 197 |
-
with st.expander("Show result"):
|
| 198 |
with st.spinner('Waiting for result display...'):
|
| 199 |
display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
|
| 200 |
|
|
|
|
| 13 |
from streamlit_image_annotation import detection
|
| 14 |
from modules.toXML import create_XML
|
| 15 |
from modules.eval import develop_prediction, generate_data
|
| 16 |
+
from modules.utils import class_dict, object_dict
|
| 17 |
|
| 18 |
def configure_page():
|
| 19 |
st.set_page_config(layout="wide")
|
|
|
|
| 114 |
score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
|
| 115 |
)
|
| 116 |
st.balloons()
|
| 117 |
+
|
| 118 |
+
def mix_new_pred(objects_pred, arrow_pred):
|
| 119 |
+
# Initialize the list of lists for keypoints
|
| 120 |
+
object_keypoints = []
|
| 121 |
+
|
| 122 |
+
# Number of boxes
|
| 123 |
+
num_boxes = len(objects_pred['boxes'])
|
| 124 |
+
|
| 125 |
+
# Iterate over the number of boxes
|
| 126 |
+
for _ in range(num_boxes):
|
| 127 |
+
# Each box has 2 keypoints, both initialized to [0, 0, 0]
|
| 128 |
+
keypoints = [[0, 0, 0], [0, 0, 0]]
|
| 129 |
+
object_keypoints.append(keypoints)
|
| 130 |
+
|
| 131 |
+
#concatenate the two predictions
|
| 132 |
+
boxes = np.concatenate((objects_pred['boxes'], arrow_pred['boxes']))
|
| 133 |
+
labels = np.concatenate((objects_pred['labels'], arrow_pred['labels']))
|
| 134 |
+
|
| 135 |
+
return boxes, labels, keypoints
|
| 136 |
|
| 137 |
|
| 138 |
def modify_results(percentage_text_dist_thresh=0.5):
|
| 139 |
with st.expander("Method and Style modification (beta version)"):
|
| 140 |
+
label_list = list(object_dict.values())
|
| 141 |
bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
|
| 142 |
for i in range(len(bboxes)):
|
| 143 |
bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
|
| 144 |
bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
|
| 145 |
labels = [int(label) for label in st.session_state.prediction['labels']]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Filter boxes and labels where label is less than 12
|
| 149 |
+
ignore_labels = [6, 7]
|
| 150 |
+
object_bboxes = []
|
| 151 |
+
object_labels = []
|
| 152 |
+
arrow_bboxes = []
|
| 153 |
+
arrow_labels = []
|
| 154 |
+
for i in range(len(bboxes)):
|
| 155 |
+
if labels[i] <= 12:
|
| 156 |
+
object_bboxes.append(bboxes[i])
|
| 157 |
+
object_labels.append(labels[i])
|
| 158 |
+
else:
|
| 159 |
+
arrow_bboxes.append(bboxes[i])
|
| 160 |
+
arrow_labels.append(labels[i])
|
| 161 |
+
|
| 162 |
+
print('Object bboxes:', object_bboxes)
|
| 163 |
+
print('Object labels:', object_labels)
|
| 164 |
+
print('Arrow bboxes:', arrow_bboxes)
|
| 165 |
+
print('Arrow labels:', arrow_labels)
|
| 166 |
+
|
| 167 |
+
original_obj_len = len(object_bboxes)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
|
| 171 |
scale = 2000 / uploaded_image.size[0]
|
| 172 |
new_labels = detection(
|
| 173 |
+
image=uploaded_image, bboxes=object_bboxes, labels=object_labels,
|
| 174 |
label_list=label_list, line_width=3, width=2000, use_space=False
|
| 175 |
)
|
| 176 |
|
| 177 |
if new_labels is not None:
|
| 178 |
+
new_lab = np.array([label['label_id'] for label in new_labels])
|
|
|
|
| 179 |
# Convert back to original format
|
| 180 |
bboxes = np.array([label['bbox'] for label in new_labels])
|
| 181 |
for i in range(len(bboxes)):
|
| 182 |
bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
|
| 183 |
bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
|
| 184 |
+
for i in range(len(arrow_bboxes)):
|
| 185 |
+
arrow_bboxes[i][2] = arrow_bboxes[i][2] + arrow_bboxes[i][0]
|
| 186 |
+
arrow_bboxes[i][3] = arrow_bboxes[i][3] + arrow_bboxes[i][1]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
new_bbox = np.concatenate((bboxes, arrow_bboxes))
|
| 190 |
+
new_lab = np.concatenate((new_lab, arrow_labels))
|
| 191 |
+
|
| 192 |
+
print('New labels:', new_lab)
|
| 193 |
|
| 194 |
scores = st.session_state.prediction['scores']
|
| 195 |
keypoints = st.session_state.prediction['keypoints']
|
| 196 |
+
|
| 197 |
+
#delete element in keypoints to make it match the new number of boxes
|
| 198 |
+
len_keypoints = len(keypoints)
|
| 199 |
+
keypoints = keypoints.tolist()
|
| 200 |
+
scores = scores.tolist()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
diff = original_obj_len-len(bboxes)
|
| 204 |
+
if diff > 0:
|
| 205 |
+
for i in range(diff):
|
| 206 |
+
keypoints.pop(0)
|
| 207 |
+
scores.pop(0)
|
| 208 |
+
elif diff < 0:
|
| 209 |
+
for i in range(-diff):
|
| 210 |
+
keypoints.insert(0, [[0, 0, 0], [0, 0, 0]])
|
| 211 |
+
scores.insert(0, 0.0)
|
| 212 |
+
|
| 213 |
+
print('lenghts: ',len(bboxes), len(new_lab), len(scores), len(keypoints))
|
| 214 |
+
keypoints = np.array(keypoints)
|
| 215 |
+
scores = np.array(scores)
|
| 216 |
+
|
| 217 |
#print('Old prediction:', st.session_state.prediction['keypoints'])
|
| 218 |
+
boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, scores, keypoints, class_dict, correction=False)
|
| 219 |
|
| 220 |
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
|
| 221 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
| 222 |
|
| 223 |
#print('New prediction:', st.session_state.prediction['keypoints'])
|
| 224 |
+
st.rerun()
|
| 225 |
|
| 226 |
|
| 227 |
def display_bpmn_modeler(is_mobile, screen_width):
|
|
|
|
| 259 |
|
| 260 |
if cropped_image is not None:
|
| 261 |
get_score_threshold(is_mobile)
|
| 262 |
+
if st.button("π Launch Prediction"):
|
| 263 |
launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
|
| 264 |
+
st.session_state.original_prediction = st.session_state.prediction.copy()
|
| 265 |
st.rerun()
|
| 266 |
|
| 267 |
if 'prediction' in st.session_state and uploaded_file:
|
| 268 |
+
#if st.button("π Refresh image"):
|
| 269 |
+
#st.rerun()
|
| 270 |
|
| 271 |
+
with st.expander("Show result of prediction"):
|
| 272 |
with st.spinner('Waiting for result display...'):
|
| 273 |
display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
|
| 274 |
|
modules/display.py
CHANGED
|
@@ -94,6 +94,8 @@ def draw_stream(image,
|
|
| 94 |
# Draw keypoints if available
|
| 95 |
if draw_keypoints and 'keypoints' in prediction:
|
| 96 |
for i in range(len(prediction['keypoints'])):
|
|
|
|
|
|
|
| 97 |
kp = prediction['keypoints'][i]
|
| 98 |
for j in range(kp.shape[0]):
|
| 99 |
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
|
|
|
|
| 94 |
# Draw keypoints if available
|
| 95 |
if draw_keypoints and 'keypoints' in prediction:
|
| 96 |
for i in range(len(prediction['keypoints'])):
|
| 97 |
+
if i >= len(prediction['keypoints']):
|
| 98 |
+
continue
|
| 99 |
kp = prediction['keypoints'][i]
|
| 100 |
for j in range(kp.shape[0]):
|
| 101 |
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
|
modules/streamlit_utils.py
CHANGED
|
@@ -130,7 +130,7 @@ def display_options(image, score_threshold, is_mobile, screen_width):
|
|
| 130 |
|
| 131 |
# Draw the annotated image with selected options
|
| 132 |
annotated_image = draw_stream(
|
| 133 |
-
np.array(image), prediction=st.session_state.
|
| 134 |
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
|
| 135 |
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option,
|
| 136 |
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
|
|
|
|
| 130 |
|
| 131 |
# Draw the annotated image with selected options
|
| 132 |
annotated_image = draw_stream(
|
| 133 |
+
np.array(image), prediction=st.session_state.original_prediction, text_predictions=st.session_state.text_pred,
|
| 134 |
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
|
| 135 |
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option,
|
| 136 |
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
|
modules/toXML.py
CHANGED
|
@@ -13,7 +13,7 @@ def align_boxes(pred, size):
|
|
| 13 |
for pool_index, element_indices in pred['pool_dict'].items():
|
| 14 |
pool_groups[pool_index] = []
|
| 15 |
for i in element_indices:
|
| 16 |
-
if i
|
| 17 |
continue
|
| 18 |
if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
|
| 19 |
x1, y1, x2, y2 = modified_pred['boxes'][i]
|
|
@@ -138,7 +138,7 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
| 138 |
pool_width = max_x - min_x + 100 # Adding padding
|
| 139 |
pool_height = max_y - min_y + 100 # Adding padding
|
| 140 |
#check area
|
| 141 |
-
if pool_width <
|
| 142 |
print("The pool is too small, please add more elements or increase the scale")
|
| 143 |
continue
|
| 144 |
|
|
@@ -157,6 +157,9 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
| 157 |
# Create sequence flow elements
|
| 158 |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
| 159 |
for i in keep_elements:
|
|
|
|
|
|
|
|
|
|
| 160 |
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
|
| 161 |
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
|
| 162 |
|
|
@@ -259,6 +262,8 @@ def add_diagram_edge(parent, element_id, waypoints):
|
|
| 259 |
'id': element_id + '_di'
|
| 260 |
})
|
| 261 |
for x, y in waypoints:
|
|
|
|
|
|
|
| 262 |
ET.SubElement(edge, 'di:waypoint', attrib={
|
| 263 |
'x': str(x),
|
| 264 |
'y': str(y)
|
|
@@ -312,6 +317,9 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
|
|
| 312 |
links = data['links']
|
| 313 |
|
| 314 |
for i in keep_elements:
|
|
|
|
|
|
|
|
|
|
| 315 |
element_id = elements[i]
|
| 316 |
|
| 317 |
if element_id is None:
|
|
|
|
| 13 |
for pool_index, element_indices in pred['pool_dict'].items():
|
| 14 |
pool_groups[pool_index] = []
|
| 15 |
for i in element_indices:
|
| 16 |
+
if i >= len(modified_pred['labels']):
|
| 17 |
continue
|
| 18 |
if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
|
| 19 |
x1, y1, x2, y2 = modified_pred['boxes'][i]
|
|
|
|
| 138 |
pool_width = max_x - min_x + 100 # Adding padding
|
| 139 |
pool_height = max_y - min_y + 100 # Adding padding
|
| 140 |
#check area
|
| 141 |
+
if pool_width < 300 or pool_height < 30:
|
| 142 |
print("The pool is too small, please add more elements or increase the scale")
|
| 143 |
continue
|
| 144 |
|
|
|
|
| 157 |
# Create sequence flow elements
|
| 158 |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
| 159 |
for i in keep_elements:
|
| 160 |
+
if i >= len(full_pred['labels']):
|
| 161 |
+
print("Problem with the index")
|
| 162 |
+
continue
|
| 163 |
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
|
| 164 |
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
|
| 165 |
|
|
|
|
| 262 |
'id': element_id + '_di'
|
| 263 |
})
|
| 264 |
for x, y in waypoints:
|
| 265 |
+
if x is None or y is None:
|
| 266 |
+
return
|
| 267 |
ET.SubElement(edge, 'di:waypoint', attrib={
|
| 268 |
'x': str(x),
|
| 269 |
'y': str(y)
|
|
|
|
| 317 |
links = data['links']
|
| 318 |
|
| 319 |
for i in keep_elements:
|
| 320 |
+
if i >= len(elements):
|
| 321 |
+
print("Problem with the index")
|
| 322 |
+
continue
|
| 323 |
element_id = elements[i]
|
| 324 |
|
| 325 |
if element_id is None:
|