|
import random |
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
import re |
|
from einops import rearrange |
|
from dataclasses import dataclass |
|
from torchvision import transforms |
|
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
from transformers import CLIPTokenizer, CLIPImageProcessor |
|
from transformers import AutoImageProcessor, AutoModel |
|
from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer |
|
from transformers.utils import ModelOutput |
|
from typing import Iterable, Optional, Union, List |
|
|
|
import step1x3d_geometry |
|
from step1x3d_geometry.utils.typing import * |
|
from .clip.modeling_clip import CLIPModel |
|
from .clip.modeling_conditional_clip import ConditionalCLIPModel |
|
from .base import BaseVisualEncoder, ImageType |
|
from .dinov2.modeling_dinov2 import Dinov2Model |
|
from .dinov2.modeling_conditional_dinov2 import ConditionalDinov2Model |
|
from .dinov2_with_registers.modeling_dinov2_with_registers import ( |
|
Dinov2WithRegistersModel, |
|
) |
|
|
|
CLIP_IMAGE_SIZE = 224 |
|
|
|
|
|
@dataclass |
|
class CLIPEmbedOutput(ModelOutput): |
|
last_hidden_state: torch.FloatTensor = None |
|
pooler_output: torch.FloatTensor = None |
|
embeds: torch.FloatTensor = None |
|
|
|
|
|
class DINOEmbedOutput(ModelOutput): |
|
last_hidden_state: torch.FloatTensor = None |
|
pooler_output: torch.FloatTensor = None |
|
|
|
|
|
@step1x3d_geometry.register("dinov2-clip-encoder") |
|
class Dinov2CLIPEncoder(BaseVisualEncoder, ModelMixin): |
|
|
|
@dataclass |
|
class Config(BaseVisualEncoder.Config): |
|
pretrained_model_name_or_path: Optional[str] = ( |
|
None |
|
) |
|
pretrained_clip_name_or_path: Optional[str] = ( |
|
None |
|
) |
|
pretrained_dino_name_or_path: Optional[str] = ( |
|
None |
|
) |
|
pretrained_linear_proj: Optional[str] = None |
|
freeze_modulation_clip: bool = False |
|
freeze_modulation_dino: bool = False |
|
enable_gradient_checkpointing: bool = False |
|
image_size: int = CLIP_IMAGE_SIZE |
|
fuse_type: str = "concat" |
|
|
|
dino_type: Optional[str] = None |
|
clip_type: Optional[str] = None |
|
kwargs: Optional[dict] = None |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
super().configure() |
|
|
|
|
|
if not self.cfg.encode_camera: |
|
if self.cfg.pretrained_clip_name_or_path is not None: |
|
self.cfg.clip_type = f"openai/{self.cfg.pretrained_clip_name_or_path.split('openai--')[-1].split('/')[0]}" |
|
self.clip_model: CLIPModel = CLIPModel.from_pretrained( |
|
self.cfg.pretrained_clip_name_or_path |
|
) |
|
else: |
|
print("Loading CLIP model from openai/clip-vit-large-patch14") |
|
self.dino_type = "openai/clip-vit-large-patch14" |
|
self.clip_model: CLIPModel = CLIPModel( |
|
config=ConditionalCLIPModel.config_class.from_pretrained( |
|
"openai/clip-vit-large-patch14", |
|
) |
|
) |
|
if self.cfg.pretrained_dino_name_or_path is not None: |
|
self.cfg.dino_type = f"facebook/{self.cfg.pretrained_dino_name_or_path.split('facebook--')[-1].split('/')[0]}" |
|
self.dino_model: Dinov2Model = AutoModel.from_pretrained( |
|
self.cfg.pretrained_dino_name_or_path |
|
) |
|
else: |
|
if ( |
|
self.cfg.pretrained_model_name_or_path is None |
|
): |
|
assert ( |
|
self.cfg.dino_type is not None |
|
), "The dino_type should be provided" |
|
print(f"Loading Dinov2 model from {self.cfg.dino_type}") |
|
if "reg" in self.cfg.dino_type: |
|
self.dino_model: Dinov2WithRegistersModel = ( |
|
Dinov2WithRegistersModel( |
|
config=Dinov2WithRegistersModel.config_class.from_pretrained( |
|
self.cfg.dino_type, |
|
) |
|
) |
|
) |
|
else: |
|
self.dino_model: Dinov2Model = Dinov2Model( |
|
config=Dinov2Model.config_class.from_pretrained( |
|
self.dino_type, |
|
) |
|
) |
|
elif "dinov2base" in self.cfg.pretrained_model_name_or_path: |
|
print("Loading Dinov2 model from facebook/dinov2-base") |
|
self.cfg.dino_type = "facebook/dinov2-base" |
|
self.dino_model: Dinov2Model = Dinov2Model( |
|
config=Dinov2Model.config_class.from_pretrained( |
|
"facebook/dinov2-base", |
|
) |
|
) |
|
elif "dinov2regbase" in self.cfg.pretrained_model_name_or_path: |
|
print( |
|
"Loading Dinov2 model from facebook/dinov2-with-registers-base" |
|
) |
|
self.cfg.dino_type = "facebook/dinov2-with-registers-base" |
|
self.dino_model: Dinov2WithRegistersModel = ( |
|
Dinov2WithRegistersModel( |
|
config=Dinov2WithRegistersModel.config_class.from_pretrained( |
|
"facebook/dinov2-with-registers-base", |
|
) |
|
) |
|
) |
|
elif "dinov2reglarge" in self.cfg.pretrained_model_name_or_path: |
|
print( |
|
"Loading Dinov2 model from facebook/dinov2-with-registers-large" |
|
) |
|
self.cfg.dino_type = "facebook/dinov2-with-registers-large" |
|
self.dino_model: Dinov2WithRegistersModel = ( |
|
Dinov2WithRegistersModel( |
|
config=Dinov2WithRegistersModel.config_class.from_pretrained( |
|
"facebook/dinov2-with-registers-large", |
|
) |
|
) |
|
) |
|
else: |
|
raise ValueError( |
|
f"Unknown Dinov2 model: {self.cfg.pretrained_model_name_or_path}" |
|
) |
|
else: |
|
|
|
conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained( |
|
self.cfg.pretrained_clip_name_or_path, |
|
) |
|
conditional_clip_config.vision_config.modulation_dim = ( |
|
self.cfg.camera_embeds_dim |
|
) |
|
self.clip_model: CLIPModel = ConditionalCLIPModel.from_pretrained( |
|
self.cfg.pretrained_clip_name_or_path, |
|
vision_config=conditional_clip_config.vision_config, |
|
) |
|
|
|
|
|
conditional_vit_config = ( |
|
ConditionalDinov2Model.config_class.from_pretrained( |
|
self.cfg.pretrained_dino_name_or_path, |
|
) |
|
) |
|
conditional_vit_config.modulation_dim = self.cfg.camera_embeds_dim |
|
self.dino_model: ConditionalDinov2Model = ( |
|
ConditionalDinov2Model.from_pretrained( |
|
self.cfg.pretrained_dino_name_or_path, config=conditional_vit_config |
|
) |
|
) |
|
|
|
self.image_preprocess_clip = CLIPImageProcessor() |
|
self.image_preprocess_dino = AutoImageProcessor.from_pretrained( |
|
self.cfg.dino_type |
|
if self.cfg.pretrained_dino_name_or_path is None |
|
else self.cfg.pretrained_dino_name_or_path |
|
) |
|
self.transform_clip = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
CLIP_IMAGE_SIZE, |
|
transforms.InterpolationMode.BICUBIC, |
|
antialias=True, |
|
), |
|
transforms.CenterCrop(CLIP_IMAGE_SIZE), |
|
transforms.Normalize( |
|
mean=[0.48145466, 0.4578275, 0.40821073], |
|
std=[0.26862954, 0.26130258, 0.27577711], |
|
), |
|
] |
|
) |
|
self.transform_dino = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
self.cfg.image_size, |
|
transforms.InterpolationMode.BICUBIC, |
|
antialias=True, |
|
), |
|
transforms.CenterCrop(self.cfg.image_size), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225], |
|
), |
|
] |
|
) |
|
|
|
if self.cfg.enable_gradient_checkpointing: |
|
self.dino_model.encoder.gradient_checkpointing = True |
|
|
|
if self.cfg.zero_uncond_embeds: |
|
image_size = max(self.cfg.image_size, self.cfg.image_size) |
|
self.empty_image_embeds_dino = torch.zeros( |
|
(self.cfg.n_views, (image_size // 14) ** 2 + 1, 1024) |
|
).detach() |
|
self.empty_image_embeds_clip = torch.zeros( |
|
(self.cfg.n_views, (CLIP_IMAGE_SIZE // 14) ** 2 + 1, 1024) |
|
).detach() |
|
if self.cfg.fuse_type == "concat": |
|
self.empty_image_embeds = torch.cat( |
|
[self.empty_image_embeds_dino, self.empty_image_embeds_clip], dim=1 |
|
) |
|
else: |
|
raise ValueError |
|
else: |
|
if self.cfg.encode_camera: |
|
self.empty_image_embeds_dino = self.encode_image_dino( |
|
torch.zeros( |
|
self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 |
|
), |
|
self.cameras[: self.cfg.n_views], |
|
).detach() |
|
self.empty_image_embeds_clip = self.encode_image_clip( |
|
torch.zeros( |
|
self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 |
|
), |
|
self.cameras[: self.cfg.n_views], |
|
).detach() |
|
else: |
|
self.empty_image_embeds_dino = self.encode_image_dino( |
|
torch.zeros( |
|
self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 |
|
) |
|
).detach() |
|
self.empty_image_embeds_clip = self.encode_image_clip( |
|
torch.zeros( |
|
self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3 |
|
) |
|
).detach() |
|
self.empty_image_embeds_clip, self.empty_image_embeds_dino = ( |
|
self.align_clip_dino( |
|
self.empty_image_embeds_clip, self.empty_image_embeds_dino |
|
) |
|
) |
|
self.empty_image_embeds = torch.cat( |
|
[self.empty_image_embeds_dino, self.empty_image_embeds_clip], dim=1 |
|
) |
|
|
|
|
|
self.clip_model.eval() |
|
for k, p in self.clip_model.named_parameters(): |
|
ks = k.split(".") |
|
if ( |
|
"mod_norm1" in ks |
|
or "mod_norm2" in ks |
|
and not self.cfg.freeze_modulation_clip |
|
): |
|
p.requires_grad_(not self.cfg.freeze_modulation_clip) |
|
else: |
|
p.requires_grad_(False) |
|
|
|
|
|
self.dino_model.eval() |
|
for k, p in self.dino_model.named_parameters(): |
|
ks = k.split(".") |
|
if ( |
|
"mod_norm1" in ks |
|
or "mod_norm2" in ks |
|
and not self.cfg.freeze_modulation_dino |
|
): |
|
p.requires_grad_(not self.cfg.freeze_modulation_dino) |
|
else: |
|
p.requires_grad_(False) |
|
|
|
|
|
if ( |
|
self.clip_model.config.vision_config.hidden_size |
|
!= self.dino_model.config.hidden_size |
|
): |
|
self.linear_proj = nn.Linear( |
|
self.clip_model.config.vision_config.hidden_size, |
|
self.dino_model.config.vision_config.hidden_size, |
|
bias=False, |
|
) |
|
else: |
|
self.linear_proj = nn.Identity() |
|
|
|
if self.cfg.pretrained_model_name_or_path is not None: |
|
print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") |
|
ckpt = torch.load( |
|
self.cfg.pretrained_model_name_or_path, map_location="cpu" |
|
)["state_dict"] |
|
pretrained_model_ckpt = {} |
|
for k, v in ckpt.items(): |
|
if k.startswith("condition."): |
|
pretrained_model_ckpt[k.replace("condition.", "")] = v |
|
self.load_state_dict(pretrained_model_ckpt, strict=True) |
|
|
|
def encode_image_clip( |
|
self, |
|
images: Iterable[Optional[ImageType]], |
|
cameras: Optional[torch.Tensor] = None, |
|
force_none_camera_embeds: bool = False, |
|
return_dict: bool = False, |
|
**kwargs, |
|
) -> torch.FloatTensor: |
|
camera_embeds = None |
|
if isinstance(images, (np.ndarray, torch.Tensor)): |
|
assert ( |
|
images.min() >= 0.0 and images.max() <= 1.0 |
|
), "The pixel values should be in the range of [0, 1]" |
|
if self.cfg.encode_camera: |
|
assert cameras is not None, "The cameras should be provided" |
|
camera_embeds = self.encode_camera(cameras) |
|
pixel_values = self.transform_clip(images.permute(0, 3, 1, 2)) |
|
else: |
|
if self.cfg.encode_camera: |
|
if cameras is None: |
|
bs = len(images) // self.cfg.n_views |
|
cameras = ( |
|
self.cameras[: self.cfg.n_views] |
|
.repeat(bs, 1, 1) |
|
.to(self.clip_model.device) |
|
) |
|
camera_embeds = self.encode_camera(cameras) |
|
pixel_values = self.image_preprocess_clip.preprocess( |
|
images, |
|
return_tensors="pt", |
|
do_rescale=True, |
|
do_resize=True, |
|
size=CLIP_IMAGE_SIZE, |
|
crop_size=CLIP_IMAGE_SIZE, |
|
).pixel_values |
|
|
|
if force_none_camera_embeds: |
|
camera_embeds = None |
|
|
|
if pixel_values.ndim == 4: |
|
pixel_values = pixel_values.unsqueeze(1) |
|
if camera_embeds is not None: |
|
camera_embeds = camera_embeds.unsqueeze(1) |
|
|
|
if self.cfg.encode_camera and camera_embeds is not None: |
|
vision_outputs = self.clip_model.vision_model( |
|
pixel_values=rearrange( |
|
pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W" |
|
), |
|
condition=rearrange(camera_embeds, "B N C -> (B N) C"), |
|
) |
|
|
|
else: |
|
vision_outputs = self.clip_model.vision_model( |
|
pixel_values=rearrange( |
|
pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W" |
|
), |
|
) |
|
|
|
if return_dict: |
|
|
|
pooler_output = vision_outputs[1] |
|
image_features = self.clip_model.visual_projection(pooler_output) |
|
clip_embeds = vision_outputs.last_hidden_state |
|
|
|
clip_embeds_dict = CLIPEmbedOutput( |
|
last_hidden_state=clip_embeds, |
|
pooler_output=pooler_output, |
|
embeds=image_features, |
|
) |
|
|
|
return clip_embeds_dict |
|
else: |
|
return vision_outputs.last_hidden_state |
|
|
|
def encode_image_dino( |
|
self, |
|
images: Iterable[Optional[ImageType]], |
|
cameras: Optional[torch.Tensor] = None, |
|
force_none_camera_embeds: bool = False, |
|
return_dict: bool = False, |
|
**kwargs, |
|
) -> torch.FloatTensor: |
|
camera_embeds = None |
|
if isinstance(images, (np.ndarray, torch.Tensor)): |
|
assert ( |
|
images.min() >= 0.0 and images.max() <= 1.0 |
|
), "The pixel values should be in the range of [0, 1]" |
|
if self.cfg.encode_camera: |
|
assert cameras is not None, "The cameras should be provided" |
|
camera_embeds = self.encode_camera(cameras) |
|
pixel_values = self.transform_dino(images.permute(0, 3, 1, 2)) |
|
else: |
|
if self.cfg.encode_camera: |
|
if cameras is None: |
|
bs = len(images) // self.cfg.n_views |
|
cameras = ( |
|
self.cameras[: self.cfg.n_views] |
|
.repeat(bs, 1, 1) |
|
.to(self.dino_model.device) |
|
) |
|
camera_embeds = self.encode_camera(cameras) |
|
pixel_values = self.image_preprocess_dino.preprocess( |
|
images, |
|
return_tensors="pt", |
|
do_rescale=True, |
|
do_resize=True, |
|
size=self.cfg.image_size, |
|
crop_size=self.cfg.image_size, |
|
).pixel_values |
|
|
|
if force_none_camera_embeds: |
|
camera_embeds = None |
|
|
|
if pixel_values.ndim == 4: |
|
pixel_values = pixel_values.unsqueeze(1) |
|
if camera_embeds is not None: |
|
camera_embeds = camera_embeds.unsqueeze(1) |
|
|
|
if self.cfg.encode_camera and camera_embeds is not None: |
|
vision_outputs = self.dino_model( |
|
rearrange( |
|
pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W" |
|
), |
|
condition=rearrange(camera_embeds, "B N C -> (B N) C"), |
|
) |
|
else: |
|
vision_outputs = self.dino_model( |
|
rearrange( |
|
pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W" |
|
), |
|
) |
|
|
|
if return_dict: |
|
|
|
dino_embeds_dict = DINOEmbedOutput( |
|
last_hidden_state=vision_outputs.last_hidden_state, |
|
pooler_output=vision_outputs.pooler_output, |
|
) |
|
return dino_embeds_dict |
|
else: |
|
return vision_outputs.last_hidden_state |
|
|
|
def align_clip_dino(self, clip_embeds, dino_embeds): |
|
if ( |
|
clip_embeds.shape[-2] != dino_embeds.shape[-2] |
|
): |
|
assert ( |
|
clip_embeds.shape[-2] == (self.cfg.image_size // 14) ** 2 + 1 |
|
), "The clip embeddings should have the shape of (n_views, (image_size // 14) ** 2 + 1, 1024)" |
|
clip_embeds_patch_tokens = clip_embeds[:, 1:].view( |
|
clip_embeds.shape[0], |
|
self.cfg.image_size // 14, |
|
self.cfg.image_size // 14, |
|
1024, |
|
) |
|
clip_embeds_patch_tokens = ( |
|
torch.nn.functional.interpolate( |
|
clip_embeds_patch_tokens.permute(0, 3, 1, 2), |
|
size=(self.cfg.image_size // 14, self.cfg.image_size // 14), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
.permute(0, 2, 3, 1) |
|
.view(clip_embeds.shape[0], -1, 1024) |
|
) |
|
clip_embeds = torch.cat( |
|
[clip_embeds[:, :1], clip_embeds_patch_tokens], dim=1 |
|
) |
|
return clip_embeds, dino_embeds |
|
|
|
def encode_image( |
|
self, |
|
images: Iterable[Optional[ImageType]], |
|
cameras: Optional[torch.Tensor] = None, |
|
force_none_camera_embeds: bool = False, |
|
return_dict: bool = False, |
|
**kwargs, |
|
) -> torch.FloatTensor: |
|
clip_embeds = self.encode_image_clip(images, cameras) |
|
dino_embeds = self.encode_image_dino(images, cameras) |
|
if ( |
|
self.dino_model.__class__.__name__ == "Dinov2WithRegistersModel" |
|
): |
|
dino_embeds = torch.cat( |
|
[ |
|
dino_embeds[:, :1], |
|
dino_embeds[:, self.dino_model.config.num_register_tokens + 1 :], |
|
], |
|
dim=1, |
|
) |
|
|
|
clip_embeds = self.linear_proj(clip_embeds) |
|
|
|
if self.cfg.fuse_type == "concat": |
|
visual_embeds = torch.cat([dino_embeds, clip_embeds], dim=1) |
|
|
|
|
|
else: |
|
raise ValueError |
|
|
|
return visual_embeds |
|
|