File size: 4,791 Bytes
39debef
3170d22
837f8a9
 
4c6c8f5
bd7c529
837f8a9
0715e8c
3170d22
 
 
bd7c529
 
0715e8c
39debef
8ed86f2
3170d22
8ed86f2
3cf084f
3170d22
8ed86f2
3170d22
 
cf4de60
 
837f8a9
 
 
 
 
 
b113de0
837f8a9
 
 
 
6c45483
 
 
 
 
 
 
 
 
 
 
 
 
bd7c529
6c45483
bd7c529
 
6c45483
bd7c529
56e0661
8c0dd07
0715e8c
76a2d16
cf4de60
108ee38
6494a0c
4d1e216
 
76a2d16
3170d22
76a2d16
96b4bfb
4d1e216
9af4bab
 
108ee38
4d1e216
 
 
 
 
 
 
 
 
ee620ac
4d1e216
 
 
9af4bab
4d1e216
55018fa
4d1e216
c72147b
9af4bab
 
9448fdb
108ee38
53b109e
b113de0
 
 
 
108ee38
 
53b109e
108ee38
 
 
 
9af4bab
108ee38
53b109e
9448fdb
 
43c365e
 
9448fdb
56e0661
 
 
b574fc0
56e0661
 
837f8a9
 
 
bed94f9
837f8a9
 
b113de0
837f8a9
 
108ee38
837f8a9
 
 
108ee38
 
 
 
837f8a9
 
 
108ee38
 
bd7c529
 
 
cc0e43c
0b13976
 
cc0e43c
0b13976
 
 
a5158fe
 
cc0e43c
0b13976
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st
import sahi.utils.file
from PIL import Image
from sahi import AutoDetectionModel
from utils import sahi_yolov8m_inference
from streamlit_image_comparison import image_comparison
from ultralyticsplus.hf_utils import download_from_hub

IMAGE_TO_URL = {
    'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png',
    'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
    'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png',
    'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png'
    }

st.set_page_config(
    page_title="P&ID Object Detection",
    layout="wide",
    initial_sidebar_state="expanded"
    )

st.title('P&ID Object Detection')
st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow')
st.caption('Developed by Deep Drawings Co.')

@st.cache_resource(show_spinner=False)
def get_model():
    yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8')
    detection_model = AutoDetectionModel.from_pretrained(
        model_type='yolov8',
        model_path=yolov8_model_path,
        confidence_threshold=0.8,
        device="cpu",
    )
    return detection_model
    
@st.cache_data(show_spinner=False)
def download_comparison_images():
    sahi.utils.file.download_from_url(
        'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
        'plant_pid.png',
    )
    sahi.utils.file.download_from_url(
        'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png',
        'prediction_visual.png',
    )

download_comparison_images()

if "output_1" not in st.session_state:
    st.session_state["output_1"] = Image.open('plant_pid.png')

if "output_2" not in st.session_state:
    st.session_state["output_2"] = Image.open('prediction_visual.png')

col1, col2, col3 = st.columns(3, gap='medium')
with col1:
    with st.expander('How to use it'):
        st.markdown(
        '''
        1) Select any example diagram  πŸ‘†πŸ»
        2) Set confidence threshold πŸ“ˆ
        3) Press to perform inference   πŸš€
        4) Visualize model predictions  πŸ”Ž
        '''
        )   

st.write('##')
   
col1, col2, col3 = st.columns(3, gap='large')
with col1:
    st.markdown('##### Input Data')
    # set input images from examples
    def radio_func(option):
        option_to_id = {
            'factory_pid.png' : 'A',
            'plant_pid.png' : 'B',
            'processing_pid.png' : 'C',
        }
        return option_to_id[option]
    radio = st.radio(
        'Select from the following examples',
        options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'],
        format_func = radio_func,
    )
with col2:
    st.markdown('##### Preview')
    image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio])
    with st.container(border = True):
        st.image(image, use_column_width = True)
        
with col3:
    st.markdown('##### Set model parameters')
    slice_size = st.slider(
        label = 'Slice Size', 
        min_value=1240,
        max_value=4960, 
        value=2480, 
        step=1240
    )
    overlap_ratio = st.slider(
        label = 'Overlap Ratio', 
        min_value=0.0, 
        max_value=0.5, 
        value=0.1, 
        step=0.1
    )
    postprocess_match_threshold = st.slider(
        label = 'Confidence Threshold',
        min_value = 0.0,
        max_value = 1.0,
        value = 0.8,
        step = 0.1
    )

st.write('##')

col1, col2, col3 = st.columns([3, 1, 3])
with col2:
    submit = st.button("πŸš€ Perform Prediction")
    
if submit:
    # perform prediction
    with st.spinner(text="Downloading model weights ... "):
        detection_model = get_model()
        
    image_size = 4960

    with st.spinner(text="Performing prediction ... "):
        output = sahi_yolov8m_inference(
            image,
            detection_model,
            image_size=image_size,
            slice_height=slice_size,
            slice_width=slice_size,
            overlap_height_ratio=overlap_ratio,
            overlap_width_ratio=overlap_ratio,
            postprocess_match_threshold=postprocess_match_threshold
        )

    st.session_state["output_1"] = image
    st.session_state["output_2"] = output 

st.write('##')

col1, col2, col3 = st.columns([1, 3, 1])
with col2:
    st.markdown(f"#### Object Detection Result")
    with st.container(height=800,border = True):
        static_component = image_comparison(
        img1=st.session_state["output_1"],
        img2=st.session_state["output_2"],
        label1='Raw Diagram',
        label2='Inference Prediction',
        width=768,
        starting_position=50,
        show_labels=True,
        make_responsive=True,
        in_memory=True,
        )