File size: 1,561 Bytes
699b9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4ee5c3
699b9c3
 
 
 
 
 
 
c4ee5c3
699b9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import functools
import logging
import typing

import beartype
import torch
from jaxtyping import Float, jaxtyped
from torch import Tensor
from torchvision.transforms import v2

logger = logging.getLogger("modeling.py")


@jaxtyped(typechecker=beartype.beartype)
class SplitDinov2(torch.nn.Module):
    def __init__(self, *, split_at: int):
        super().__init__()

        self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").eval()
        self.split_at = split_at

    def forward_start(
        self, x: Float[Tensor, "batch channels width height"]
    ) -> Float[Tensor, "batch total_patches dim"]:
        x_BPD = self.vit.prepare_tokens_with_masks(x)
        for blk in self.vit.blocks[: self.split_at]:
            x_BPD = blk(x_BPD)

        return x_BPD

    def forward_end(
        self, x_BPD: Float[Tensor, "batch total_patches dim"]
    ) -> Float[Tensor, "batch patches dim"]:
        for blk in self.vit.blocks[-self.split_at :]:
            x_BPD = blk(x_BPD)

        x_BPD = self.vit.norm(x_BPD)
        return x_BPD[:, self.vit.num_register_tokens + 1 :]


@functools.cache
def load_vit(device: str) -> tuple[SplitDinov2, typing.Callable]:
    vit = SplitDinov2(split_at=11).to(device)
    vit_transform = v2.Compose([
        v2.Resize(size=(256, 256)),
        v2.CenterCrop(size=(224, 224)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
    ])
    logger.info("Loaded ViT.")

    return vit, vit_transform