Daniel Cerda Escobar
Update app file
986adf5
raw
history blame
5.28 kB
import pandas as pd
import numpy as np
import streamlit as st
import random
import sahi.utils.file
import tempfile
import os
from PIL import Image
from sahi import AutoDetectionModel
#from utils import convert_pdf_file
from utils import sahi_yolov8m_inference
from streamlit_image_comparison import image_comparison
from ultralyticsplus.hf_utils import download_from_hub
IMAGE_TO_URL = {
'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png',
'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png',
'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png'
}
st.set_page_config(
page_title="P&ID Object Detection",
layout="wide",
initial_sidebar_state="expanded"
)
st.title('P&ID Object Detection')
st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow')
st.caption('Developed by Deep Drawings Co.')
@st.cache_resource(show_spinner=False)
def get_model():
yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8')
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=0.75,
device="cpu",
)
return detection_model
@st.cache_data(show_spinner=False)
def download_comparison_images():
sahi.utils.file.download_from_url(
'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
'plant_pid.png',
)
sahi.utils.file.download_from_url(
'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png',
'prediction_visual.png',
)
download_comparison_images()
if "output_1" not in st.session_state:
st.session_state["output_1"] = Image.open('plant_pid.png')
if "output_2" not in st.session_state:
st.session_state["output_2"] = Image.open('prediction_visual.png')
col1, col2, col3 = st.columns(3, gap='medium')
with col1:
with st.expander('How to use it'):
st.markdown(
'''
1) Upload your P&ID or select example diagrams πŸ“¬
2) Set confidence threshold πŸ“ˆ
3) Press to perform inference πŸš€
4) Visualize model predictions πŸ”Ž
'''
)
st.write('##')
col1, col2, col3 = st.columns(3, gap='large')
with col1:
st.markdown('##### Input File')
# set input image by upload
uploaded_file = st.file_uploader("Upload your diagram", type="pdf")
if uploaded_file:
temp_dir = tempfile.mkdtemp()
path = os.path.join(temp_dir, uploaded_file.name)
with open(path, "wb") as f:
f.write(uploaded_file.getvalue())
# set input images from examples
def radio_func(option):
option_to_id = {
'factory_pid.png' : 'A',
'plant_pid.png' : 'B',
'processing_pid.png' : 'C',
}
return option_to_id[option]
radio = st.radio(
'Or select from the following examples',
options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'],
format_func = radio_func,
)
with col2:
st.markdown('##### Preview')
# visualize input image
if uploaded_file is not None:
#image_file = convert_pdf_file(path=path)
image = Image.open(image_file)
else:
image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio])
with st.container(border = True):
st.image(image, use_column_width = True)
with col3:
st.markdown('##### Set model parameters')
postprocess_match_threshold = st.slider(
label = 'Select confidence threshold',
min_value = 0.0,
max_value = 1.0,
value = 0.75,
step = 0.25
)
postprocess_match_metric = st.slider(
label = 'Select IoU threshold',
min_value = 0.0,
max_value = 1.0,
value = 0.75,
step = 0.25
)
st.write('##')
col1, col2, col3 = st.columns([3, 1, 3])
with col2:
submit = st.button("πŸš€ Perform Prediction")
if submit:
# perform prediction
with st.spinner(text="Downloading model weights ... "):
detection_model = get_model()
image_size = 1280
with st.spinner(text="Performing prediction ... "):
output_1, output_2 = sahi_yolov8m_inference(
image,
detection_model,
image_size=image_size,
#slice_height=slice_size,
#slice_width=slice_size,
#overlap_height_ratio=overlap_ratio,
#overlap_width_ratio=overlap_ratio,
postprocess_match_threshold=postprocess_match_threshold
)
st.session_state["output_1"] = output_1
st.session_state["output_2"] = output_2
st.write('##')
col1, col2, col3 = st.columns([1, 4, 1])
with col2:
st.markdown(f"#### Object Detection Result")
with st.container(border = True):
static_component = image_comparison(
img1=st.session_state["output_1"],
img2=st.session_state["output_2"],
label1='Uploaded Diagram',
label2='Model Inference',
width=820,
starting_position=50,
show_labels=True,
make_responsive=True,
in_memory=True,
)