|
import torch |
|
import detection.transforms as T |
|
|
|
|
|
class DetectionPresetTrain: |
|
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): |
|
if data_augmentation == "hflip": |
|
self.transforms = T.Compose( |
|
[ |
|
T.RandomHorizontalFlip(p=hflip_prob), |
|
T.PILToTensor(), |
|
T.ConvertImageDtype(torch.float), |
|
] |
|
) |
|
elif data_augmentation == "ssd": |
|
self.transforms = T.Compose( |
|
[ |
|
T.RandomPhotometricDistort(), |
|
T.RandomZoomOut(fill=list(mean)), |
|
T.RandomIoUCrop(), |
|
T.RandomHorizontalFlip(p=hflip_prob), |
|
T.PILToTensor(), |
|
T.ConvertImageDtype(torch.float), |
|
] |
|
) |
|
elif data_augmentation == "ssdlite": |
|
self.transforms = T.Compose( |
|
[ |
|
T.RandomIoUCrop(), |
|
T.RandomHorizontalFlip(p=hflip_prob), |
|
T.PILToTensor(), |
|
T.ConvertImageDtype(torch.float), |
|
] |
|
) |
|
else: |
|
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') |
|
|
|
def __call__(self, img, target): |
|
return self.transforms(img, target) |
|
|
|
|
|
class DetectionPresetEval: |
|
def __init__(self): |
|
self.transforms = T.ToTensor() |
|
|
|
def __call__(self, img, target): |
|
return self.transforms(img, target) |
|
|