liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# DUST3R default transforms
# --------------------------------------------------------
import torchvision.transforms as tvf
from dust3r.utils.image import ImgNorm
# define the standard image transforms
ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
if isinstance(value, (int, float)):
if value < 0:
raise ValueError(f"If is a single number, it must be non negative.")
value = [center - float(value), center + float(value)]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
value = [float(value[0]), float(value[1])]
else:
raise TypeError(f"should be a single number or a list/tuple with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"values should be between {bound}, but got {value}.")
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
return None
else:
return tuple(value)
import torch
import torchvision.transforms.functional as F
def SeqColorJitter():
"""
Return a color jitter transform with same random parameters
"""
brightness = _check_input(0.5)
contrast = _check_input(0.5)
saturation = _check_input(0.5)
hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
fn_idx = torch.randperm(4)
brightness_factor = (
None
if brightness is None
else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
)
contrast_factor = (
None
if contrast is None
else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
)
saturation_factor = (
None
if saturation is None
else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
)
hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
def _color_jitter(img):
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img = F.adjust_brightness(img, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img = F.adjust_contrast(img, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
return ImgNorm(img)
return _color_jitter