File size: 1,737 Bytes
ba92502
9a933a3
cd4c90e
9fbf078
cd4c90e
de2e31f
e5bb367
9fbf078
cd4c90e
de2e31f
 
9fbf078
ba92502
 
 
 
 
 
cd4c90e
ba92502
cd4c90e
b80c100
9fbf078
 
 
cd4c90e
 
ba92502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95d9d45
9a933a3
95d9d45
ba92502
 
 
 
 
 
 
 
84c5052
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
from gradio.networking import get_first_available_port
import PIL
import torch

from utils import plot_img_no_mask, get_models
from classifier import CustomEfficientNet, CustomViT
from model import get_model, predict, prepare_prediction, predict_class

DET_CKPT = 'efficientDet_icevision.ckpt'
CLASS_CKPT = 'class_ViT_taco_7_class.pth'

def waste_detector_interface(
    image,
    detection_threshold,
    nms_threshold
):  
    det_model, classifier = get_models(DET_CKPT, CLASS_CKPT)
    print('Getting predictions')
    pred_dict = predict(det_model, image, detection_threshold)
    print('Fixing the preds')
    boxes, image = prepare_prediction(pred_dict, nms_threshold)

    print('Predicting classes')
    labels = predict_class(classifier, image, boxes)
    print('Plotting')

    return plot_img_no_mask(image, boxes, labels)

inputs = [
    gr.inputs.Image(type="pil", label="Original Image"),
    gr.inputs.Number(default=0.5, label="detection_threshold"),
    gr.inputs.Number(default=0.5, label="nms_threshold"),
]

outputs = [
    gr.outputs.Image(type="plot", label="Prediction"),
]

title = 'Waste Detection'
description = 'Demo for waste object detection. It detects and classify waste in images according to which rubbish bin the waste should be thrown. Upload an image or click an image to use.'
examples = [
    ['example_imgs/basura_4_2.jpg', 0.5, 0.5],
    ['example_imgs/basura_1.jpg', 0.5, 0.5],
    ['example_imgs/basura_3.jpg', 0.5, 0.5]
]

gr.close_all()
port = get_first_available_port(1000, 9000)

gr.Interface(
    waste_detector_interface,
    inputs,
    outputs,
    title=title,
    description=description,
    examples=examples,
    theme="huggingface",
).launch(server_port=port)