Hector Lopez commited on
Commit
ba92502
·
1 Parent(s): e4cd286

Implemented gradio

Browse files
Files changed (2) hide show
  1. app.py +37 -55
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import streamlit as st
2
  import PIL
3
  import torch
4
 
@@ -9,65 +9,47 @@ from model import get_model, predict, prepare_prediction, predict_class
9
  DET_CKPT = 'efficientDet_icevision.ckpt'
10
  CLASS_CKPT = 'class_ViT_taco_7_class.pth'
11
 
12
- det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
13
-
14
- st.subheader('Upload Custom Image')
15
-
16
- image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
17
-
18
- st.subheader('Example Images')
19
-
20
- example_imgs = [
21
- 'example_imgs/basura_4_2.jpg',
22
- 'example_imgs/basura_1.jpg',
23
- 'example_imgs/basura_3.jpg'
24
- ]
25
-
26
- with st.container() as cont:
27
- st.image(example_imgs[0], width=150, caption='1')
28
- if st.button('Select Image', key='Image_1'):
29
- image_file = example_imgs[0]
30
-
31
- with st.container() as cont:
32
- st.image(example_imgs[1], width=150, caption='2')
33
- if st.button('Select Image', key='Image_2'):
34
- image_file = example_imgs[1]
35
-
36
- with st.container() as cont:
37
- st.image(example_imgs[2], width=150, caption='2')
38
- if st.button('Select Image', key='Image_3'):
39
- image_file = example_imgs[2]
40
-
41
- st.subheader('Detection parameters')
42
-
43
- detection_threshold = st.slider('Detection threshold',
44
- min_value=0.0,
45
- max_value=1.0,
46
- value=0.5,
47
- step=0.1)
48
-
49
- nms_threshold = st.slider('NMS threshold',
50
- min_value=0.0,
51
- max_value=1.0,
52
- value=0.3,
53
- step=0.1)
54
-
55
- st.subheader('Prediction')
56
-
57
- if image_file is not None:
58
  print('Getting predictions')
59
- if isinstance(image_file, str):
60
- data = image_file
61
- else:
62
- data = image_file.read()
63
- pred_dict = predict(det_model, data, detection_threshold)
64
  print('Fixing the preds')
65
  boxes, image = prepare_prediction(pred_dict, nms_threshold)
66
 
67
  print('Predicting classes')
68
  labels = predict_class(classifier, image, boxes)
69
  print('Plotting')
70
- plot_img_no_mask(image, boxes, labels)
71
 
72
- img = PIL.Image.open('img.png')
73
- st.image(img,width=750)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import PIL
3
  import torch
4
 
 
9
  DET_CKPT = 'efficientDet_icevision.ckpt'
10
  CLASS_CKPT = 'class_ViT_taco_7_class.pth'
11
 
12
+ def waste_detector_interface(
13
+ image,
14
+ detection_threshold,
15
+ nms_threshold
16
+ ):
17
+ det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  print('Getting predictions')
19
+ pred_dict = predict(det_model, image, detection_threshold)
 
 
 
 
20
  print('Fixing the preds')
21
  boxes, image = prepare_prediction(pred_dict, nms_threshold)
22
 
23
  print('Predicting classes')
24
  labels = predict_class(classifier, image, boxes)
25
  print('Plotting')
 
26
 
27
+ return plot_img_no_mask(image, boxes, labels)
28
+
29
+ inputs = [
30
+ gr.inputs.Image(type="pil", label="Original Image"),
31
+ gr.inputs.Number(default=0.5, label="detection_threshold"),
32
+ gr.inputs.Number(default=0.5, label="nms_threshold"),
33
+ ]
34
+
35
+ outputs = [
36
+ gr.outputs.Image(type="plot", label="Prediction"),
37
+ ]
38
+
39
+ title = 'Waste Detection'
40
+ description = 'Demo for waste object detection. It detects and classify waste in images according to which rubbish bin the waste should be thrown. Upload an image or click an image to use.'
41
+ examples = [
42
+ ['example_imgs/basura_4_2.jpg', 0.5, 0.5],
43
+ ['example_imgs/basura_1.jpg', 0.5, 0.5],
44
+ ['example_imgs/basura_3.jpg', 0.5, 0.5]
45
+ ]
46
+
47
+ gr.Interface(
48
+ waste_detector_interface,
49
+ inputs,
50
+ outputs,
51
+ title=title,
52
+ description=description,
53
+ examples=examples,
54
+ theme="huggingface",
55
+ ).launch(share=True, enable_queue=True)
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  icevision[all]
2
  matplotlib
3
  effdet
 
4
  streamlit==1.2.0
5
  Pillow==8.4.0
 
1
  icevision[all]
2
  matplotlib
3
  effdet
4
+ gradio
5
  streamlit==1.2.0
6
  Pillow==8.4.0