Image Feature Extraction
Transformers
Safetensors
dinov2
rad-dino / augmentations.py
fepegar's picture
Add files helpful for fine-tuning (#6)
428e563
# Copyright (c) Microsoft Corporation. All rights reserved.
# See LICENSE in the repo root for license information.
#
# Portions:
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import logging
from PIL import Image
from torchvision import transforms
from .transforms import (
GaussianBlur,
MaybeToTensor,
make_normalize_transform,
)
logger = logging.getLogger("dinov2")
class DataAugmentationDINO(object):
def __init__(
self,
global_crops_scale,
local_crops_scale,
local_crops_number,
global_crops_size=224,
local_crops_size=96,
):
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
self.local_crops_number = local_crops_number
self.global_crops_size = global_crops_size
self.local_crops_size = local_crops_size
logger.info("###################################")
logger.info("Using data augmentation parameters:")
logger.info(f"global_crops_scale: {global_crops_scale}")
logger.info(f"local_crops_scale: {local_crops_scale}")
logger.info(f"local_crops_number: {local_crops_number}")
logger.info(f"global_crops_size: {global_crops_size}")
logger.info(f"local_crops_size: {local_crops_size}")
logger.info("###################################")
# random resized crop and flip
self.geometric_augmentation_global = transforms.Compose(
[
transforms.RandomResizedCrop(
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
self.geometric_augmentation_local = transforms.Compose(
[
transforms.RandomResizedCrop(
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
# color distorsions / blurring
color_jittering = transforms.Compose(
[
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
p=0.8,
),
transforms.RandomGrayscale(p=0.2),
]
)
global_transfo1_extra = GaussianBlur(p=0.5)
global_transfo2_extra = transforms.Compose(
[
GaussianBlur(p=0.1),
]
)
local_transfo_extra = GaussianBlur(p=0.5)
# normalization
self.normalize = transforms.Compose(
[
MaybeToTensor(),
make_normalize_transform(),
]
)
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
def __call__(self, image):
output = {}
# global crops:
im1_base = self.geometric_augmentation_global(image)
global_crop_1 = self.global_transfo1(im1_base)
im2_base = self.geometric_augmentation_global(image)
global_crop_2 = self.global_transfo2(im2_base)
output["global_crops"] = [global_crop_1, global_crop_2]
# global crops for teacher:
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
# local crops:
local_crops = [
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
]
output["local_crops"] = local_crops
output["offsets"] = ()
return output
def get_online_classification_augmentation_from_config(cfg) -> transforms.Compose:
augmentation_config = cfg.evaluation.online.augmentation
interpolation = getattr(Image.Resampling, augmentation_config.interpolation)
resize_size = crop_size = cfg.crops.global_crops_size
resize = transforms.Resize(resize_size, interpolation=interpolation)
crop = transforms.CenterCrop(crop_size)
affine = transforms.RandomAffine(
degrees=augmentation_config.degrees,
scale=augmentation_config.scale,
shear=augmentation_config.shear,
interpolation=interpolation,
)
transforms_list = [
resize,
crop,
affine,
MaybeToTensor(),
make_normalize_transform(),
]
if augmentation_config.horizontal_flip:
transforms_list.append(transforms.RandomHorizontalFlip())
return transforms.Compose(transforms_list)