Spaces:
Sleeping
Sleeping
""" Object detection loader/collate | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
import torch.utils.data | |
from .transforms import * | |
from .transforms_albumentation import get_transform | |
from .random_erasing import RandomErasing | |
from effdet.anchors import AnchorLabeler | |
from timm.data.distributed_sampler import OrderedDistributedSampler | |
import os | |
MAX_NUM_INSTANCES = 100 | |
class DetectionFastCollate: | |
""" A detection specific, optimized collate function w/ a bit of state. | |
Optionally performs anchor labelling. Doing this here offloads some work from the | |
GPU and the main training process thread and increases the load on the dataloader | |
threads. | |
""" | |
def __init__( | |
self, | |
instance_keys=None, | |
instance_shapes=None, | |
instance_fill=-1, | |
max_instances=MAX_NUM_INSTANCES, | |
anchor_labeler=None, | |
): | |
instance_keys = instance_keys or {'bbox', 'bbox_ignore', 'cls'} | |
instance_shapes = instance_shapes or dict( | |
bbox=(max_instances, 4), bbox_ignore=(max_instances, 4), cls=(max_instances,)) | |
self.instance_info = {k: dict(fill=instance_fill, shape=instance_shapes[k]) for k in instance_keys} | |
self.max_instances = max_instances | |
self.anchor_labeler = anchor_labeler | |
def __call__(self, batch): | |
batch_size = len(batch) | |
target = dict() | |
labeler_outputs = dict() | |
img_tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) | |
for i in range(batch_size): | |
img_tensor[i] += torch.from_numpy(batch[i][0]) | |
labeler_inputs = {} | |
for tk, tv in batch[i][1].items(): | |
instance_info = self.instance_info.get(tk, None) | |
if instance_info is not None: | |
# target tensor is associated with a detection instance | |
tv = torch.from_numpy(tv).to(dtype=torch.float32) | |
if self.anchor_labeler is None: | |
if i == 0: | |
shape = (batch_size,) + instance_info['shape'] | |
target_tensor = torch.full(shape, instance_info['fill'], dtype=torch.float32) | |
target[tk] = target_tensor | |
else: | |
target_tensor = target[tk] | |
num_elem = min(tv.shape[0], self.max_instances) | |
target_tensor[i, 0:num_elem] = tv[0:num_elem] | |
else: | |
# no need to pass gt tensors through when labeler in use | |
if tk in ('bbox', 'cls'): | |
labeler_inputs[tk] = tv | |
else: | |
# target tensor is an image-level annotation / metadata | |
if i == 0: | |
# first batch elem, create destination tensors | |
if isinstance(tv, (tuple, list)): | |
# per batch elem sequence | |
shape = (batch_size, len(tv)) | |
dtype = torch.float32 if isinstance(tv[0], (float, np.floating)) else torch.int32 | |
else: | |
# per batch elem scalar | |
shape = batch_size, | |
dtype = torch.float32 if isinstance(tv, (float, np.floating)) else torch.int64 | |
target_tensor = torch.zeros(shape, dtype=dtype) | |
target[tk] = target_tensor | |
else: | |
target_tensor = target[tk] | |
target_tensor[i] = torch.tensor(tv, dtype=target_tensor.dtype) | |
if self.anchor_labeler is not None: | |
cls_targets, box_targets, num_positives = self.anchor_labeler.label_anchors( | |
labeler_inputs['bbox'], labeler_inputs['cls'], filter_valid=False) | |
if i == 0: | |
# first batch elem, create destination tensors, separate key per level | |
for j, (ct, bt) in enumerate(zip(cls_targets, box_targets)): | |
labeler_outputs[f'label_cls_{j}'] = torch.zeros( | |
(batch_size,) + ct.shape, dtype=torch.int64) | |
labeler_outputs[f'label_bbox_{j}'] = torch.zeros( | |
(batch_size,) + bt.shape, dtype=torch.float32) | |
labeler_outputs['label_num_positives'] = torch.zeros(batch_size) | |
for j, (ct, bt) in enumerate(zip(cls_targets, box_targets)): | |
labeler_outputs[f'label_cls_{j}'][i] = ct | |
labeler_outputs[f'label_bbox_{j}'][i] = bt | |
labeler_outputs['label_num_positives'][i] = num_positives | |
if labeler_outputs: | |
target.update(labeler_outputs) | |
return img_tensor, target | |
class PrefetchLoader: | |
def __init__(self, | |
loader, | |
mean=IMAGENET_DEFAULT_MEAN, | |
std=IMAGENET_DEFAULT_STD, | |
re_prob=0., | |
re_mode='pixel', | |
re_count=1, | |
): | |
self.loader = loader | |
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) | |
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) | |
if re_prob > 0.: | |
self.random_erasing = RandomErasing(probability=re_prob, mode=re_mode, max_count=re_count) | |
else: | |
self.random_erasing = None | |
def __iter__(self): | |
stream = torch.cuda.Stream() | |
first = True | |
for next_input, next_target in self.loader: | |
with torch.cuda.stream(stream): | |
next_input = next_input.cuda(non_blocking=True) | |
next_input = next_input.float().sub_(self.mean).div_(self.std) | |
next_target = {k: v.cuda(non_blocking=True) for k, v in next_target.items()} | |
if self.random_erasing is not None: | |
next_input = self.random_erasing(next_input, next_target) | |
if not first: | |
yield input, target | |
else: | |
first = False | |
torch.cuda.current_stream().wait_stream(stream) | |
input = next_input | |
target = next_target | |
yield input, target | |
def __len__(self): | |
return len(self.loader) | |
def sampler(self): | |
return self.loader.sampler | |
def dataset(self): | |
return self.loader.dataset | |
def create_loader( | |
dataset, | |
input_size, | |
batch_size, | |
is_training=False, | |
use_prefetcher=True, | |
re_prob=0., | |
re_mode='pixel', | |
re_count=1, | |
interpolation='bilinear', | |
fill_color='mean', | |
mean=IMAGENET_DEFAULT_MEAN, | |
std=IMAGENET_DEFAULT_STD, | |
num_workers=1, | |
distributed=False, | |
pin_mem=False, | |
anchor_labeler=None, | |
): | |
if isinstance(input_size, tuple): | |
img_size = input_size[-2:] | |
else: | |
img_size = input_size | |
if is_training: | |
transforms = get_transform() | |
transform = transforms_coco_train( | |
img_size, | |
interpolation=interpolation, | |
use_prefetcher=use_prefetcher, | |
fill_color=fill_color, | |
mean=mean, | |
std=std) | |
else: | |
transforms = None | |
transform = transforms_coco_eval( | |
img_size, | |
interpolation=interpolation, | |
use_prefetcher=use_prefetcher, | |
fill_color=fill_color, | |
mean=mean, | |
std=std) | |
dataset.transforms = transforms | |
dataset.transform = transform | |
sampler = None | |
if distributed: | |
if is_training: | |
sampler = torch.utils.data.distributed.DistributedSampler(dataset) | |
else: | |
# This will add extra duplicate entries to result in equal num | |
# of samples per-process, will slightly alter validation results | |
sampler = OrderedDistributedSampler(dataset) | |
collate_fn = DetectionFastCollate(anchor_labeler=anchor_labeler) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=sampler is None and is_training, | |
num_workers=num_workers, | |
sampler=sampler, | |
pin_memory=pin_mem, | |
collate_fn=collate_fn, | |
) | |
if use_prefetcher: | |
if is_training: | |
loader = PrefetchLoader(loader, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count) | |
else: | |
loader = PrefetchLoader(loader, mean=mean, std=std) | |
return loader | |