YouLiXiya's picture
Upload 22 files
import PIL
from PIL.Image import Image
from typing import Union
from sklearn.decomposition import PCA
import torch
from torch import nn
from torchvision import transforms as tfs
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DINO_MODEL_HUB = 'facebookresearch/dino:main'
DINO_MODEL_TYPE = ['dino_vits16',
DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main'
DINOV2_MODEL_TYPE = ['dinov2_vits14',
class DINO(nn.Module):
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
super(DINO, self).__init__()
assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device)
self.device = device
for param in self.model.parameters():
param.requires_grad = False
self.img_size = img_size
self.pca_dim = pca_dim
self.pca = self.set_pca(pca_dim) if pca_dim else None
def set_pca(self, dim=64):
return PCA(n_components=dim)
def extract_features(
self, img: Union[Image, torch.Tensor], transform=True, size=None
if transform and isinstance(img, Image):
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
with torch.no_grad():
out = self.model.get_intermediate_layers(, n=1)[0]
out = out[:, 1:, :] # we discard the [CLS] token
h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int(
img.shape[3] / self.model.patch_embed.patch_size
dim = out.shape[-1]
out = out.reshape(-1, h, w, dim)
dtype = out.dtype
if size is not None:
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
if self.pca:
B, H, W, C = out.shape
out = out.view(-1, C).cpu().numpy()
out = self.pca.fit_transform(out)
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
return out
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
return self.extract_features(img, transform, size)
def transform(img, image_size):
transforms = tfs.Compose(
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
img = transforms(img)
return img
class DINOV2(nn.Module):
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
super(DINOV2, self).__init__()
assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device)
self.device = device
for param in self.model.parameters():
param.requires_grad = False
self.img_size = img_size
self.pca_dim = pca_dim
self.pca = self.set_pca(pca_dim) if pca_dim else None
def set_pca(self, dim=64):
return PCA(n_components=dim)
def extract_features(
self, img: Union[Image, torch.Tensor], transform=True, size=None
if transform and isinstance(img, Image):
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
with torch.no_grad():
out = self.model.forward_features(['x_norm_patchtokens']
h, w = int(img.shape[2] / self.model.patch_size), int(
img.shape[3] / self.model.patch_size
dim = out.shape[-1]
out = out.reshape(-1, h, w, dim)
dtype = out.dtype
if size is not None:
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
if self.pca:
B, H, W, C = out.shape
out = out.view(-1, C).cpu().numpy()
out = self.pca.fit_transform(out)
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
return out
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
return self.extract_features(img, transform, size)
def transform(img, image_size):
transforms = tfs.Compose(
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
img = transforms(img)
return img