Hector Lopez commited on
Commit
de2e31f
·
1 Parent(s): 3ab2a3b

feature: Implemented gradio

Browse files
Files changed (3) hide show
  1. app.py +40 -106
  2. model.py +5 -12
  3. utils.py +79 -0
app.py CHANGED
@@ -1,121 +1,55 @@
1
- import streamlit as st
2
- import matplotlib.pyplot as plt
3
- import numpy as np
4
- import cv2
5
  import PIL
6
  import torch
7
 
 
8
  from classifier import CustomEfficientNet, CustomViT
9
  from model import get_model, predict, prepare_prediction, predict_class
10
 
11
- print('Creating the model')
12
- model = get_model('efficientDet_icevision.ckpt')
13
- print('Loading the classifier')
14
- classifier = CustomViT(target_size=7, pretrained=False)
15
- classifier.load_state_dict(torch.load('class_ViT_taco_7_class.pth', map_location='cpu'))
16
- # Set eval mode to deactivate dropout and BN layers
17
- classifier.eval()
18
 
19
- def plot_img_no_mask(image, boxes, labels):
20
- colors = {
21
- 0: (255,255,0),
22
- 1: (255, 0, 0),
23
- 2: (0, 0, 255),
24
- 3: (0,128,0),
25
- 4: (255,165,0),
26
- 5: (230,230,250),
27
- 6: (192,192,192)
28
- }
29
-
30
- texts = {
31
- 0: 'plastic',
32
- 1: 'dangerous',
33
- 2: 'carton',
34
- 3: 'glass',
35
- 4: 'organic',
36
- 5: 'rest',
37
- 6: 'other'
38
- }
39
-
40
- # Show image
41
- boxes = boxes.cpu().detach().numpy().astype(np.int32)
42
- fig, ax = plt.subplots(1, 1, figsize=(12, 6))
43
-
44
- for i, box in enumerate(boxes):
45
- color = colors[labels[i]]
46
-
47
- [x1, y1, x2, y2] = np.array(box).astype(int)
48
- # Si no se hace la copia da error en cv2.rectangle
49
- image = np.array(image).copy()
50
-
51
- pt1 = (x1, y1)
52
- pt2 = (x2, y2)
53
- cv2.rectangle(image, pt1, pt2, color, thickness=5)
54
- cv2.putText(image, texts[labels[i]], (x1, y1-10),
55
- cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
56
-
57
-
58
- plt.axis('off')
59
- ax.imshow(image)
60
- fig.savefig("img.png", bbox_inches='tight')
61
-
62
- st.subheader('Upload Custom Image')
63
-
64
- image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
65
-
66
- st.subheader('Example Images')
67
-
68
- example_imgs = [
69
- 'example_imgs/basura_4_2.jpg',
70
- 'example_imgs/basura_1.jpg',
71
- 'example_imgs/basura_3.jpg'
72
- ]
73
-
74
- with st.container() as cont:
75
- st.image(example_imgs[0], width=150, caption='1')
76
- if st.button('Select Image', key='Image_1'):
77
- image_file = example_imgs[0]
78
-
79
- with st.container() as cont:
80
- st.image(example_imgs[1], width=150, caption='2')
81
- if st.button('Select Image', key='Image_2'):
82
- image_file = example_imgs[1]
83
-
84
- with st.container() as cont:
85
- st.image(example_imgs[2], width=150, caption='2')
86
- if st.button('Select Image', key='Image_3'):
87
- image_file = example_imgs[2]
88
-
89
- st.subheader('Detection parameters')
90
-
91
- detection_threshold = st.slider('Detection threshold',
92
- min_value=0.0,
93
- max_value=1.0,
94
- value=0.5,
95
- step=0.1)
96
-
97
- nms_threshold = st.slider('NMS threshold',
98
- min_value=0.0,
99
- max_value=1.0,
100
- value=0.3,
101
- step=0.1)
102
-
103
- st.subheader('Prediction')
104
-
105
- if image_file is not None:
106
  print('Getting predictions')
107
- if isinstance(image_file, str):
108
- data = image_file
109
- else:
110
- data = image_file.read()
111
- pred_dict = predict(model, data, detection_threshold)
112
  print('Fixing the preds')
113
  boxes, image = prepare_prediction(pred_dict, nms_threshold)
114
 
115
  print('Predicting classes')
116
  labels = predict_class(classifier, image, boxes)
117
  print('Plotting')
118
- plot_img_no_mask(image, boxes, labels)
119
 
120
- img = PIL.Image.open('img.png')
121
- st.image(img,width=750)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
 
2
  import PIL
3
  import torch
4
 
5
+ from utils import plot_img_no_mask, get_models
6
  from classifier import CustomEfficientNet, CustomViT
7
  from model import get_model, predict, prepare_prediction, predict_class
8
 
9
+ DET_CKPT = 'efficientDet_icevision.ckpt'
10
+ CLASS_CKPT = 'class_ViT_taco_7_class.pth'
 
 
 
 
 
11
 
12
+ def waste_detector_interface(
13
+ image,
14
+ detection_threshold,
15
+ nms_threshold
16
+ ):
17
+ det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  print('Getting predictions')
19
+ pred_dict = predict(det_model, image, detection_threshold)
 
 
 
 
20
  print('Fixing the preds')
21
  boxes, image = prepare_prediction(pred_dict, nms_threshold)
22
 
23
  print('Predicting classes')
24
  labels = predict_class(classifier, image, boxes)
25
  print('Plotting')
 
26
 
27
+ return plot_img_no_mask(image, boxes, labels)
28
+
29
+ inputs = [
30
+ gr.inputs.Image(type="pil", label="Original Image"),
31
+ gr.inputs.Number(default=0.5, label="detection_threshold"),
32
+ gr.inputs.Number(default=0.5, label="nms_threshold"),
33
+ ]
34
+
35
+ outputs = [
36
+ gr.outputs.Image(type="plot", label="Prediction"),
37
+ ]
38
+
39
+ title = 'Waste Detection'
40
+ description = 'Demo for waste object detection. It detects and classify waste in images according to which rubbish bin the waste should be thrown. Upload an image or click an image to use.'
41
+ examples = [
42
+ ['example_imgs/basura_4_2.jpg', 0.5, 0.5],
43
+ ['example_imgs/basura_1.jpg', 0.5, 0.5],
44
+ ['example_imgs/basura_3.jpg', 0.5, 0.5]
45
+ ]
46
+
47
+ gr.Interface(
48
+ waste_detector_interface,
49
+ inputs,
50
+ outputs,
51
+ title=title,
52
+ description=description,
53
+ examples=examples,
54
+ theme="huggingface",
55
+ ).launch(debug=True, enable_queue=True)
model.py CHANGED
@@ -1,4 +1,5 @@
1
  from io import BytesIO
 
2
  from icevision import *
3
  import collections
4
  import PIL
@@ -12,7 +13,7 @@ import icevision.models.ross.efficientdet
12
 
13
  MODEL_TYPE = icevision.models.ross.efficientdet
14
 
15
- def get_model(checkpoint_path):
16
  extra_args = {}
17
  backbone = MODEL_TYPE.backbones.d0
18
  # The efficientdet model requires an img_size parameter
@@ -27,8 +28,8 @@ def get_model(checkpoint_path):
27
 
28
  return model
29
 
30
- def get_checkpoint(checkpoint_path):
31
- ckpt = torch.load('checkpoint.ckpt', map_location=torch.device('cpu'))
32
 
33
  fixed_state_dict = collections.OrderedDict()
34
 
@@ -38,15 +39,7 @@ def get_checkpoint(checkpoint_path):
38
 
39
  return fixed_state_dict
40
 
41
- def predict(model, image, detection_threshold):
42
- if isinstance(image, str):
43
- img = PIL.Image.open(image)
44
- else:
45
- img = PIL.Image.open(BytesIO(image))
46
-
47
- img = np.array(img)
48
- img = PIL.Image.fromarray(img)
49
-
50
  class_map = ClassMap(classes=['Waste'])
51
  transforms = tfms.A.Adapter([
52
  *tfms.A.resize_and_pad(512),
 
1
  from io import BytesIO
2
+ from typing import Union
3
  from icevision import *
4
  import collections
5
  import PIL
 
13
 
14
  MODEL_TYPE = icevision.models.ross.efficientdet
15
 
16
+ def get_model(checkpoint_path : str):
17
  extra_args = {}
18
  backbone = MODEL_TYPE.backbones.d0
19
  # The efficientdet model requires an img_size parameter
 
28
 
29
  return model
30
 
31
+ def get_checkpoint(checkpoint_path : str):
32
+ ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
33
 
34
  fixed_state_dict = collections.OrderedDict()
35
 
 
39
 
40
  return fixed_state_dict
41
 
42
+ def predict(model : object, img : Union[str, BytesIO], detection_threshold : float):
 
 
 
 
 
 
 
 
43
  class_map = ClassMap(classes=['Waste'])
44
  transforms = tfms.A.Adapter([
45
  *tfms.A.resize_and_pad(512),
utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import cv2
5
+ import torch
6
+
7
+ from classifier import CustomViT
8
+ from model import get_model
9
+
10
+ def plot_img_no_mask(image : np.ndarray, boxes, labels):
11
+ colors = {
12
+ 0: (255,255,0),
13
+ 1: (255, 0, 0),
14
+ 2: (0, 0, 255),
15
+ 3: (0,128,0),
16
+ 4: (255,165,0),
17
+ 5: (230,230,250),
18
+ 6: (192,192,192)
19
+ }
20
+
21
+ texts = {
22
+ 0: 'plastic',
23
+ 1: 'dangerous',
24
+ 2: 'carton',
25
+ 3: 'glass',
26
+ 4: 'organic',
27
+ 5: 'rest',
28
+ 6: 'other'
29
+ }
30
+
31
+ # Show image
32
+ boxes = boxes.cpu().detach().numpy().astype(np.int32)
33
+ fig, ax = plt.subplots(1, 1, figsize=(12, 6))
34
+
35
+ for i, box in enumerate(boxes):
36
+ color = colors[labels[i]]
37
+
38
+ [x1, y1, x2, y2] = np.array(box).astype(int)
39
+ # Si no se hace la copia da error en cv2.rectangle
40
+ image = np.array(image).copy()
41
+
42
+ pt1 = (x1, y1)
43
+ pt2 = (x2, y2)
44
+ cv2.rectangle(image, pt1, pt2, color, thickness=5)
45
+ cv2.putText(image, texts[labels[i]], (x1, y1-10),
46
+ cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
47
+
48
+
49
+ plt.axis('off')
50
+ ax.imshow(image)
51
+
52
+ return fig
53
+
54
+ def get_models(
55
+ detection_ckpt : str,
56
+ classifier_ckpt : str
57
+ ) -> Tuple[torch.nn.Module, torch.nn.Module]:
58
+ """
59
+ Get the detection and classifier models
60
+
61
+ Args:
62
+ detection_ckpt (str): Detection model checkpoint
63
+ classifier_ckpt (str): Classifier model checkpoint
64
+
65
+ Returns:
66
+ tuple: Tuple containing:
67
+ - (torch.nn.Module): Detection model
68
+ - (torch.nn.Module): Classifier model
69
+ """
70
+ print('Loading the detection model')
71
+ det_model = get_model(detection_ckpt)
72
+ det_model.eval()
73
+
74
+ print('Loading the classifier model')
75
+ classifier = CustomViT(target_size=7, pretrained=False)
76
+ classifier.load_state_dict(torch.load(classifier_ckpt, map_location='cpu'))
77
+ classifier.eval()
78
+
79
+ return det_model, classifier