File size: 3,057 Bytes
b4de0ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
ADOBE CONFIDENTIAL
Copyright 2024 Adobe
All Rights Reserved.
NOTICE: All information contained herein is, and remains
the property of Adobe and its suppliers, if any. The intellectual
and technical concepts contained herein are proprietary to Adobe
and its suppliers and are protected by all applicable intellectual
property laws, including trade secret and copyright laws.
Dissemination of this information or reproduction of this material
is strictly forbidden unless prior written permission is obtained
from Adobe.
"""

import torch as th
from diffusers import ModelMixin
from transformers import AutoModel, SiglipVisionConfig, Dinov2Config
from transformers import SiglipVisionModel

from diffusers.configuration_utils import ConfigMixin, register_to_config
    
class AnalogyEncoder(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(self, load_pretrained=False, 
                 dino_config_dict=None, siglip_config_dict=None):
        super().__init__()
        if load_pretrained:
            image_encoder_dino = AutoModel.from_pretrained('facebook/dinov2-large', torch_dtype=th.float16)
            image_encoder_siglip = SiglipVisionModel.from_pretrained("google/siglip-large-patch16-256", torch_dtype=th.float16, attn_implementation="sdpa")
        else:
            image_encoder_dino = AutoModel.from_config(Dinov2Config.from_dict(dino_config_dict))
            image_encoder_siglip = AutoModel.from_config(SiglipVisionConfig.from_dict(siglip_config_dict))
            
        image_encoder_dino.requires_grad_(False)
        image_encoder_dino = image_encoder_dino.to(memory_format=th.channels_last)

        image_encoder_siglip.requires_grad_(False)
        image_encoder_siglip = image_encoder_siglip.to(memory_format=th.channels_last)
        self.image_encoder_dino = image_encoder_dino
        self.image_encoder_siglip = image_encoder_siglip


    def dino_normalization(self, encoder_output):
        embeds = encoder_output.last_hidden_state
        embeds_pooled = embeds[:, 0:1]
        embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
        return embeds
    
    def siglip_normalization(self, encoder_output):
        embeds = th.cat ([encoder_output.pooler_output[:, None, :], encoder_output.last_hidden_state], dim=1)
        embeds_pooled = embeds[:, 0:1]
        embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
        return embeds
    
    def forward(self, dino_in, siglip_in):

        x_1 = self.image_encoder_dino(dino_in, output_hidden_states=True)
        x_1_first = x_1.hidden_states[0]
        x_1 = self.dino_normalization(x_1)
        x_2 = self.image_encoder_siglip(siglip_in, output_hidden_states=True)
        x_2_first = x_2.hidden_states[0]
        x_2_first_pool = th.mean(x_2_first, dim=1, keepdim=True)
        x_2_first = th.cat([x_2_first_pool, x_2_first], 1)
        x_2 = self.siglip_normalization(x_2)
        dino_embd = th.cat([x_1, x_1_first], -1)
        siglip_embd = th.cat([x_2, x_2_first], -1)
        return dino_embd, siglip_embd