Daniel Cerda Escobar
Plot graph
61ccf4c
raw
history blame
7 kB
import streamlit as st
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import sahi.utils.file
from PIL import Image
from sahi import AutoDetectionModel
from utils import sahi_yolov8m_inference
from ultralyticsplus.hf_utils import download_from_hub
IMAGE_TO_URL = {
'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png',
'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png',
'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png'
}
st.set_page_config(
page_title="P&ID Object Detection",
layout="wide",
initial_sidebar_state="expanded"
)
st.title('P&ID Object Detection')
st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow')
st.caption('Developed by Deep Drawings Co.')
@st.cache_resource(show_spinner=False)
def get_model(postprocess_match_threshold):
yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8')
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=postprocess_match_threshold,
device="cpu",
)
return detection_model
@st.cache_data(show_spinner=False)
def download_comparison_images():
sahi.utils.file.download_from_url(
'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
'plant_pid.png',
)
sahi.utils.file.download_from_url(
'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png',
'prediction_visual.png',
)
download_comparison_images()
# initialize prediction visual data
coco_df = pd.DataFrame({
'category' : ['centrifugal-pump','centrifugal-pump','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve'],
'score' : [0.88, 0.85, 0.87, 0.87, 0.86, 0.86, 0.85, 0.84, 0.81, 0.81, 0.76]
})
output_df = pd.DataFrame({
'category':['ball-valve', 'butterfly-valve', 'centrifugal-pump', 'check-valve', 'gate-valve'],
'count':[0, 0, 2, 0, 9],
'percentage':[0, 0, 18.2, 0, 81.8]
})
# session state
if "output_1" not in st.session_state:
img_1 = Image.open('plant_pid.png')
st.session_state["output_1"] = img_1.resize((4960,3508))
if "output_2" not in st.session_state:
img_2 = Image.open('prediction_visual.png')
st.session_state["output_2"] = img_2.resize((4960,3508))
if "output_3" not in st.session_state:
st.session_state["output_3"] = coco_df
if "output_4" not in st.session_state:
st.session_state["output_4"] = output_df
col1, col2, col3 = st.columns(3, gap='medium')
with col1:
with st.expander('How to use it'):
st.markdown(
'''
1) Upload or select any example diagram πŸ‘†πŸ»
2) Set model parameters πŸ“ˆ
3) Press to perform inference πŸš€
4) Visualize model predictions πŸ”Ž
'''
)
st.write('##')
col1, col2, col3 = st.columns(3, gap='large')
with col1:
st.markdown('##### Set Input Image')
# set input image by upload
image_file = st.file_uploader(
'Upload your P&ID', type = ['jpg','jpeg','png']
)
# set input images from examples
def radio_func(option):
option_to_id = {
'factory_pid.png' : 'A',
'plant_pid.png' : 'B',
'processing_pid.png' : 'C',
}
return option_to_id[option]
radio = st.radio(
'Select from the following examples',
options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'],
format_func = radio_func,
)
with col2:
# visualize input image
if image_file is not None:
image = Image.open(image_file)
else:
image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio])
st.markdown('##### Preview')
with st.container(border = True):
st.image(image, use_column_width = True)
with col3:
# set SAHI parameters
st.markdown('##### Set model parameters')
slice_number = st.select_slider(
'Slices per Image',
options = [
'1',
'4',
'16',
'64',
],
value = '4'
)
overlap_ratio = st.slider(
label = 'Slicing Overlap Ratio',
min_value=0.0,
max_value=0.5,
value=0.1,
step=0.1
)
postprocess_match_threshold = st.slider(
label = 'Confidence Threshold',
min_value = 0.0,
max_value = 1.0,
value = 0.85,
step = 0.05
)
st.write('##')
col1, col2, col3 = st.columns([4, 1, 4])
with col2:
submit = st.button("πŸš€ Perform Prediction")
if submit:
# perform prediction
with st.spinner(text="Downloading model weights ... "):
detection_model = get_model(postprocess_match_threshold)
slice_size = int(4960/(float(slice_number)**0.5))
image_size = 4960
with st.spinner(text="Performing prediction ... "):
output_visual,coco_df,output_df = sahi_yolov8m_inference(
image,
detection_model,
image_size=image_size,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap_ratio,
overlap_width_ratio=overlap_ratio,
)
st.session_state["output_1"] = image
st.session_state["output_2"] = output_visual
st.session_state["output_3"] = coco_df
st.session_state["output_4"] = output_df
st.write('##')
col1, col2, col3 = st.columns([1, 5, 1], gap='small')
with col2:
st.markdown(f"#### Object Detection Result")
with st.container(border = True):
tab1, tab2, tab3, tab4 = st.tabs(['Original Image','Inference Prediction','Data','Insights'])
with tab1:
st.image(st.session_state["output_1"])
with tab2:
st.image(st.session_state["output_2"])
with tab3:
col1,col2,col3 = st.columns([1,2,1])
with col2:
st.dataframe(
st.session_state["output_3"],
column_config = {
'category' : 'Predicted Category',
'score' : 'Confidence',
},
use_container_width = True,
hide_index = True,
)
with tab4:
col1,col2,col3 = st.columns([1,4,1])
with col2:
chart_data = st.session_state["output_4"]
bar_plot = sns.barplot(x='count', y='category', data=chart_data, hue='category', legend=False)
bar_plot.bar_label(bar_plot.containers[0], fontsize=10);
st.pyplot(bar_plot.figure, use_container_width=True)
#st.bar_chart(chart_data[['category','count']], x='category', y='count', use_container_width=True)