|
|
|
|
|
|
|
|
|
|
|
from enum import Enum |
|
from functools import partial |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
|
|
from .backbones import _make_dinov2_model |
|
from .depth import BNHead, DepthEncoderDecoder, DPTHead |
|
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding |
|
|
|
|
|
class Weights(Enum): |
|
NYU = "NYU" |
|
KITTI = "KITTI" |
|
|
|
|
|
def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: |
|
if not pretrained: |
|
return (0.001, 10.0) |
|
|
|
|
|
if weights == Weights.KITTI: |
|
return (0.001, 80.0) |
|
|
|
if weights == Weights.NYU: |
|
return (0.001, 10.0) |
|
|
|
return (0.001, 10.0) |
|
|
|
|
|
def _make_dinov2_linear_depth_head( |
|
*, |
|
embed_dim: int, |
|
layers: int, |
|
min_depth: float, |
|
max_depth: float, |
|
**kwargs, |
|
): |
|
if layers not in (1, 4): |
|
raise AssertionError(f"Unsupported number of layers: {layers}") |
|
|
|
if layers == 1: |
|
in_index = [0] |
|
else: |
|
assert layers == 4 |
|
in_index = [0, 1, 2, 3] |
|
|
|
return BNHead( |
|
classify=True, |
|
n_bins=256, |
|
bins_strategy="UD", |
|
norm_strategy="linear", |
|
upsample=4, |
|
in_channels=[embed_dim] * len(in_index), |
|
in_index=in_index, |
|
input_transform="resize_concat", |
|
channels=embed_dim * len(in_index) * 2, |
|
align_corners=False, |
|
min_depth=0.001, |
|
max_depth=80, |
|
loss_decode=(), |
|
) |
|
|
|
|
|
def _make_dinov2_linear_depther( |
|
*, |
|
arch_name: str = "vit_large", |
|
layers: int = 4, |
|
pretrained: bool = True, |
|
weights: Union[Weights, str] = Weights.NYU, |
|
depth_range: Optional[Tuple[float, float]] = None, |
|
**kwargs, |
|
): |
|
if layers not in (1, 4): |
|
raise AssertionError(f"Unsupported number of layers: {layers}") |
|
if isinstance(weights, str): |
|
try: |
|
weights = Weights[weights] |
|
except KeyError: |
|
raise AssertionError(f"Unsupported weights: {weights}") |
|
|
|
if depth_range is None: |
|
depth_range = _get_depth_range(pretrained, weights) |
|
min_depth, max_depth = depth_range |
|
|
|
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) |
|
|
|
embed_dim = backbone.embed_dim |
|
patch_size = backbone.patch_size |
|
model_name = _make_dinov2_model_name(arch_name, patch_size) |
|
linear_depth_head = _make_dinov2_linear_depth_head( |
|
embed_dim=embed_dim, |
|
layers=layers, |
|
min_depth=min_depth, |
|
max_depth=max_depth, |
|
) |
|
|
|
layer_count = { |
|
"vit_small": 12, |
|
"vit_base": 12, |
|
"vit_large": 24, |
|
"vit_giant2": 40, |
|
}[arch_name] |
|
|
|
if layers == 4: |
|
out_index = { |
|
"vit_small": [2, 5, 8, 11], |
|
"vit_base": [2, 5, 8, 11], |
|
"vit_large": [4, 11, 17, 23], |
|
"vit_giant2": [9, 19, 29, 39], |
|
}[arch_name] |
|
else: |
|
assert layers == 1 |
|
out_index = [layer_count - 1] |
|
|
|
model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) |
|
model.backbone.forward = partial( |
|
backbone.get_intermediate_layers, |
|
n=out_index, |
|
reshape=True, |
|
return_class_token=True, |
|
norm=False, |
|
) |
|
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) |
|
|
|
if pretrained: |
|
layers_str = str(layers) if layers == 4 else "" |
|
weights_str = weights.value.lower() |
|
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" |
|
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") |
|
if "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
return model |
|
|
|
|
|
def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
|
return _make_dinov2_linear_depther( |
|
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs |
|
) |
|
|
|
|
|
def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
|
return _make_dinov2_linear_depther( |
|
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs |
|
) |
|
|
|
|
|
def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
|
return _make_dinov2_linear_depther( |
|
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs |
|
) |
|
|
|
|
|
def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
|
return _make_dinov2_linear_depther( |
|
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs |
|
) |
|
|
|
|
|
def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): |
|
return DPTHead( |
|
in_channels=[embed_dim] * 4, |
|
channels=256, |
|
embed_dims=embed_dim, |
|
post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], |
|
readout_type="project", |
|
min_depth=min_depth, |
|
max_depth=max_depth, |
|
loss_decode=(), |
|
) |
|
|
|
|
|
def _make_dinov2_dpt_depther( |
|
*, |
|
arch_name: str = "vit_large", |
|
pretrained: bool = True, |
|
weights: Union[Weights, str] = Weights.NYU, |
|
depth_range: Optional[Tuple[float, float]] = None, |
|
**kwargs, |
|
): |
|
if isinstance(weights, str): |
|
try: |
|
weights = Weights[weights] |
|
except KeyError: |
|
raise AssertionError(f"Unsupported weights: {weights}") |
|
|
|
if depth_range is None: |
|
depth_range = _get_depth_range(pretrained, weights) |
|
min_depth, max_depth = depth_range |
|
|
|
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) |
|
|
|
model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) |
|
dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) |
|
|
|
out_index = { |
|
"vit_small": [2, 5, 8, 11], |
|
"vit_base": [2, 5, 8, 11], |
|
"vit_large": [4, 11, 17, 23], |
|
"vit_giant2": [9, 19, 29, 39], |
|
}[arch_name] |
|
|
|
model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) |
|
model.backbone.forward = partial( |
|
backbone.get_intermediate_layers, |
|
n=out_index, |
|
reshape=True, |
|
return_class_token=True, |
|
norm=False, |
|
) |
|
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) |
|
|
|
if pretrained: |
|
weights_str = weights.value.lower() |
|
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" |
|
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") |
|
if "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
return model |
|
|
|
|
|
def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
|
return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) |
|
|
|
|
|
def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
|
return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) |
|
|
|
|
|
def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
|
return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) |
|
|
|
|
|
def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): |
|
return _make_dinov2_dpt_depther( |
|
arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs |
|
) |
|
|