|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch DINOv2 model.""" |
|
|
|
from typing import Dict, List, Optional, Set, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .modeling_dinov2 import ( |
|
Dinov2Config, |
|
Dinov2Layer, |
|
Dinov2Model, |
|
Dinov2Embeddings, |
|
BaseModelOutput, |
|
BaseModelOutputWithPooling, |
|
) |
|
|
|
|
|
class ModLN(nn.Module): |
|
def __init__(self, inner_dim: int, mod_dim: int = 1024): |
|
super().__init__() |
|
self.mlp = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(mod_dim, inner_dim * 2), |
|
) |
|
|
|
for m in self.modules(): |
|
if isinstance(m, nn.Linear): |
|
nn.init.zeros_(m.weight) |
|
nn.init.zeros_(m.bias) |
|
|
|
def forward(self, x: torch.Tensor, condition: torch.Tensor): |
|
""" |
|
x: [N, M, C_in], M: num of tokens |
|
condition: [N, C_mod] |
|
""" |
|
shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) |
|
return x * (1 + scale) + shift |
|
|
|
|
|
class ConditionalDinov2Config(Dinov2Config): |
|
def __init__(self, modulation_dim: int = 1024, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.modulation_dim = modulation_dim |
|
|
|
|
|
class ConditionalDinov2Layer(Dinov2Layer): |
|
"""This corresponds to the Block class in the original implementation.""" |
|
|
|
def __init__(self, config: ConditionalDinov2Config) -> None: |
|
super().__init__(config) |
|
self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) |
|
self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
head_mask: Optional[torch.Tensor] = None, |
|
condition: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: |
|
self_attention_outputs = self.attention( |
|
self.mod_norm1( |
|
self.norm1(hidden_states), condition |
|
), |
|
head_mask, |
|
output_attentions=output_attentions, |
|
) |
|
attention_output = self_attention_outputs[0] |
|
|
|
attention_output = self.layer_scale1(attention_output) |
|
outputs = self_attention_outputs[ |
|
1: |
|
] |
|
|
|
|
|
hidden_states = self.drop_path(attention_output) + hidden_states |
|
|
|
|
|
layer_output = self.mod_norm2(self.norm2(hidden_states), condition) |
|
layer_output = self.mlp(layer_output) |
|
layer_output = self.layer_scale2(layer_output) |
|
|
|
|
|
layer_output = self.drop_path(layer_output) + hidden_states |
|
|
|
outputs = (layer_output,) + outputs |
|
|
|
return outputs |
|
|
|
|
|
|
|
class ConditionalDinov2Encoder(nn.Module): |
|
def __init__(self, config: ConditionalDinov2Config) -> None: |
|
super().__init__() |
|
self.config = config |
|
self.layer = nn.ModuleList( |
|
[ConditionalDinov2Layer(config) for _ in range(config.num_hidden_layers)] |
|
) |
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
condition: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[tuple, BaseModelOutput]: |
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attentions = () if output_attentions else None |
|
|
|
for i, layer_module in enumerate(self.layer): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
layer_module.__call__, |
|
hidden_states, |
|
layer_head_mask, |
|
condition, |
|
output_attentions, |
|
) |
|
else: |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
layer_head_mask, |
|
condition, |
|
output_attentions, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, all_hidden_states, all_self_attentions] |
|
if v is not None |
|
) |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
) |
|
|
|
|
|
class ConditionalDinov2Model(Dinov2Model): |
|
config_class = ConditionalDinov2Config |
|
|
|
def __init__(self, config: ConditionalDinov2Config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = Dinov2Embeddings(config) |
|
self.encoder = ConditionalDinov2Encoder(config) |
|
|
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
bool_masked_pos: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
condition: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
embedding_output = self.embeddings( |
|
pixel_values, bool_masked_pos=bool_masked_pos |
|
) |
|
|
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
condition=condition, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = encoder_outputs[0] |
|
sequence_output = self.layernorm(sequence_output) |
|
pooled_output = sequence_output[:, 0, :] |
|
|
|
if not return_dict: |
|
head_outputs = (sequence_output, pooled_output) |
|
return head_outputs + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|