sketch-to-BPMN / app.py
BenjiELCA's picture
add some file to reduce code in the app.py
e108fc3
raw
history blame
10.2 kB
import streamlit as st
from torchvision.transforms import functional as F
import gc
import copy
import xml.etree.ElementTree as ET
import numpy as np
from xml.dom import minidom
from modules.htlm_webpage import display_bpmn_xml
from modules.utils import class_dict, rescale_boxes
from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element, get_size_elements, definitions
from streamlit_cropper import st_cropper
from streamlit_image_select import image_select
from streamlit_js_eval import streamlit_js_eval
from modules.streamlit_utils import get_memory_usage, clear_memory, get_image, load_models, perform_inference, display_options, align_boxes, sidebar
# Function to create a BPMN XML file from prediction results
def create_XML(full_pred, text_mapping, size_scale, scale):
size_elements = get_size_elements(size_scale)
#modify the boxes positions
old_boxes = copy.deepcopy(full_pred)
# Create BPMN collaboration element
collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
# Create BPMN process elements
process = []
for idx in range(len(full_pred['pool_dict'].items())):
process_id = f'process_{idx+1}'
process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]))
bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes'])
full_pred['boxes'] = align_boxes(full_pred, size_elements)
# Add diagram elements for each pool
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
pool_id = f'participant_{idx+1}'
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
# Calculate the bounding box for the pool
if len(keep_elements) == 0:
min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index]
pool_width = max_x - min_x
pool_height = max_y - min_y
else:
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements)
pool_width = max_x - min_x + 100 # Adding padding
pool_height = max_y - min_y + 100 # Adding padding
add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
# Create BPMN elements for each pool
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
# Create message flow elements
message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow']
for idx in message_flows:
create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True)
# Create sequence flow elements
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
for i in keep_elements:
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
# Generate pretty XML string
tree = ET.ElementTree(definitions)
rough_string = ET.tostring(definitions, 'utf-8')
reparsed = minidom.parseString(rough_string)
pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes'])
full_pred['boxes'] = old_boxes
return pretty_xml_as_string
def main():
st.set_page_config(layout="wide")
screen_width = streamlit_js_eval(js_expressions='screen.width', want_output = True, key = 'SCR')
print("Screen width:", screen_width)
if screen_width is not None and screen_width < 800:
is_mobile = True
print('Mobile version')
else:
is_mobile = False
print('Desktop version')
# Add your company logo banner
if is_mobile:
st.image("./images/banner_mobile.png", use_column_width=True)
else:
st.image("./images/banner_desktop.png", use_column_width=True)
# Use is_mobile flag in your logic
if is_mobile:
st.title(f"Welcome on the mobile version of BPMN AI model recognition app")
else:
st.title(f"Welcome on BPMN AI model recognition app")
sidebar() # Display the sidebar
# Display current memory usage
memory_usage = get_memory_usage()
print(f"Current memory usage: {memory_usage:.2f} MB")
# Initialize the session state for storing pool bounding boxes
if 'pool_bboxes' not in st.session_state:
st.session_state.pool_bboxes = []
# Load the models using the defined function
if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
clear_memory()
_, _ = load_models()
model_arrow = st.session_state.model_arrow
model_object = st.session_state.model_object
with st.expander("Use example images"):
img_selected = image_select("If you have no image and just want to test the demo, click on one of these images", ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"],
captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"], index=0, use_container_width=False, return_value="original")
if img_selected== './images/none.jpg':
print('No example image selected')
img_selected = None
if is_mobile==False:
#Create the layout for the app
col1, col2 = st.columns(2)
with col1:
if img_selected is not None:
uploaded_file = img_selected
else:
uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
else:
if img_selected is not None:
uploaded_file = img_selected
else:
uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
with st.spinner('Waiting for image display...'):
original_image = get_image(uploaded_file)
resized_image = original_image.resize((screen_width // 3, int(original_image.height * (screen_width // 3) / original_image.width)))
if not is_mobile:
col1, col2 = st.columns(2)
with col1:
marge=10
cropped_box = st_cropper(
resized_image,
realtime_update=True,
box_color='#0000FF',
return_type='box',
should_resize_image=False,
default_coords=(marge, resized_image.width-marge, marge, resized_image.height-marge)
)
scale_x = original_image.width / resized_image.width
scale_y = original_image.height / resized_image.height
x0, y0, x1, y1 = int(cropped_box['left'] * scale_x), int(cropped_box['top'] * scale_y), int((cropped_box['left'] + cropped_box['width']) * scale_x), int((cropped_box['top'] + cropped_box['height']) * scale_y)
cropped_image = original_image.crop((x0, y0, x1, y1))
with col2:
st.image(cropped_image, caption="Cropped Image", use_column_width=False, width=int(screen_width//4))
else:
st.image(resized_image, caption="Image", use_column_width=False, width=int(4/5*screen_width))
cropped_image = original_image
if cropped_image is not None:
if is_mobile is False:
col1, col2 = st.columns(2)
with col1:
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
else:
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.6, step=0.05)
if st.button("Launch Prediction"):
st.session_state.crop_image = cropped_image
with st.spinner('Processing...'):
perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5)
st.balloons()
if 'prediction' in st.session_state and uploaded_file is not None:
with st.spinner('Waiting for result display...'):
display_options(st.session_state.crop_image, score_threshold, is_mobile, int(5/6*screen_width))
with st.spinner('Waiting for BPMN modeler...'):
col1, col2 = st.columns(2)
with col1:
st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
if is_mobile is False:
st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
else:
st.session_state.size_scale = 1.0
st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.size_scale, st.session_state.scale)
display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width))
gc.collect()
if __name__ == "__main__":
print('Starting the app...')
main()