Spaces:
Runtime error
Runtime error
Hector Lopez
commited on
Commit
·
de2e31f
1
Parent(s):
3ab2a3b
feature: Implemented gradio
Browse files
app.py
CHANGED
@@ -1,121 +1,55 @@
|
|
1 |
-
import
|
2 |
-
import matplotlib.pyplot as plt
|
3 |
-
import numpy as np
|
4 |
-
import cv2
|
5 |
import PIL
|
6 |
import torch
|
7 |
|
|
|
8 |
from classifier import CustomEfficientNet, CustomViT
|
9 |
from model import get_model, predict, prepare_prediction, predict_class
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
print('Loading the classifier')
|
14 |
-
classifier = CustomViT(target_size=7, pretrained=False)
|
15 |
-
classifier.load_state_dict(torch.load('class_ViT_taco_7_class.pth', map_location='cpu'))
|
16 |
-
# Set eval mode to deactivate dropout and BN layers
|
17 |
-
classifier.eval()
|
18 |
|
19 |
-
def
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
4: (255,165,0),
|
26 |
-
5: (230,230,250),
|
27 |
-
6: (192,192,192)
|
28 |
-
}
|
29 |
-
|
30 |
-
texts = {
|
31 |
-
0: 'plastic',
|
32 |
-
1: 'dangerous',
|
33 |
-
2: 'carton',
|
34 |
-
3: 'glass',
|
35 |
-
4: 'organic',
|
36 |
-
5: 'rest',
|
37 |
-
6: 'other'
|
38 |
-
}
|
39 |
-
|
40 |
-
# Show image
|
41 |
-
boxes = boxes.cpu().detach().numpy().astype(np.int32)
|
42 |
-
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
|
43 |
-
|
44 |
-
for i, box in enumerate(boxes):
|
45 |
-
color = colors[labels[i]]
|
46 |
-
|
47 |
-
[x1, y1, x2, y2] = np.array(box).astype(int)
|
48 |
-
# Si no se hace la copia da error en cv2.rectangle
|
49 |
-
image = np.array(image).copy()
|
50 |
-
|
51 |
-
pt1 = (x1, y1)
|
52 |
-
pt2 = (x2, y2)
|
53 |
-
cv2.rectangle(image, pt1, pt2, color, thickness=5)
|
54 |
-
cv2.putText(image, texts[labels[i]], (x1, y1-10),
|
55 |
-
cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
|
56 |
-
|
57 |
-
|
58 |
-
plt.axis('off')
|
59 |
-
ax.imshow(image)
|
60 |
-
fig.savefig("img.png", bbox_inches='tight')
|
61 |
-
|
62 |
-
st.subheader('Upload Custom Image')
|
63 |
-
|
64 |
-
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
|
65 |
-
|
66 |
-
st.subheader('Example Images')
|
67 |
-
|
68 |
-
example_imgs = [
|
69 |
-
'example_imgs/basura_4_2.jpg',
|
70 |
-
'example_imgs/basura_1.jpg',
|
71 |
-
'example_imgs/basura_3.jpg'
|
72 |
-
]
|
73 |
-
|
74 |
-
with st.container() as cont:
|
75 |
-
st.image(example_imgs[0], width=150, caption='1')
|
76 |
-
if st.button('Select Image', key='Image_1'):
|
77 |
-
image_file = example_imgs[0]
|
78 |
-
|
79 |
-
with st.container() as cont:
|
80 |
-
st.image(example_imgs[1], width=150, caption='2')
|
81 |
-
if st.button('Select Image', key='Image_2'):
|
82 |
-
image_file = example_imgs[1]
|
83 |
-
|
84 |
-
with st.container() as cont:
|
85 |
-
st.image(example_imgs[2], width=150, caption='2')
|
86 |
-
if st.button('Select Image', key='Image_3'):
|
87 |
-
image_file = example_imgs[2]
|
88 |
-
|
89 |
-
st.subheader('Detection parameters')
|
90 |
-
|
91 |
-
detection_threshold = st.slider('Detection threshold',
|
92 |
-
min_value=0.0,
|
93 |
-
max_value=1.0,
|
94 |
-
value=0.5,
|
95 |
-
step=0.1)
|
96 |
-
|
97 |
-
nms_threshold = st.slider('NMS threshold',
|
98 |
-
min_value=0.0,
|
99 |
-
max_value=1.0,
|
100 |
-
value=0.3,
|
101 |
-
step=0.1)
|
102 |
-
|
103 |
-
st.subheader('Prediction')
|
104 |
-
|
105 |
-
if image_file is not None:
|
106 |
print('Getting predictions')
|
107 |
-
|
108 |
-
data = image_file
|
109 |
-
else:
|
110 |
-
data = image_file.read()
|
111 |
-
pred_dict = predict(model, data, detection_threshold)
|
112 |
print('Fixing the preds')
|
113 |
boxes, image = prepare_prediction(pred_dict, nms_threshold)
|
114 |
|
115 |
print('Predicting classes')
|
116 |
labels = predict_class(classifier, image, boxes)
|
117 |
print('Plotting')
|
118 |
-
plot_img_no_mask(image, boxes, labels)
|
119 |
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
|
|
|
|
|
|
2 |
import PIL
|
3 |
import torch
|
4 |
|
5 |
+
from utils import plot_img_no_mask, get_models
|
6 |
from classifier import CustomEfficientNet, CustomViT
|
7 |
from model import get_model, predict, prepare_prediction, predict_class
|
8 |
|
9 |
+
DET_CKPT = 'efficientDet_icevision.ckpt'
|
10 |
+
CLASS_CKPT = 'class_ViT_taco_7_class.pth'
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def waste_detector_interface(
|
13 |
+
image,
|
14 |
+
detection_threshold,
|
15 |
+
nms_threshold
|
16 |
+
):
|
17 |
+
det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
print('Getting predictions')
|
19 |
+
pred_dict = predict(det_model, image, detection_threshold)
|
|
|
|
|
|
|
|
|
20 |
print('Fixing the preds')
|
21 |
boxes, image = prepare_prediction(pred_dict, nms_threshold)
|
22 |
|
23 |
print('Predicting classes')
|
24 |
labels = predict_class(classifier, image, boxes)
|
25 |
print('Plotting')
|
|
|
26 |
|
27 |
+
return plot_img_no_mask(image, boxes, labels)
|
28 |
+
|
29 |
+
inputs = [
|
30 |
+
gr.inputs.Image(type="pil", label="Original Image"),
|
31 |
+
gr.inputs.Number(default=0.5, label="detection_threshold"),
|
32 |
+
gr.inputs.Number(default=0.5, label="nms_threshold"),
|
33 |
+
]
|
34 |
+
|
35 |
+
outputs = [
|
36 |
+
gr.outputs.Image(type="plot", label="Prediction"),
|
37 |
+
]
|
38 |
+
|
39 |
+
title = 'Waste Detection'
|
40 |
+
description = 'Demo for waste object detection. It detects and classify waste in images according to which rubbish bin the waste should be thrown. Upload an image or click an image to use.'
|
41 |
+
examples = [
|
42 |
+
['example_imgs/basura_4_2.jpg', 0.5, 0.5],
|
43 |
+
['example_imgs/basura_1.jpg', 0.5, 0.5],
|
44 |
+
['example_imgs/basura_3.jpg', 0.5, 0.5]
|
45 |
+
]
|
46 |
+
|
47 |
+
gr.Interface(
|
48 |
+
waste_detector_interface,
|
49 |
+
inputs,
|
50 |
+
outputs,
|
51 |
+
title=title,
|
52 |
+
description=description,
|
53 |
+
examples=examples,
|
54 |
+
theme="huggingface",
|
55 |
+
).launch(debug=True, enable_queue=True)
|
model.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from io import BytesIO
|
|
|
2 |
from icevision import *
|
3 |
import collections
|
4 |
import PIL
|
@@ -12,7 +13,7 @@ import icevision.models.ross.efficientdet
|
|
12 |
|
13 |
MODEL_TYPE = icevision.models.ross.efficientdet
|
14 |
|
15 |
-
def get_model(checkpoint_path):
|
16 |
extra_args = {}
|
17 |
backbone = MODEL_TYPE.backbones.d0
|
18 |
# The efficientdet model requires an img_size parameter
|
@@ -27,8 +28,8 @@ def get_model(checkpoint_path):
|
|
27 |
|
28 |
return model
|
29 |
|
30 |
-
def get_checkpoint(checkpoint_path):
|
31 |
-
ckpt = torch.load(
|
32 |
|
33 |
fixed_state_dict = collections.OrderedDict()
|
34 |
|
@@ -38,15 +39,7 @@ def get_checkpoint(checkpoint_path):
|
|
38 |
|
39 |
return fixed_state_dict
|
40 |
|
41 |
-
def predict(model,
|
42 |
-
if isinstance(image, str):
|
43 |
-
img = PIL.Image.open(image)
|
44 |
-
else:
|
45 |
-
img = PIL.Image.open(BytesIO(image))
|
46 |
-
|
47 |
-
img = np.array(img)
|
48 |
-
img = PIL.Image.fromarray(img)
|
49 |
-
|
50 |
class_map = ClassMap(classes=['Waste'])
|
51 |
transforms = tfms.A.Adapter([
|
52 |
*tfms.A.resize_and_pad(512),
|
|
|
1 |
from io import BytesIO
|
2 |
+
from typing import Union
|
3 |
from icevision import *
|
4 |
import collections
|
5 |
import PIL
|
|
|
13 |
|
14 |
MODEL_TYPE = icevision.models.ross.efficientdet
|
15 |
|
16 |
+
def get_model(checkpoint_path : str):
|
17 |
extra_args = {}
|
18 |
backbone = MODEL_TYPE.backbones.d0
|
19 |
# The efficientdet model requires an img_size parameter
|
|
|
28 |
|
29 |
return model
|
30 |
|
31 |
+
def get_checkpoint(checkpoint_path : str):
|
32 |
+
ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
33 |
|
34 |
fixed_state_dict = collections.OrderedDict()
|
35 |
|
|
|
39 |
|
40 |
return fixed_state_dict
|
41 |
|
42 |
+
def predict(model : object, img : Union[str, BytesIO], detection_threshold : float):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
class_map = ClassMap(classes=['Waste'])
|
44 |
transforms = tfms.A.Adapter([
|
45 |
*tfms.A.resize_and_pad(512),
|
utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from classifier import CustomViT
|
8 |
+
from model import get_model
|
9 |
+
|
10 |
+
def plot_img_no_mask(image : np.ndarray, boxes, labels):
|
11 |
+
colors = {
|
12 |
+
0: (255,255,0),
|
13 |
+
1: (255, 0, 0),
|
14 |
+
2: (0, 0, 255),
|
15 |
+
3: (0,128,0),
|
16 |
+
4: (255,165,0),
|
17 |
+
5: (230,230,250),
|
18 |
+
6: (192,192,192)
|
19 |
+
}
|
20 |
+
|
21 |
+
texts = {
|
22 |
+
0: 'plastic',
|
23 |
+
1: 'dangerous',
|
24 |
+
2: 'carton',
|
25 |
+
3: 'glass',
|
26 |
+
4: 'organic',
|
27 |
+
5: 'rest',
|
28 |
+
6: 'other'
|
29 |
+
}
|
30 |
+
|
31 |
+
# Show image
|
32 |
+
boxes = boxes.cpu().detach().numpy().astype(np.int32)
|
33 |
+
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
|
34 |
+
|
35 |
+
for i, box in enumerate(boxes):
|
36 |
+
color = colors[labels[i]]
|
37 |
+
|
38 |
+
[x1, y1, x2, y2] = np.array(box).astype(int)
|
39 |
+
# Si no se hace la copia da error en cv2.rectangle
|
40 |
+
image = np.array(image).copy()
|
41 |
+
|
42 |
+
pt1 = (x1, y1)
|
43 |
+
pt2 = (x2, y2)
|
44 |
+
cv2.rectangle(image, pt1, pt2, color, thickness=5)
|
45 |
+
cv2.putText(image, texts[labels[i]], (x1, y1-10),
|
46 |
+
cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
|
47 |
+
|
48 |
+
|
49 |
+
plt.axis('off')
|
50 |
+
ax.imshow(image)
|
51 |
+
|
52 |
+
return fig
|
53 |
+
|
54 |
+
def get_models(
|
55 |
+
detection_ckpt : str,
|
56 |
+
classifier_ckpt : str
|
57 |
+
) -> Tuple[torch.nn.Module, torch.nn.Module]:
|
58 |
+
"""
|
59 |
+
Get the detection and classifier models
|
60 |
+
|
61 |
+
Args:
|
62 |
+
detection_ckpt (str): Detection model checkpoint
|
63 |
+
classifier_ckpt (str): Classifier model checkpoint
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
tuple: Tuple containing:
|
67 |
+
- (torch.nn.Module): Detection model
|
68 |
+
- (torch.nn.Module): Classifier model
|
69 |
+
"""
|
70 |
+
print('Loading the detection model')
|
71 |
+
det_model = get_model(detection_ckpt)
|
72 |
+
det_model.eval()
|
73 |
+
|
74 |
+
print('Loading the classifier model')
|
75 |
+
classifier = CustomViT(target_size=7, pretrained=False)
|
76 |
+
classifier.load_state_dict(torch.load(classifier_ckpt, map_location='cpu'))
|
77 |
+
classifier.eval()
|
78 |
+
|
79 |
+
return det_model, classifier
|