File size: 1,561 Bytes
699b9c3 c4ee5c3 699b9c3 c4ee5c3 699b9c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import functools
import logging
import typing
import beartype
import torch
from jaxtyping import Float, jaxtyped
from torch import Tensor
from torchvision.transforms import v2
logger = logging.getLogger("modeling.py")
@jaxtyped(typechecker=beartype.beartype)
class SplitDinov2(torch.nn.Module):
def __init__(self, *, split_at: int):
super().__init__()
self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").eval()
self.split_at = split_at
def forward_start(
self, x: Float[Tensor, "batch channels width height"]
) -> Float[Tensor, "batch total_patches dim"]:
x_BPD = self.vit.prepare_tokens_with_masks(x)
for blk in self.vit.blocks[: self.split_at]:
x_BPD = blk(x_BPD)
return x_BPD
def forward_end(
self, x_BPD: Float[Tensor, "batch total_patches dim"]
) -> Float[Tensor, "batch patches dim"]:
for blk in self.vit.blocks[-self.split_at :]:
x_BPD = blk(x_BPD)
x_BPD = self.vit.norm(x_BPD)
return x_BPD[:, self.vit.num_register_tokens + 1 :]
@functools.cache
def load_vit(device: str) -> tuple[SplitDinov2, typing.Callable]:
vit = SplitDinov2(split_at=11).to(device)
vit_transform = v2.Compose([
v2.Resize(size=(256, 256)),
v2.CenterCrop(size=(224, 224)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
])
logger.info("Loaded ViT.")
return vit, vit_transform
|