Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| # References: | |
| # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py | |
| from functools import partial | |
| import math | |
| import logging | |
| from typing import Sequence, Tuple, Union, Callable | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint | |
| from torch.nn.init import trunc_normal_ | |
| from .dinov2.hub.backbones import dinov2_vitb14 | |
| class FrozenDinoV2ImageEmbedder(nn.Module): | |
| """ | |
| Uses the dinov2 image encoder with camera modulation. | |
| Not actually frozen... If you want that set cond_stage_trainable=False in cfg | |
| """ | |
| def __init__( | |
| self, | |
| version='dinov2_vitb14', | |
| ckpt_path=None, | |
| lrm_mode='plain_lrm', | |
| ): | |
| super().__init__() | |
| self.lrm_mode = lrm_mode | |
| assert version in ['dinov2_vitb14', 'dinov2_vits14', 'dinov2_vitl14', 'dinov2_vitg14'] | |
| self.model = dinov2_vitb14(pretrained=False) | |
| if ckpt_path is not None: | |
| self.load_pretrained(ckpt_path) | |
| else: | |
| print('None pretrained model for dinov2 encoder ...') | |
| def load_pretrained(self, ckpt_path): | |
| print('Loading dinov2 encoder ...') | |
| orig_state_dict = torch.load(ckpt_path, map_location='cpu') | |
| try: | |
| ret = self.model.load_state_dict(orig_state_dict, strict=False) | |
| print(ret) | |
| print('Successfully loaded orig state dict') | |
| except: | |
| new_state_dict = OrderedDict() | |
| for k, v in orig_state_dict['state_dict'].items(): | |
| if 'img_encoder' in k: | |
| new_state_dict[k.replace('img_encoder.model.', '')] = v | |
| ret = self.model.load_state_dict(new_state_dict, strict=False) | |
| print(ret) | |
| print('Successfully loaded new state dict') | |
| def forward(self, x, *args, **kwargs): | |
| ret = self.model.forward_features_with_camera(x, *args, **kwargs) | |
| output = torch.cat([ret['x_norm_clstoken'].unsqueeze(1), ret['x_norm_patchtokens']], dim=1) | |
| return output | |