File size: 8,904 Bytes
615e9f1
 
 
 
b76c717
615e9f1
2d1db93
bb42123
00a4c90
 
 
 
 
615e9f1
00a4c90
7029382
00a4c90
 
 
e508e94
00a4c90
74f41b4
9ef60c2
74f41b4
9ef60c2
e508e94
00a4c90
 
7029382
00a4c90
 
7029382
00a4c90
 
e508e94
00a4c90
615e9f1
 
 
 
00a4c90
615e9f1
00a4c90
9134c9f
00a4c90
 
 
 
 
 
 
 
 
 
 
 
9134c9f
615e9f1
00a4c90
 
74f41b4
00a4c90
 
2d1db93
053df76
 
00a4c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615e9f1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import streamlit as st
from torchvision.transforms import functional as F
import gc
import numpy as np
from modules.htlm_webpage import display_bpmn_xml
from streamlit_cropper import st_cropper
from streamlit_image_select import image_select
from streamlit_js_eval import streamlit_js_eval
from streamlit_drawable_canvas import st_canvas
from modules.streamlit_utils import *
from glob import glob
from streamlit_image_annotation import detection
from modules.toXML import create_XML

def configure_page():
    st.set_page_config(layout="wide")
    screen_width = streamlit_js_eval(js_expressions='screen.width', want_output=True, key='SCR')
    is_mobile = screen_width is not None and screen_width < 800
    return is_mobile, screen_width

def display_banner(is_mobile):
    if is_mobile:
        st.image("./images/banner_mobile.png", use_column_width=True)
    else:
        st.image("./images/banner_desktop.png", use_column_width=True)

def display_title(is_mobile):
    title = "Welcome on the BPMN AI model recognition app"
    if is_mobile:
        title = "Welcome on the mobile version of BPMN AI model recognition app"
    st.title(title)

def display_sidebar():
    sidebar()

def initialize_session_state():
    if 'pool_bboxes' not in st.session_state:
        st.session_state.pool_bboxes = []
    if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
        clear_memory()
        load_models()

def load_example_image():
    with st.expander("Use example images"):
        img_selected = image_select(
            "If you have no image and just want to test the demo, click on one of these images", 
            ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"],
            captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"], 
            index=0, 
            use_container_width=False, 
            return_value="original"
        )
        return img_selected

def load_user_image(img_selected, is_mobile):
    if img_selected == './images/none.jpg':
        img_selected = None

    if img_selected is not None:
        uploaded_file = img_selected
    else:
        if is_mobile:
            uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"], accept_multiple_files=False)
        else:
            col1, col2 = st.columns(2)
            with col1:
                uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])

    return uploaded_file

def display_image(uploaded_file, screen_width, is_mobile):
    
    with st.spinner('Waiting for image display...'):
        original_image = get_image(uploaded_file)
        resized_image = original_image.resize((screen_width // 2, int(original_image.height * (screen_width // 2) / original_image.width)))

        if not is_mobile:
            cropped_image = crop_image(resized_image, original_image)
        else:
            st.image(resized_image, caption="Image", use_column_width=False, width=int(4/5 * screen_width))
            cropped_image = original_image

    return cropped_image

def crop_image(resized_image, original_image):
    marge = 10
    cropped_box = st_cropper(
        resized_image,
        realtime_update=True,
        box_color='#0000FF',
        return_type='box',
        should_resize_image=False,
        default_coords=(marge, resized_image.width - marge, marge, resized_image.height - marge)
    )
    scale_x = original_image.width / resized_image.width
    scale_y = original_image.height / resized_image.height
    x0, y0, x1, y1 = int(cropped_box['left'] * scale_x), int(cropped_box['top'] * scale_y), int((cropped_box['left'] + cropped_box['width']) * scale_x), int((cropped_box['top'] + cropped_box['height']) * scale_y)
    cropped_image = original_image.crop((x0, y0, x1, y1))
    return cropped_image

def get_score_threshold(is_mobile):
    col1, col2 = st.columns(2)
    with col1:
        st.session_state.score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5 if not is_mobile else 0.6, step=0.05) 

def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
    st.session_state.crop_image = cropped_image
    with st.spinner('Processing...'):
        perform_inference(
            st.session_state.model_object, st.session_state.model_arrow, st.session_state.crop_image,
            score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
        )
        st.balloons()    

from modules.eval import develop_prediction, generate_data
from modules.utils import class_dict
def modify_results(percentage_text_dist_thresh=0.5):
    with st.expander("Method and Style modification"):
        label_list = list(class_dict.values())
        bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
        for i in range(len(bboxes)):
            bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
            bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
        labels = [int(label) for label in st.session_state.prediction['labels']]
        uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
        scale = 2000 / uploaded_image.size[0]
        new_labels = detection(
            image=uploaded_image, bboxes=bboxes, labels=labels, 
            label_list=label_list, line_width=3, width=2000, use_space=False
        )

        if new_labels is not None:
            new_lab = np.array([label['label_id'] for label in new_labels])
  
            # Convert back to original format
            bboxes = np.array([label['bbox'] for label in new_labels])
            for i in range(len(bboxes)):
                bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
                bboxes[i][3] = bboxes[i][3] + bboxes[i][1]

            scores = st.session_state.prediction['scores']
            keypoints = st.session_state.prediction['keypoints']
            #print('Old prediction:', st.session_state.prediction['keypoints'])
            boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(bboxes, new_lab, scores, keypoints, class_dict, correction=False)

            st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
            st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)

            #print('New prediction:', st.session_state.prediction['keypoints'])


def display_bpmn_modeler(is_mobile, screen_width):
    with st.spinner('Waiting for BPMN modeler...'):
        st.session_state.bpmn_xml = create_XML(
            st.session_state.prediction.copy(), st.session_state.text_mapping, 
            st.session_state.size_scale, st.session_state.scale
        )
        display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))

def modeler_options(is_mobile):
    col1, col2 = st.columns(2)
    with col1:
        st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1) if not is_mobile else 1.0
        st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1) if not is_mobile else 1.0

def main():
    is_mobile, screen_width = configure_page()
    display_banner(is_mobile)
    display_title(is_mobile)
    display_sidebar()
    initialize_session_state()

    cropped_image = None

    img_selected = load_example_image()
    uploaded_file = load_user_image(img_selected, is_mobile)
    if uploaded_file is not None:
        cropped_image = display_image(uploaded_file, screen_width, is_mobile)

    if cropped_image is not None:
        get_score_threshold(is_mobile)
        if st.button("Launch Prediction"):
            launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
            st.rerun()
    
    if 'prediction' in st.session_state and uploaded_file:
        if st.button("πŸ”„ Refresh image"):
            st.rerun()

        with st.expander("Show result"):
            with st.spinner('Waiting for result display...'):
                display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))

        #if not is_mobile:
            #modify_results()

        with st.expander("Options for BPMN modeler"):
            modeler_options(is_mobile)
        
        display_bpmn_modeler(is_mobile, screen_width)

    gc.collect()

if __name__ == "__main__":
    print('Starting the app...')
    main()