Create analogy_encoder.py
Browse files
analogy_encoder/analogy_encoder.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ADOBE CONFIDENTIAL
|
3 |
+
Copyright 2024 Adobe
|
4 |
+
All Rights Reserved.
|
5 |
+
NOTICE: All information contained herein is, and remains
|
6 |
+
the property of Adobe and its suppliers, if any. The intellectual
|
7 |
+
and technical concepts contained herein are proprietary to Adobe
|
8 |
+
and its suppliers and are protected by all applicable intellectual
|
9 |
+
property laws, including trade secret and copyright laws.
|
10 |
+
Dissemination of this information or reproduction of this material
|
11 |
+
is strictly forbidden unless prior written permission is obtained
|
12 |
+
from Adobe.
|
13 |
+
"""
|
14 |
+
|
15 |
+
import torch as th
|
16 |
+
from diffusers import ModelMixin
|
17 |
+
from transformers import AutoModel, SiglipVisionConfig, Dinov2Config
|
18 |
+
from transformers import SiglipVisionModel
|
19 |
+
|
20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
21 |
+
|
22 |
+
class AnalogyEncoder(ModelMixin, ConfigMixin):
|
23 |
+
@register_to_config
|
24 |
+
def __init__(self, load_pretrained=False,
|
25 |
+
dino_config_dict=None, siglip_config_dict=None):
|
26 |
+
super().__init__()
|
27 |
+
if load_pretrained:
|
28 |
+
image_encoder_dino = AutoModel.from_pretrained('facebook/dinov2-large', torch_dtype=th.float16)
|
29 |
+
image_encoder_siglip = SiglipVisionModel.from_pretrained("google/siglip-large-patch16-256", torch_dtype=th.float16, attn_implementation="sdpa")
|
30 |
+
else:
|
31 |
+
image_encoder_dino = AutoModel.from_config(Dinov2Config.from_dict(dino_config_dict))
|
32 |
+
image_encoder_siglip = AutoModel.from_config(SiglipVisionConfig.from_dict(siglip_config_dict))
|
33 |
+
|
34 |
+
image_encoder_dino.requires_grad_(False)
|
35 |
+
image_encoder_dino = image_encoder_dino.to(memory_format=th.channels_last)
|
36 |
+
|
37 |
+
image_encoder_siglip.requires_grad_(False)
|
38 |
+
image_encoder_siglip = image_encoder_siglip.to(memory_format=th.channels_last)
|
39 |
+
self.image_encoder_dino = image_encoder_dino
|
40 |
+
self.image_encoder_siglip = image_encoder_siglip
|
41 |
+
|
42 |
+
|
43 |
+
def dino_normalization(self, encoder_output):
|
44 |
+
embeds = encoder_output.last_hidden_state
|
45 |
+
embeds_pooled = embeds[:, 0:1]
|
46 |
+
embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
|
47 |
+
return embeds
|
48 |
+
|
49 |
+
def siglip_normalization(self, encoder_output):
|
50 |
+
embeds = th.cat ([encoder_output.pooler_output[:, None, :], encoder_output.last_hidden_state], dim=1)
|
51 |
+
embeds_pooled = embeds[:, 0:1]
|
52 |
+
embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
|
53 |
+
return embeds
|
54 |
+
|
55 |
+
def forward(self, dino_in, siglip_in):
|
56 |
+
|
57 |
+
x_1 = self.image_encoder_dino(dino_in, output_hidden_states=True)
|
58 |
+
x_1_first = x_1.hidden_states[0]
|
59 |
+
x_1 = self.dino_normalization(x_1)
|
60 |
+
x_2 = self.image_encoder_siglip(siglip_in, output_hidden_states=True)
|
61 |
+
x_2_first = x_2.hidden_states[0]
|
62 |
+
x_2_first_pool = th.mean(x_2_first, dim=1, keepdim=True)
|
63 |
+
x_2_first = th.cat([x_2_first_pool, x_2_first], 1)
|
64 |
+
x_2 = self.siglip_normalization(x_2)
|
65 |
+
dino_embd = th.cat([x_1, x_1_first], -1)
|
66 |
+
siglip_embd = th.cat([x_2, x_2_first], -1)
|
67 |
+
return dino_embd, siglip_embd
|
68 |
+
|