import numpy as np import torch from fastai.vision.all import load_learner from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model def localize_trash(im, det_name, det_checkpoint, device, prob_threshold): # detector detector = set_model(det_name, 1, det_checkpoint, device) detector.eval() # mean-std normalize the input image (batch-size: 1) img = get_transforms(im) # propagate through the model outputs = detector(img.to(device)) # keep only predictions above set confidence bboxes_keep = outputs[0, outputs[0, :, 4] > prob_threshold] probas = bboxes_keep[:, 4:] # convert boxes to image scales bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:])) return probas, bboxes_scaled def classify_trash(im, clas_checkpoint, cls_th, probas, bboxes_scaled): # classifier classifier = load_learner(clas_checkpoint) bboxes_final = [] cls_prob = [] for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()): img = im.crop((xmin, ymin, xmax, ymax)) outputs = classifier.predict(img) p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item() p[0] = torch.max(np.trunc(outputs[2] * 100)) if p[0] >= cls_th * 100: bboxes_final.append((xmin, ymin, xmax, ymax)) cls_prob.append(p) return cls_prob, bboxes_final def detect_trash( im, det_name, det_checkpoint, clas_checkpoint, device, prob_threshold, cls_th ): # prepare models for evaluation torch.set_grad_enabled(False) # 1) Localize probas, bboxes_scaled = localize_trash( im, det_name, det_checkpoint, device, prob_threshold ) # 2) Classify cls_prob, bboxes_final = classify_trash( im, clas_checkpoint, cls_th, probas, bboxes_scaled ) return cls_prob, bboxes_final