Spaces:
Runtime error
Runtime error
File size: 3,247 Bytes
782cec7 de2e31f cd4c90e bc3d4e9 cd4c90e bc3d4e9 cd4c90e de2e31f bc3d4e9 cd4c90e de2e31f cd4c90e f890c24 cd4c90e b80c100 cd4c90e b80c100 cd4c90e b80c100 cd4c90e 9fbf078 bc3d4e9 9fbf078 e5bb367 bc3d4e9 9fbf078 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from io import BytesIO
from typing import Union
from icevision import *
from icevision.models.checkpoint import model_from_checkpoint
from classifier import transform_image
from icevision.models import ross
import collections
import PIL
import torch
import numpy as np
import torchvision
MODEL_TYPE = ross.efficientdet
def get_model(checkpoint_path : str):
checkpoint_and_model = model_from_checkpoint(
checkpoint_path,
model_name='ross.efficientdet',
backbone_name='d0',
img_size=512,
classes=['Waste'],
revise_keys=[(r'^model\.', '')])
model = checkpoint_and_model['model']
return model
def get_checkpoint(checkpoint_path : str):
ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
fixed_state_dict = collections.OrderedDict()
for k, v in ckpt['state_dict'].items():
new_k = k[6:]
fixed_state_dict[new_k] = v
return fixed_state_dict
def predict(model : object, image : Union[str, BytesIO], detection_threshold : float):
img = PIL.Image.open(image)
#img = PIL.Image.open(BytesIO(image))
img = np.array(img)
img = PIL.Image.fromarray(img)
class_map = ClassMap(classes=['Waste'])
transforms = tfms.A.Adapter([
*tfms.A.resize_and_pad(512),
tfms.A.Normalize()
])
pred_dict = MODEL_TYPE.end2end_detect(img,
transforms,
model,
class_map=class_map,
detection_threshold=detection_threshold,
return_as_pil_img=False,
return_img=True,
display_bbox=False,
display_score=False,
display_label=False)
return pred_dict
def prepare_prediction(pred_dict, threshold):
boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
boxes = torch.stack(boxes)
scores = torch.as_tensor(pred_dict['detection']['scores'])
labels = torch.as_tensor(pred_dict['detection']['label_ids'])
image = np.array(pred_dict['img'])
fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
boxes = boxes[fixed_boxes, :]
return boxes, image
def predict_class(classifier, image, bboxes):
preds = []
for bbox in bboxes:
img = image.copy()
bbox = np.array(bbox).astype(int)
cropped_img = PIL.Image.fromarray(img).crop(bbox)
cropped_img = np.array(cropped_img)
tran_image = transform_image(cropped_img, 224)
tran_image = tran_image.transpose(2, 0, 1)
tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
print(tran_image.shape)
y_preds = classifier(tran_image)
preds.append(y_preds.softmax(1).detach().numpy())
preds = np.concatenate(preds).argmax(1)
return preds |