sadimanna's picture
Upload 20 files
d6def08
raw
history blame
6.66 kB
import random
from enum import Enum
from typing import Tuple, List, Type, Iterator
import PIL
import torch.utils.data.dataset
import torch.utils.data.sampler
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import transforms
class Base(torch.utils.data.dataset.Dataset):
class Mode(Enum):
TRAIN = 'train'
EVAL = 'eval'
OPTIONS = ['voc2007', 'coco2017', 'voc2007-cat-dog', 'coco2017-person', 'coco2017-car', 'coco2017-animal']
@staticmethod
def from_name(name: str) -> Type['Base']:
if name == 'voc2007':
from dataset.voc2007 import VOC2007
return VOC2007
elif name == 'coco2017':
from dataset.coco2017 import COCO2017
return COCO2017
elif name == 'voc2007-cat-dog':
from dataset.voc2007_cat_dog import VOC2007CatDog
return VOC2007CatDog
elif name == 'coco2017-person':
from dataset.coco2017_person import COCO2017Person
return COCO2017Person
elif name == 'coco2017-car':
from dataset.coco2017_car import COCO2017Car
return COCO2017Car
elif name == 'coco2017-animal':
from dataset.coco2017_animal import COCO2017Animal
return COCO2017Animal
else:
raise ValueError
def __init__(self, path_to_data_dir: str, mode: Mode, image_min_side: float, image_max_side: float):
self._path_to_data_dir = path_to_data_dir
self._mode = mode
self._image_min_side = image_min_side
self._image_max_side = image_max_side
def __len__(self) -> int:
raise NotImplementedError
def __getitem__(self, index: int) -> Tuple[str, Tensor, Tensor, Tensor, Tensor]:
raise NotImplementedError
def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]:
raise NotImplementedError
def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]):
raise NotImplementedError
@property
def image_ratios(self) -> List[float]:
raise NotImplementedError
@staticmethod
def num_classes() -> int:
raise NotImplementedError
@staticmethod
def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]:
# resize according to the rules:
# 1. scale shorter side to IMAGE_MIN_SIDE
# 2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE
scale_for_shorter_side = image_min_side / min(image.width, image.height)
longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side
scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1
scale = scale_for_shorter_side * scale_for_longer_side
transform = transforms.Compose([
transforms.Resize((round(image.height * scale), round(image.width * scale))), # interpolation `BILINEAR` is applied by default
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
return image, scale
@staticmethod
def padding_collate_fn(batch: List[Tuple[str, Tensor, Tensor, Tensor, Tensor]]) -> Tuple[List[str], Tensor, Tensor, Tensor, Tensor]:
image_id_batch, image_batch, scale_batch, bboxes_batch, labels_batch = zip(*batch)
max_image_width = max([it.shape[2] for it in image_batch])
max_image_height = max([it.shape[1] for it in image_batch])
max_bboxes_length = max([len(it) for it in bboxes_batch])
max_labels_length = max([len(it) for it in labels_batch])
padded_image_batch = []
padded_bboxes_batch = []
padded_labels_batch = []
for image in image_batch:
padded_image = F.pad(input=image, pad=(0, max_image_width - image.shape[2], 0, max_image_height - image.shape[1])) # pad has format (left, right, top, bottom)
padded_image_batch.append(padded_image)
for bboxes in bboxes_batch:
padded_bboxes = torch.cat([bboxes, torch.zeros(max_bboxes_length - len(bboxes), 4).to(bboxes)])
padded_bboxes_batch.append(padded_bboxes)
for labels in labels_batch:
padded_labels = torch.cat([labels, torch.zeros(max_labels_length - len(labels)).to(labels)])
padded_labels_batch.append(padded_labels)
image_id_batch = list(image_id_batch)
padded_image_batch = torch.stack(padded_image_batch, dim=0)
scale_batch = torch.stack(scale_batch, dim=0)
padded_bboxes_batch = torch.stack(padded_bboxes_batch, dim=0)
padded_labels_batch = torch.stack(padded_labels_batch, dim=0)
return image_id_batch, padded_image_batch, scale_batch, padded_bboxes_batch, padded_labels_batch
class NearestRatioRandomSampler(torch.utils.data.sampler.Sampler):
def __init__(self, image_ratios: List[float], num_neighbors: int):
super().__init__(data_source=None)
self._image_ratios = image_ratios
self._num_neighbors = num_neighbors
def __len__(self) -> int:
return len(self._image_ratios)
def __iter__(self) -> Iterator[int]:
image_ratios = torch.tensor(self._image_ratios)
tall_indices = (image_ratios < 1).nonzero().view(-1)
fat_indices = (image_ratios >= 1).nonzero().view(-1)
tall_indices_length = len(tall_indices)
fat_indices_length = len(fat_indices)
tall_indices = tall_indices[torch.randperm(tall_indices_length)]
fat_indices = fat_indices[torch.randperm(fat_indices_length)]
num_tall_remainder = tall_indices_length % self._num_neighbors
num_fat_remainder = fat_indices_length % self._num_neighbors
tall_indices = tall_indices[:tall_indices_length - num_tall_remainder]
fat_indices = fat_indices[:fat_indices_length - num_fat_remainder]
tall_indices = tall_indices.view(-1, self._num_neighbors)
fat_indices = fat_indices.view(-1, self._num_neighbors)
merge_indices = torch.cat([tall_indices, fat_indices], dim=0)
merge_indices = merge_indices[torch.randperm(len(merge_indices))].view(-1)
return iter(merge_indices.tolist())