|
import pandas as pd |
|
import numpy as np |
|
import streamlit as st |
|
import random |
|
import sahi.utils.file |
|
import tempfile |
|
import os |
|
from PIL import Image |
|
from sahi import AutoDetectionModel |
|
|
|
from utils import sahi_yolov8m_inference |
|
from streamlit_image_comparison import image_comparison |
|
from ultralyticsplus.hf_utils import download_from_hub |
|
|
|
IMAGE_TO_URL = { |
|
'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png', |
|
'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', |
|
'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png', |
|
'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png' |
|
} |
|
|
|
st.set_page_config( |
|
page_title="P&ID Object Detection", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
st.title('P&ID Object Detection') |
|
st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow') |
|
st.caption('Developed by Deep Drawings Co.') |
|
|
|
@st.cache_resource(show_spinner=False) |
|
def get_model(): |
|
yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8') |
|
detection_model = AutoDetectionModel.from_pretrained( |
|
model_type='yolov8', |
|
model_path=yolov8_model_path, |
|
confidence_threshold=0.75, |
|
device="cpu", |
|
) |
|
return detection_model |
|
|
|
@st.cache_data(show_spinner=False) |
|
def download_comparison_images(): |
|
sahi.utils.file.download_from_url( |
|
'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', |
|
'plant_pid.png', |
|
) |
|
sahi.utils.file.download_from_url( |
|
'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png', |
|
'prediction_visual.png', |
|
) |
|
|
|
download_comparison_images() |
|
|
|
if "output_1" not in st.session_state: |
|
st.session_state["output_1"] = Image.open('plant_pid.png') |
|
|
|
if "output_2" not in st.session_state: |
|
st.session_state["output_2"] = Image.open('prediction_visual.png') |
|
|
|
col1, col2, col3 = st.columns(3, gap='medium') |
|
with col1: |
|
with st.expander('How to use it'): |
|
st.markdown( |
|
''' |
|
1) Upload your P&ID or select example diagrams π¬ |
|
2) Set confidence threshold π |
|
3) Press to perform inference π |
|
4) Visualize model predictions π |
|
''' |
|
) |
|
|
|
st.write('##') |
|
|
|
col1, col2, col3 = st.columns(3, gap='large') |
|
with col1: |
|
st.markdown('##### Input File') |
|
|
|
uploaded_file = st.file_uploader("Upload your diagram", type="pdf") |
|
if uploaded_file: |
|
temp_dir = tempfile.mkdtemp() |
|
path = os.path.join(temp_dir, uploaded_file.name) |
|
with open(path, "wb") as f: |
|
f.write(uploaded_file.getvalue()) |
|
|
|
def radio_func(option): |
|
option_to_id = { |
|
'factory_pid.png' : 'A', |
|
'plant_pid.png' : 'B', |
|
'processing_pid.png' : 'C', |
|
} |
|
return option_to_id[option] |
|
radio = st.radio( |
|
'Or select from the following examples', |
|
options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'], |
|
format_func = radio_func, |
|
) |
|
with col2: |
|
st.markdown('##### Preview') |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(image_file) |
|
else: |
|
image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio]) |
|
with st.container(border = True): |
|
st.image(image, use_column_width = True) |
|
|
|
with col3: |
|
st.markdown('##### Set model parameters') |
|
postprocess_match_threshold = st.slider( |
|
label = 'Select confidence threshold', |
|
min_value = 0.0, |
|
max_value = 1.0, |
|
value = 0.75, |
|
step = 0.25 |
|
) |
|
postprocess_match_metric = st.slider( |
|
label = 'Select IoU threshold', |
|
min_value = 0.0, |
|
max_value = 1.0, |
|
value = 0.75, |
|
step = 0.25 |
|
) |
|
|
|
st.write('##') |
|
|
|
col1, col2, col3 = st.columns([3, 1, 3]) |
|
with col2: |
|
submit = st.button("π Perform Prediction") |
|
|
|
if submit: |
|
|
|
with st.spinner(text="Downloading model weight ... "): |
|
detection_model = get_model() |
|
|
|
image_size = 1280 |
|
|
|
with st.spinner(text="Performing prediction ... "): |
|
output_1, output_2 = sahi_yolov8m_inference( |
|
image, |
|
detection_model, |
|
image_size=image_size, |
|
slice_height=slice_size, |
|
slice_width=slice_size, |
|
overlap_height_ratio=overlap_ratio, |
|
overlap_width_ratio=overlap_ratio, |
|
postprocess_match_threshold=postprocess_match_threshold |
|
) |
|
|
|
st.session_state["output_1"] = output_1 |
|
st.session_state["output_2"] = output_2 |
|
|
|
st.write('##') |
|
|
|
col1, col2, col3 = st.columns([1, 4, 1]) |
|
with col2: |
|
st.markdown(f"#### Object Detection Result") |
|
with st.container(border = True): |
|
static_component = image_comparison( |
|
img1=st.session_state["output_1"], |
|
img2=st.session_state["output_2"], |
|
label1='Uploaded Diagram', |
|
label2='Model Inference', |
|
width=820, |
|
starting_position=50, |
|
show_labels=True, |
|
make_responsive=True, |
|
in_memory=True, |
|
) |
|
|