Spaces:
Runtime error
Runtime error
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# DUST3R default transforms | |
# -------------------------------------------------------- | |
import torchvision.transforms as tvf | |
from dust3r.utils.image import ImgNorm | |
# define the standard image transforms | |
ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) | |
def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True): | |
if isinstance(value, (int, float)): | |
if value < 0: | |
raise ValueError(f"If is a single number, it must be non negative.") | |
value = [center - float(value), center + float(value)] | |
if clip_first_on_zero: | |
value[0] = max(value[0], 0.0) | |
elif isinstance(value, (tuple, list)) and len(value) == 2: | |
value = [float(value[0]), float(value[1])] | |
else: | |
raise TypeError(f"should be a single number or a list/tuple with length 2.") | |
if not bound[0] <= value[0] <= value[1] <= bound[1]: | |
raise ValueError(f"values should be between {bound}, but got {value}.") | |
# if value is 0 or (1., 1.) for brightness/contrast/saturation | |
# or (0., 0.) for hue, do nothing | |
if value[0] == value[1] == center: | |
return None | |
else: | |
return tuple(value) | |
import torch | |
import torchvision.transforms.functional as F | |
def SeqColorJitter(): | |
""" | |
Return a color jitter transform with same random parameters | |
""" | |
brightness = _check_input(0.5) | |
contrast = _check_input(0.5) | |
saturation = _check_input(0.5) | |
hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) | |
fn_idx = torch.randperm(4) | |
brightness_factor = ( | |
None | |
if brightness is None | |
else float(torch.empty(1).uniform_(brightness[0], brightness[1])) | |
) | |
contrast_factor = ( | |
None | |
if contrast is None | |
else float(torch.empty(1).uniform_(contrast[0], contrast[1])) | |
) | |
saturation_factor = ( | |
None | |
if saturation is None | |
else float(torch.empty(1).uniform_(saturation[0], saturation[1])) | |
) | |
hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) | |
def _color_jitter(img): | |
for fn_id in fn_idx: | |
if fn_id == 0 and brightness_factor is not None: | |
img = F.adjust_brightness(img, brightness_factor) | |
elif fn_id == 1 and contrast_factor is not None: | |
img = F.adjust_contrast(img, contrast_factor) | |
elif fn_id == 2 and saturation_factor is not None: | |
img = F.adjust_saturation(img, saturation_factor) | |
elif fn_id == 3 and hue_factor is not None: | |
img = F.adjust_hue(img, hue_factor) | |
return ImgNorm(img) | |
return _color_jitter | |