import pandas as pd import numpy as np import streamlit as st import random import sahi.utils.file import tempfile import os from PIL import Image from sahi import AutoDetectionModel from utils import convert_pdf_file from streamlit_image_comparison import image_comparison from ultralyticsplus.hf_utils import download_from_hub IMAGE_TO_URL = { 'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png', 'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', 'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png', 'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png' } st.set_page_config( page_title="P&ID Object Detection", layout="wide", initial_sidebar_state="expanded" ) st.title('P&ID Object Detection') st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow') st.caption('Developed by Deep Drawings Co.') @st.cache_resource(show_spinner=False) def get_model(): yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8') detection_model = AutoDetectionModel.from_pretrained( model_type='yolov8', model_path=yolov8_model_path, confidence_threshold=0.75, device="cpu", ) return detection_model @st.cache_data(show_spinner=False) def download_comparison_images(): sahi.utils.file.download_from_url( 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', 'plant_pid.png', ) sahi.utils.file.download_from_url( 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png', 'prediction_visual.png', ) download_comparison_images() if "output_1" not in st.session_state: st.session_state["output_1"] = Image.open('plant_pid.png') if "output_2" not in st.session_state: st.session_state["output_2"] = Image.open('prediction_visual.png') col1, col2, col3 = st.columns(3, gap='medium') with col1: with st.expander('How to use it'): st.markdown( ''' 1) Upload your P&ID or select example diagrams 📬 2) Set confidence threshold 📈 3) Press to perform inference 🚀 4) Visualize model predictions 🔎 ''' ) st.write('##') col1, col2, col3 = st.columns(3, gap='large') with col1: st.markdown('##### Input File') # set input image by upload uploaded_file = st.file_uploader("Upload your diagram", type="pdf") if uploaded_file: temp_dir = tempfile.mkdtemp() path = os.path.join(temp_dir, uploaded_file.name) with open(path, "wb") as f: f.write(uploaded_file.getvalue()) # set input images from examples def radio_func(option): option_to_id = { 'factory_pid.png' : 'A', 'plant_pid.png' : 'B', 'processing_pid.png' : 'C', } return option_to_id[option] radio = st.radio( 'Or select from the following examples', options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'], format_func = radio_func, ) with col2: st.markdown('##### Preview') # visualize input image if uploaded_file is not None: image_file = convert_pdf_file(path=path) image = Image.open(image_file) else: image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio]) with st.container(border = True): st.image(image, use_column_width = True) with col3: st.markdown('##### Set model parameters') postprocess_match_threshold = st.slider( label = 'Select confidence threshold', min_value = 0.0, max_value = 1.0, value = 0.75, step = 0.25 ) postprocess_match_metric = st.slider( label = 'Select IoU threshold', min_value = 0.0, max_value = 1.0, value = 0.75, step = 0.25 ) st.write('##') col1, col2, col3 = st.columns([3, 1, 3]) with col2: submit = st.button("🚀 Perform Prediction") if submit: # perform prediction with st.spinner(text="Downloading model weight ... "): detection_model = get_model() image_size = 1280 with st.spinner(text="Performing prediction ... "): output_1, output_2 = sahi_yolov8m_inference( image, detection_model, image_size=image_size, slice_height=slice_size, slice_width=slice_size, overlap_height_ratio=overlap_ratio, overlap_width_ratio=overlap_ratio, postprocess_match_threshold=postprocess_match_threshold ) st.session_state["output_1"] = output_1 st.session_state["output_2"] = output_2 st.write('##') col1, col2, col3 = st.columns([1, 4, 1]) with col2: st.markdown(f"#### Object Detection Result") with st.container(border = True): static_component = image_comparison( img1=st.session_state["output_1"], img2=st.session_state["output_2"], label1='Uploaded Diagram', label2='Model Inference', width=800, starting_position=50, show_labels=True, make_responsive=True, in_memory=True, )