import pandas as pd import numpy as np import streamlit as st import random import sahi.utils.file import tempfile import os from PIL import Image from sahi import AutoDetectionModel #from utils import convert_pdf_file from utils import sahi_yolov8m_inference from streamlit_image_comparison import image_comparison from ultralyticsplus.hf_utils import download_from_hub IMAGE_TO_URL = { 'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png', 'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', 'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png', 'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png' } st.set_page_config( page_title="P&ID Object Detection", layout="wide", initial_sidebar_state="expanded" ) st.title('P&ID Object Detection') st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow') st.caption('Developed by Deep Drawings Co.') @st.cache_resource(show_spinner=False) def get_model(): yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8') detection_model = AutoDetectionModel.from_pretrained( model_type='yolov8', model_path=yolov8_model_path, confidence_threshold=0.75, device="cpu", ) return detection_model @st.cache_data(show_spinner=False) def download_comparison_images(): sahi.utils.file.download_from_url( 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', 'plant_pid.png', ) sahi.utils.file.download_from_url( 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png', 'prediction_visual.png', ) download_comparison_images() if "output_1" not in st.session_state: st.session_state["output_1"] = Image.open('plant_pid.png') if "output_2" not in st.session_state: st.session_state["output_2"] = Image.open('prediction_visual.png') col1, col2, col3 = st.columns(3, gap='medium') with col1: with st.expander('How to use it'): st.markdown( ''' 1) Select any example diagram 📬 2) Set confidence threshold 📈 3) Press to perform inference 🚀 4) Visualize model predictions 🔎 ''' ) st.write('##') col1, col2, col3 = st.columns(3, gap='large') with col1: #st.markdown('##### Input File') # set input image by upload #uploaded_file = st.file_uploader("Upload your diagram", type="pdf") #if uploaded_file: # temp_dir = tempfile.mkdtemp() # path = os.path.join(temp_dir, uploaded_file.name) # with open(path, "wb") as f: # f.write(uploaded_file.getvalue()) # set input images from examples def radio_func(option): option_to_id = { 'factory_pid.png' : 'A', 'plant_pid.png' : 'B', 'processing_pid.png' : 'C', } return option_to_id[option] radio = st.radio( 'Select from the following examples', options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'], format_func = radio_func, ) with col2: st.markdown('##### Preview') # visualize input image #if uploaded_file is not None: #image_file = convert_pdf_file(path=path) #image = Image.open(image_file) #else: image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio]) with st.container(border = True): st.image(image, use_column_width = True) with col3: st.markdown('##### Set model parameters') postprocess_match_threshold = st.slider( label = 'Select confidence threshold', min_value = 0.0, max_value = 1.0, value = 0.75, step = 0.25 ) postprocess_match_metric = st.slider( label = 'Select IoU threshold', min_value = 0.0, max_value = 1.0, value = 0.75, step = 0.25 ) st.write('##') col1, col2, col3 = st.columns([3, 1, 3]) with col2: submit = st.button("🚀 Perform Prediction") if submit: # perform prediction with st.spinner(text="Downloading model weights ... "): detection_model = get_model() image_size = 1280 with st.spinner(text="Performing prediction ... "): output_1, output_2 = sahi_yolov8m_inference( image, detection_model, image_size=image_size, #slice_height=slice_size, #slice_width=slice_size, #overlap_height_ratio=overlap_ratio, #overlap_width_ratio=overlap_ratio, postprocess_match_threshold=postprocess_match_threshold ) st.session_state["output_1"] = output_1 st.session_state["output_2"] = output_2 st.write('##') col1, col2, col3 = st.columns([1, 4, 1]) with col2: st.markdown(f"#### Object Detection Result") with st.container(border = True): static_component = image_comparison( img1=st.session_state["output_1"], img2=st.session_state["output_2"], label1='Uploaded Diagram', label2='Model Inference', width=820, starting_position=50, show_labels=True, make_responsive=True, in_memory=True, )