Spaces:
Sleeping
Sleeping
File size: 6,365 Bytes
fa84113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
""" PyTorch EfficientDet support benches
Hacked together by Ross Wightman
"""
from typing import Optional, Dict, List
import torch
import torch.nn as nn
from timm.utils import ModelEma
from .anchors import Anchors, AnchorLabeler, generate_detections, MAX_DETECTION_POINTS
from .loss import DetectionLoss
def _post_process(
cls_outputs: List[torch.Tensor],
box_outputs: List[torch.Tensor],
num_levels: int,
num_classes: int,
max_detection_points: int = MAX_DETECTION_POINTS,
):
"""Selects top-k predictions.
Post-proc code adapted from Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet
and optimized for PyTorch.
Args:
cls_outputs: an OrderDict with keys representing levels and values
representing logits in [batch_size, height, width, num_anchors].
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in [batch_size, height, width, num_anchors * 4].
num_levels (int): number of feature levels
num_classes (int): number of output classes
"""
batch_size = cls_outputs[0].shape[0]
cls_outputs_all = torch.cat([
cls_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, num_classes])
for level in range(num_levels)], 1)
box_outputs_all = torch.cat([
box_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, 4])
for level in range(num_levels)], 1)
_, cls_topk_indices_all = torch.topk(cls_outputs_all.reshape(batch_size, -1), dim=1, k=max_detection_points)
indices_all = cls_topk_indices_all // num_classes
classes_all = cls_topk_indices_all % num_classes
box_outputs_all_after_topk = torch.gather(
box_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, 4))
cls_outputs_all_after_topk = torch.gather(
cls_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, num_classes))
cls_outputs_all_after_topk = torch.gather(
cls_outputs_all_after_topk, 2, classes_all.unsqueeze(2))
return cls_outputs_all_after_topk, box_outputs_all_after_topk, indices_all, classes_all
@torch.jit.script
def _batch_detection(
batch_size: int, class_out, box_out, anchor_boxes, indices, classes,
img_scale: Optional[torch.Tensor] = None, img_size: Optional[torch.Tensor] = None):
batch_detections = []
# FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome
for i in range(batch_size):
img_scale_i = None if img_scale is None else img_scale[i]
img_size_i = None if img_size is None else img_size[i]
detections = generate_detections(
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale_i, img_size_i)
batch_detections.append(detections)
return torch.stack(batch_detections, dim=0)
class DetBenchPredict(nn.Module):
def __init__(self, model):
super(DetBenchPredict, self).__init__()
self.model = model
self.config = model.config # FIXME remove this when we can use @property (torchscript limitation)
self.num_levels = model.config.num_levels
self.num_classes = model.config.num_classes
self.anchors = Anchors.from_config(model.config)
def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None):
class_out, box_out = self.model(x)
class_out, box_out, indices, classes = _post_process(
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
if img_info is None:
img_scale, img_size = None, None
else:
img_scale, img_size = img_info['img_scale'], img_info['img_size']
return _batch_detection(
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scale, img_size)
class DetBenchTrain(nn.Module):
def __init__(self, model, create_labeler=True):
super(DetBenchTrain, self).__init__()
self.model = model
self.config = model.config # FIXME remove this when we can use @property (torchscript limitation)
self.num_levels = model.config.num_levels
self.num_classes = model.config.num_classes
self.anchors = Anchors.from_config(model.config)
self.anchor_labeler = None
if create_labeler:
self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
self.loss_fn = DetectionLoss(model.config)
def forward(self, x, target: Dict[str, torch.Tensor]):
class_out, box_out = self.model(x)
if self.anchor_labeler is None:
# target should contain pre-computed anchor labels if labeler not present in bench
assert 'label_num_positives' in target
cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)]
box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)]
num_positives = target['label_num_positives']
else:
cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
target['bbox'], target['cls'])
loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss}
if not self.training:
# if eval mode, output detections for evaluation
class_out_pp, box_out_pp, indices, classes = _post_process(
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
output['detections'] = _batch_detection(
x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes,
target['img_scale'], target['img_size'])
return output
def unwrap_bench(model):
# Unwrap a model in support bench so that various other fns can access the weights and attribs of the
# underlying model directly
if isinstance(model, ModelEma): # unwrap ModelEma
return unwrap_bench(model.ema)
elif hasattr(model, 'module'): # unwrap DDP
return unwrap_bench(model.module)
elif hasattr(model, 'model'): # unwrap Bench -> model
return unwrap_bench(model.model)
else:
return model
|