Spaces:
Sleeping
Sleeping
| import PIL | |
| from PIL.Image import Image | |
| from typing import Union | |
| from sklearn.decomposition import PCA | |
| import torch | |
| from torch import nn | |
| from torchvision import transforms as tfs | |
| MEAN = [0.485, 0.456, 0.406] | |
| STD = [0.229, 0.224, 0.225] | |
| DINO_MODEL_HUB = 'facebookresearch/dino:main' | |
| DINO_MODEL_TYPE = ['dino_vits16', | |
| 'dino_vits8', | |
| 'dino_vitb16', | |
| 'dino_vitb8', | |
| 'dino_xcit_small_12_p16', | |
| 'dino_xcit_small_12_p8', | |
| 'dino_xcit_medium_24_p16', | |
| 'dino_xcit_medium_24_p8', | |
| 'dino_resnet50'] | |
| DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main' | |
| DINOV2_MODEL_TYPE = ['dinov2_vits14', | |
| 'dinov2_vitb14', | |
| 'dinov2_vitl14', | |
| 'dinov2_vitg14'] | |
| class DINO(nn.Module): | |
| def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None): | |
| super(DINO, self).__init__() | |
| assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!' | |
| self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device) | |
| self.device = device | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| self.model.eval() | |
| self.img_size = img_size | |
| self.pca_dim = pca_dim | |
| self.pca = self.set_pca(pca_dim) if pca_dim else None | |
| def set_pca(self, dim=64): | |
| return PCA(n_components=dim) | |
| def extract_features( | |
| self, img: Union[Image, torch.Tensor], transform=True, size=None | |
| ): | |
| if transform and isinstance(img, Image): | |
| img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW | |
| with torch.no_grad(): | |
| out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0] | |
| out = out[:, 1:, :] # we discard the [CLS] token | |
| h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int( | |
| img.shape[3] / self.model.patch_embed.patch_size | |
| ) | |
| dim = out.shape[-1] | |
| out = out.reshape(-1, h, w, dim) | |
| dtype = out.dtype | |
| if size is not None: | |
| out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1) | |
| if self.pca: | |
| B, H, W, C = out.shape | |
| out = out.view(-1, C).cpu().numpy() | |
| out = self.pca.fit_transform(out) | |
| out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device) | |
| return out | |
| def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None): | |
| return self.extract_features(img, transform, size) | |
| def transform(img, image_size): | |
| transforms = tfs.Compose( | |
| [tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)] | |
| ) | |
| img = transforms(img) | |
| return img | |
| class DINOV2(nn.Module): | |
| def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None): | |
| super(DINOV2, self).__init__() | |
| assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!' | |
| self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device) | |
| self.device = device | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| self.model.eval() | |
| self.img_size = img_size | |
| self.pca_dim = pca_dim | |
| self.pca = self.set_pca(pca_dim) if pca_dim else None | |
| def set_pca(self, dim=64): | |
| return PCA(n_components=dim) | |
| def extract_features( | |
| self, img: Union[Image, torch.Tensor], transform=True, size=None | |
| ): | |
| if transform and isinstance(img, Image): | |
| img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW | |
| with torch.no_grad(): | |
| out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens'] | |
| h, w = int(img.shape[2] / self.model.patch_size), int( | |
| img.shape[3] / self.model.patch_size | |
| ) | |
| dim = out.shape[-1] | |
| out = out.reshape(-1, h, w, dim) | |
| dtype = out.dtype | |
| if size is not None: | |
| out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1) | |
| if self.pca: | |
| B, H, W, C = out.shape | |
| out = out.view(-1, C).cpu().numpy() | |
| out = self.pca.fit_transform(out) | |
| out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device) | |
| return out | |
| def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None): | |
| return self.extract_features(img, transform, size) | |
| def transform(img, image_size): | |
| transforms = tfs.Compose( | |
| [tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)] | |
| ) | |
| img = transforms(img) | |
| return img | |