santit96's picture
Create the streamlit app that classifies the trash in an image into classes
fa84113
raw
history blame
6.37 kB
""" PyTorch EfficientDet support benches
Hacked together by Ross Wightman
"""
from typing import Optional, Dict, List
import torch
import torch.nn as nn
from timm.utils import ModelEma
from .anchors import Anchors, AnchorLabeler, generate_detections, MAX_DETECTION_POINTS
from .loss import DetectionLoss
def _post_process(
cls_outputs: List[torch.Tensor],
box_outputs: List[torch.Tensor],
num_levels: int,
num_classes: int,
max_detection_points: int = MAX_DETECTION_POINTS,
):
"""Selects top-k predictions.
Post-proc code adapted from Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet
and optimized for PyTorch.
Args:
cls_outputs: an OrderDict with keys representing levels and values
representing logits in [batch_size, height, width, num_anchors].
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in [batch_size, height, width, num_anchors * 4].
num_levels (int): number of feature levels
num_classes (int): number of output classes
"""
batch_size = cls_outputs[0].shape[0]
cls_outputs_all = torch.cat([
cls_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, num_classes])
for level in range(num_levels)], 1)
box_outputs_all = torch.cat([
box_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, 4])
for level in range(num_levels)], 1)
_, cls_topk_indices_all = torch.topk(cls_outputs_all.reshape(batch_size, -1), dim=1, k=max_detection_points)
indices_all = cls_topk_indices_all // num_classes
classes_all = cls_topk_indices_all % num_classes
box_outputs_all_after_topk = torch.gather(
box_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, 4))
cls_outputs_all_after_topk = torch.gather(
cls_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, num_classes))
cls_outputs_all_after_topk = torch.gather(
cls_outputs_all_after_topk, 2, classes_all.unsqueeze(2))
return cls_outputs_all_after_topk, box_outputs_all_after_topk, indices_all, classes_all
@torch.jit.script
def _batch_detection(
batch_size: int, class_out, box_out, anchor_boxes, indices, classes,
img_scale: Optional[torch.Tensor] = None, img_size: Optional[torch.Tensor] = None):
batch_detections = []
# FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome
for i in range(batch_size):
img_scale_i = None if img_scale is None else img_scale[i]
img_size_i = None if img_size is None else img_size[i]
detections = generate_detections(
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale_i, img_size_i)
batch_detections.append(detections)
return torch.stack(batch_detections, dim=0)
class DetBenchPredict(nn.Module):
def __init__(self, model):
super(DetBenchPredict, self).__init__()
self.model = model
self.config = model.config # FIXME remove this when we can use @property (torchscript limitation)
self.num_levels = model.config.num_levels
self.num_classes = model.config.num_classes
self.anchors = Anchors.from_config(model.config)
def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None):
class_out, box_out = self.model(x)
class_out, box_out, indices, classes = _post_process(
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
if img_info is None:
img_scale, img_size = None, None
else:
img_scale, img_size = img_info['img_scale'], img_info['img_size']
return _batch_detection(
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scale, img_size)
class DetBenchTrain(nn.Module):
def __init__(self, model, create_labeler=True):
super(DetBenchTrain, self).__init__()
self.model = model
self.config = model.config # FIXME remove this when we can use @property (torchscript limitation)
self.num_levels = model.config.num_levels
self.num_classes = model.config.num_classes
self.anchors = Anchors.from_config(model.config)
self.anchor_labeler = None
if create_labeler:
self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
self.loss_fn = DetectionLoss(model.config)
def forward(self, x, target: Dict[str, torch.Tensor]):
class_out, box_out = self.model(x)
if self.anchor_labeler is None:
# target should contain pre-computed anchor labels if labeler not present in bench
assert 'label_num_positives' in target
cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)]
box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)]
num_positives = target['label_num_positives']
else:
cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
target['bbox'], target['cls'])
loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss}
if not self.training:
# if eval mode, output detections for evaluation
class_out_pp, box_out_pp, indices, classes = _post_process(
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
output['detections'] = _batch_detection(
x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes,
target['img_scale'], target['img_size'])
return output
def unwrap_bench(model):
# Unwrap a model in support bench so that various other fns can access the weights and attribs of the
# underlying model directly
if isinstance(model, ModelEma): # unwrap ModelEma
return unwrap_bench(model.ema)
elif hasattr(model, 'module'): # unwrap DDP
return unwrap_bench(model.module)
elif hasattr(model, 'model'): # unwrap Bench -> model
return unwrap_bench(model.model)
else:
return model