Spaces:
Runtime error
Runtime error
Hector Lopez
commited on
Commit
·
e4cd286
1
Parent(s):
528f436
refactor: Using streamlit again
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import
|
2 |
import PIL
|
3 |
import torch
|
4 |
|
@@ -9,47 +9,65 @@ from model import get_model, predict, prepare_prediction, predict_class
|
|
9 |
DET_CKPT = 'efficientDet_icevision.ckpt'
|
10 |
CLASS_CKPT = 'class_ViT_taco_7_class.pth'
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
)
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
print('Getting predictions')
|
19 |
-
|
|
|
|
|
|
|
|
|
20 |
print('Fixing the preds')
|
21 |
boxes, image = prepare_prediction(pred_dict, nms_threshold)
|
22 |
|
23 |
print('Predicting classes')
|
24 |
labels = predict_class(classifier, image, boxes)
|
25 |
print('Plotting')
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
inputs = [
|
30 |
-
gr.inputs.Image(type="pil", label="Original Image"),
|
31 |
-
gr.inputs.Number(default=0.5, label="detection_threshold"),
|
32 |
-
gr.inputs.Number(default=0.5, label="nms_threshold"),
|
33 |
-
]
|
34 |
-
|
35 |
-
outputs = [
|
36 |
-
gr.outputs.Image(type="plot", label="Prediction"),
|
37 |
-
]
|
38 |
-
|
39 |
-
title = 'Waste Detection'
|
40 |
-
description = 'Demo for waste object detection. It detects and classify waste in images according to which rubbish bin the waste should be thrown. Upload an image or click an image to use.'
|
41 |
-
examples = [
|
42 |
-
['example_imgs/basura_4_2.jpg', 0.5, 0.5],
|
43 |
-
['example_imgs/basura_1.jpg', 0.5, 0.5],
|
44 |
-
['example_imgs/basura_3.jpg', 0.5, 0.5]
|
45 |
-
]
|
46 |
-
|
47 |
-
gr.Interface(
|
48 |
-
waste_detector_interface,
|
49 |
-
inputs,
|
50 |
-
outputs,
|
51 |
-
title=title,
|
52 |
-
description=description,
|
53 |
-
examples=examples,
|
54 |
-
theme="huggingface",
|
55 |
-
).launch(share=True, enable_queue=True)
|
|
|
1 |
+
import streamlit as st
|
2 |
import PIL
|
3 |
import torch
|
4 |
|
|
|
9 |
DET_CKPT = 'efficientDet_icevision.ckpt'
|
10 |
CLASS_CKPT = 'class_ViT_taco_7_class.pth'
|
11 |
|
12 |
+
det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
|
13 |
+
|
14 |
+
st.subheader('Upload Custom Image')
|
15 |
+
|
16 |
+
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
|
17 |
+
|
18 |
+
st.subheader('Example Images')
|
19 |
+
|
20 |
+
example_imgs = [
|
21 |
+
'example_imgs/basura_4_2.jpg',
|
22 |
+
'example_imgs/basura_1.jpg',
|
23 |
+
'example_imgs/basura_3.jpg'
|
24 |
+
]
|
25 |
+
|
26 |
+
with st.container() as cont:
|
27 |
+
st.image(example_imgs[0], width=150, caption='1')
|
28 |
+
if st.button('Select Image', key='Image_1'):
|
29 |
+
image_file = example_imgs[0]
|
30 |
+
|
31 |
+
with st.container() as cont:
|
32 |
+
st.image(example_imgs[1], width=150, caption='2')
|
33 |
+
if st.button('Select Image', key='Image_2'):
|
34 |
+
image_file = example_imgs[1]
|
35 |
+
|
36 |
+
with st.container() as cont:
|
37 |
+
st.image(example_imgs[2], width=150, caption='2')
|
38 |
+
if st.button('Select Image', key='Image_3'):
|
39 |
+
image_file = example_imgs[2]
|
40 |
+
|
41 |
+
st.subheader('Detection parameters')
|
42 |
+
|
43 |
+
detection_threshold = st.slider('Detection threshold',
|
44 |
+
min_value=0.0,
|
45 |
+
max_value=1.0,
|
46 |
+
value=0.5,
|
47 |
+
step=0.1)
|
48 |
+
|
49 |
+
nms_threshold = st.slider('NMS threshold',
|
50 |
+
min_value=0.0,
|
51 |
+
max_value=1.0,
|
52 |
+
value=0.3,
|
53 |
+
step=0.1)
|
54 |
+
|
55 |
+
st.subheader('Prediction')
|
56 |
+
|
57 |
+
if image_file is not None:
|
58 |
print('Getting predictions')
|
59 |
+
if isinstance(image_file, str):
|
60 |
+
data = image_file
|
61 |
+
else:
|
62 |
+
data = image_file.read()
|
63 |
+
pred_dict = predict(det_model, data, detection_threshold)
|
64 |
print('Fixing the preds')
|
65 |
boxes, image = prepare_prediction(pred_dict, nms_threshold)
|
66 |
|
67 |
print('Predicting classes')
|
68 |
labels = predict_class(classifier, image, boxes)
|
69 |
print('Plotting')
|
70 |
+
plot_img_no_mask(image, boxes, labels)
|
71 |
|
72 |
+
img = PIL.Image.open('img.png')
|
73 |
+
st.image(img,width=750)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|