File size: 1,018 Bytes
077dc3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image

from .abstract_embedder import AbstractImageEmbedder


class DinoV2Embedder(AbstractImageEmbedder):
    def __init__(self, device: str = "cpu"):
        """Embedder using DINOv2 embeddings.
        """
        super().__init__(device)
        self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device)
        self.model.eval()
        self.transforms = T.Compose([
            T.Resize((256, 256), interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def embed(self, image: Image) -> np.ndarray:
        image = image.convert("RGB")
        image = self.transforms(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            output = self.model(image)[0].cpu().numpy()
        return output