# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # DUST3R default transforms # -------------------------------------------------------- import torchvision.transforms as tvf from dust3r.utils.image import ImgNorm # define the standard image transforms ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True): if isinstance(value, (int, float)): if value < 0: raise ValueError(f"If is a single number, it must be non negative.") value = [center - float(value), center + float(value)] if clip_first_on_zero: value[0] = max(value[0], 0.0) elif isinstance(value, (tuple, list)) and len(value) == 2: value = [float(value[0]), float(value[1])] else: raise TypeError(f"should be a single number or a list/tuple with length 2.") if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError(f"values should be between {bound}, but got {value}.") # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: return None else: return tuple(value) import torch import torchvision.transforms.functional as F def SeqColorJitter(): """ Return a color jitter transform with same random parameters """ brightness = _check_input(0.5) contrast = _check_input(0.5) saturation = _check_input(0.5) hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) fn_idx = torch.randperm(4) brightness_factor = ( None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) ) contrast_factor = ( None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) ) saturation_factor = ( None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) ) hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) def _color_jitter(img): for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: img = F.adjust_brightness(img, brightness_factor) elif fn_id == 1 and contrast_factor is not None: img = F.adjust_contrast(img, contrast_factor) elif fn_id == 2 and saturation_factor is not None: img = F.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) return ImgNorm(img) return _color_jitter