File size: 1,800 Bytes
ba92502
9a933a3
cd4c90e
9fbf078
ad55672
cd4c90e
de2e31f
e5bb367
9fbf078
cd4c90e
c06425f
ad55672
de2e31f
 
9fbf078
ba92502
 
 
 
 
 
cd4c90e
ba92502
cd4c90e
b80c100
9fbf078
 
 
cd4c90e
 
ba92502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a5c74
ca86eaa
95d9d45
86d9b44
 
 
 
 
 
 
ab0c2de
 
6438514
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from gradio.networking import get_first_available_port
import PIL
import torch
import os

from utils import plot_img_no_mask, get_models
from classifier import CustomEfficientNet, CustomViT
from model import get_model, predict, prepare_prediction, predict_class

os.system('pkill -9 python')

DET_CKPT = 'efficientDet_icevision.ckpt'
CLASS_CKPT = 'class_ViT_taco_7_class.pth'

def waste_detector_interface(
    image,
    detection_threshold,
    nms_threshold
):  
    det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
    print('Getting predictions')
    pred_dict = predict(det_model, image, detection_threshold)
    print('Fixing the preds')
    boxes, image = prepare_prediction(pred_dict, nms_threshold)

    print('Predicting classes')
    labels = predict_class(classifier, image, boxes)
    print('Plotting')

    return plot_img_no_mask(image, boxes, labels)

inputs = [
    gr.inputs.Image(type="pil", label="Original Image"),
    gr.inputs.Number(default=0.5, label="detection_threshold"),
    gr.inputs.Number(default=0.5, label="nms_threshold"),
]

outputs = [
    gr.outputs.Image(type="plot", label="Prediction"),
]

title = 'Waste Detection'
description = 'Demo for waste object detection. It detects and classify waste in images according to which rubbish bin the waste should be thrown. Upload an image or click an image to use.'
examples = [
    ['example_imgs/basura_4_2.jpg', 0.5, 0.5],
    ['example_imgs/basura_1.jpg', 0.5, 0.5],
    ['example_imgs/basura_3.jpg', 0.5, 0.5]
]

gr.close_all()
#port = get_first_available_port(7682, 9000)

gr.Interface(
    waste_detector_interface,
    inputs,
    outputs,
    title=title,
    description=description,
    examples=examples,
    theme="huggingface"
).launch(share=True)

os.system('python3 app.py')