|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from accelerate.logging import get_logger |
|
from lam.models.encoders.dinov2_unet import DINOBase |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class Dinov2UnetWrapper(nn.Module): |
|
""" |
|
Dino v2 wrapper using original implementation, hacked with modulation. |
|
""" |
|
def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384): |
|
super().__init__() |
|
self.modulation_dim = modulation_dim |
|
|
|
self.model = DINOBase(output_dim=encoder_feat_dim) |
|
assert model_name in ["no_avg", "avg_2"] |
|
self.model_name = model_name |
|
|
|
if freeze: |
|
if modulation_dim is not None: |
|
raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.") |
|
self._freeze() |
|
else: |
|
for name, param in self.model.dino_model.named_parameters(): |
|
if name == "mask_token": |
|
param.requires_grad = False |
|
|
|
def _freeze(self): |
|
logger.warning(f"======== Freezing Dinov2UnetWrapper ========") |
|
self.model.dino_model.eval() |
|
for name, param in self.model.dino_model.named_parameters(): |
|
param.requires_grad = False |
|
|
|
@staticmethod |
|
def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True): |
|
from importlib import import_module |
|
dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__) |
|
model_fn = getattr(dinov2_hub, model_name) |
|
logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.") |
|
model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained) |
|
return model |
|
|
|
@torch.compile |
|
def forward(self, image: torch.Tensor, mod: torch.Tensor = None): |
|
|
|
|
|
|
|
if self.modulation_dim is None: |
|
assert mod is None, "Unexpected modulation input in dinov2 forward." |
|
outs = self.model(image, is_training=True) |
|
else: |
|
assert mod is not None, "Modulation input is required in modulated dinov2 forward." |
|
outs = self.model(image, mod=mod, is_training=True) |
|
|
|
out_local, out_global = outs |
|
|
|
if self.model_name == "avg_2": |
|
out_local = nn.functional.avg_pool2d(out_local, stride=2, kernel_size=2) |
|
|
|
if out_global is not None: |
|
ret = torch.cat([out_local.permute(0, 2, 3, 1).flatten(1, 2), out_global.unsqueeze(1)], dim=1) |
|
else: |
|
ret = out_local.permute(0, 2, 3, 1).flatten(1, 2) |
|
return ret |
|
|