Spaces:
Sleeping
Sleeping
""" PyTorch EfficientDet support benches | |
Hacked together by Ross Wightman | |
""" | |
from typing import Optional, Dict, List | |
import torch | |
import torch.nn as nn | |
from timm.utils import ModelEma | |
from .anchors import Anchors, AnchorLabeler, generate_detections, MAX_DETECTION_POINTS | |
from .loss import DetectionLoss | |
def _post_process( | |
cls_outputs: List[torch.Tensor], | |
box_outputs: List[torch.Tensor], | |
num_levels: int, | |
num_classes: int, | |
max_detection_points: int = MAX_DETECTION_POINTS, | |
): | |
"""Selects top-k predictions. | |
Post-proc code adapted from Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet | |
and optimized for PyTorch. | |
Args: | |
cls_outputs: an OrderDict with keys representing levels and values | |
representing logits in [batch_size, height, width, num_anchors]. | |
box_outputs: an OrderDict with keys representing levels and values | |
representing box regression targets in [batch_size, height, width, num_anchors * 4]. | |
num_levels (int): number of feature levels | |
num_classes (int): number of output classes | |
""" | |
batch_size = cls_outputs[0].shape[0] | |
cls_outputs_all = torch.cat([ | |
cls_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, num_classes]) | |
for level in range(num_levels)], 1) | |
box_outputs_all = torch.cat([ | |
box_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, 4]) | |
for level in range(num_levels)], 1) | |
_, cls_topk_indices_all = torch.topk(cls_outputs_all.reshape(batch_size, -1), dim=1, k=max_detection_points) | |
indices_all = cls_topk_indices_all // num_classes | |
classes_all = cls_topk_indices_all % num_classes | |
box_outputs_all_after_topk = torch.gather( | |
box_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, 4)) | |
cls_outputs_all_after_topk = torch.gather( | |
cls_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, num_classes)) | |
cls_outputs_all_after_topk = torch.gather( | |
cls_outputs_all_after_topk, 2, classes_all.unsqueeze(2)) | |
return cls_outputs_all_after_topk, box_outputs_all_after_topk, indices_all, classes_all | |
def _batch_detection( | |
batch_size: int, class_out, box_out, anchor_boxes, indices, classes, | |
img_scale: Optional[torch.Tensor] = None, img_size: Optional[torch.Tensor] = None): | |
batch_detections = [] | |
# FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome | |
for i in range(batch_size): | |
img_scale_i = None if img_scale is None else img_scale[i] | |
img_size_i = None if img_size is None else img_size[i] | |
detections = generate_detections( | |
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale_i, img_size_i) | |
batch_detections.append(detections) | |
return torch.stack(batch_detections, dim=0) | |
class DetBenchPredict(nn.Module): | |
def __init__(self, model): | |
super(DetBenchPredict, self).__init__() | |
self.model = model | |
self.config = model.config # FIXME remove this when we can use @property (torchscript limitation) | |
self.num_levels = model.config.num_levels | |
self.num_classes = model.config.num_classes | |
self.anchors = Anchors.from_config(model.config) | |
def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None): | |
class_out, box_out = self.model(x) | |
class_out, box_out, indices, classes = _post_process( | |
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes) | |
if img_info is None: | |
img_scale, img_size = None, None | |
else: | |
img_scale, img_size = img_info['img_scale'], img_info['img_size'] | |
return _batch_detection( | |
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scale, img_size) | |
class DetBenchTrain(nn.Module): | |
def __init__(self, model, create_labeler=True): | |
super(DetBenchTrain, self).__init__() | |
self.model = model | |
self.config = model.config # FIXME remove this when we can use @property (torchscript limitation) | |
self.num_levels = model.config.num_levels | |
self.num_classes = model.config.num_classes | |
self.anchors = Anchors.from_config(model.config) | |
self.anchor_labeler = None | |
if create_labeler: | |
self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5) | |
self.loss_fn = DetectionLoss(model.config) | |
def forward(self, x, target: Dict[str, torch.Tensor]): | |
class_out, box_out = self.model(x) | |
if self.anchor_labeler is None: | |
# target should contain pre-computed anchor labels if labeler not present in bench | |
assert 'label_num_positives' in target | |
cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)] | |
box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)] | |
num_positives = target['label_num_positives'] | |
else: | |
cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors( | |
target['bbox'], target['cls']) | |
loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives) | |
output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss} | |
if not self.training: | |
# if eval mode, output detections for evaluation | |
class_out_pp, box_out_pp, indices, classes = _post_process( | |
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes) | |
output['detections'] = _batch_detection( | |
x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes, | |
target['img_scale'], target['img_size']) | |
return output | |
def unwrap_bench(model): | |
# Unwrap a model in support bench so that various other fns can access the weights and attribs of the | |
# underlying model directly | |
if isinstance(model, ModelEma): # unwrap ModelEma | |
return unwrap_bench(model.ema) | |
elif hasattr(model, 'module'): # unwrap DDP | |
return unwrap_bench(model.module) | |
elif hasattr(model, 'model'): # unwrap Bench -> model | |
return unwrap_bench(model.model) | |
else: | |
return model | |