File size: 3,828 Bytes
bf0045a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from super_gradients.training import models
from apd_utils import write_video, convert_video
import torch, PIL, os
import streamlit as st

CLASSES = ['Dust Mask', 'Eye Wear', 'Glove', 'Protective Boots', 'Protective Helmet', 'Safety Vest', 'Shield']
SOURCES = ['Images', 'Videos']

# Setting page layout
st.set_page_config(
    page_title="PPE Object Detection using YOLO-NAS",
    page_icon="👷",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Main page heading
st.title("PPE Object Detection using YOLO-NAS")

# Sidebar
st.sidebar.header("YOLO-NAS Model Config")

# Model Options
confidence = float(st.sidebar.slider(
    "Select Model Confidence", 0, 100, 40)) / 100

st.sidebar.header("Image/Video Config")
source_radio = st.sidebar.radio("Select Source", SOURCES)

source_img = None
source_vid = None

#with st.spinner('Downloading model..'):
    #model_url = 'https://drive.google.com/file/d/1XOq3OkpQ3OgibjHmYOCMsQPBtqjdf2i3/view?usp=sharing'
    #download_model(model_url)

model = models.get('yolo_nas_m',
                        num_classes=len(CLASSES),
                        checkpoint_path="./ckpt_best_yolonas.pth")

device = 'cuda' if torch.cuda.is_available() else "cpu"
device = 'cpu'

if source_radio == 'Images':
    source_img = st.sidebar.file_uploader(
        "Choose an image...", type=("jpg", "jpeg", "png", 'bmp', 'webp'))
    col1, col2 = st.columns(2)

    with col1:
        try:
            if source_img is None:
                st.image('default_img.png', caption="Default Image",
                         use_column_width=True)
            else:
                uploaded_image = PIL.Image.open(source_img)
                st.image(source_img, caption="Uploaded Image",
                        use_column_width=True)
        except Exception as ex:
            st.error("Error occurred while opening the image.")
            st.error(ex)
    with col2:
        if source_img is None:
            st.image('default_img_res.png', caption="Detected Objects",
                    use_column_width=True)
        else:
            if st.sidebar.button('Detect Objects'):
                res = model.to(device).predict(uploaded_image,
                                    conf=confidence)
                st.image(res.draw(), caption='Detected Image',
                        use_column_width=True)

elif source_radio == 'Videos':
    source_vid = st.sidebar.file_uploader(
        "Choose a video ...", type=("mp4", "mov", "webM"))

    col1, col2 = st.columns(2)

    with col1:
        if source_vid is None:
            st.image('default_img.png', caption="Default Image",
                    use_column_width=True)
        else:
            try:
                uploaded_video = source_vid.getvalue()
                st.video(uploaded_video)
            except Exception as ex:
                st.error("Error occurred while opening the video.")
                st.error(ex)
    with col2:
        if source_vid is None:
            st.image('default_img_res.png', caption="Detected Objects",
                    use_column_width=True)
        else:
            if st.sidebar.button('Detect Objects'):
                temp_uploaded_path = write_video(source_vid)
                res = model.to(device).predict(temp_uploaded_path, conf=confidence)

                with st.spinner('Processing video ...'):
                    in_temp_res_path = "./temp/result.mp4"
                    out_temp_res_path = "./temp/result2.mp4"

                    res.save(in_temp_res_path)
                    convert_video(in_temp_res_path, out_temp_res_path)
                st.video(out_temp_res_path)

                os.remove(temp_uploaded_path)
                os.remove(in_temp_res_path)
                os.remove(out_temp_res_path)
else:
    st.error("Please select a valid source type!")