bardofcodes commited on
Commit
b4de0ac
·
verified ·
1 Parent(s): c34c90c

Create analogy_encoder.py

Browse files
Files changed (1) hide show
  1. analogy_encoder/analogy_encoder.py +68 -0
analogy_encoder/analogy_encoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ADOBE CONFIDENTIAL
3
+ Copyright 2024 Adobe
4
+ All Rights Reserved.
5
+ NOTICE: All information contained herein is, and remains
6
+ the property of Adobe and its suppliers, if any. The intellectual
7
+ and technical concepts contained herein are proprietary to Adobe
8
+ and its suppliers and are protected by all applicable intellectual
9
+ property laws, including trade secret and copyright laws.
10
+ Dissemination of this information or reproduction of this material
11
+ is strictly forbidden unless prior written permission is obtained
12
+ from Adobe.
13
+ """
14
+
15
+ import torch as th
16
+ from diffusers import ModelMixin
17
+ from transformers import AutoModel, SiglipVisionConfig, Dinov2Config
18
+ from transformers import SiglipVisionModel
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+
22
+ class AnalogyEncoder(ModelMixin, ConfigMixin):
23
+ @register_to_config
24
+ def __init__(self, load_pretrained=False,
25
+ dino_config_dict=None, siglip_config_dict=None):
26
+ super().__init__()
27
+ if load_pretrained:
28
+ image_encoder_dino = AutoModel.from_pretrained('facebook/dinov2-large', torch_dtype=th.float16)
29
+ image_encoder_siglip = SiglipVisionModel.from_pretrained("google/siglip-large-patch16-256", torch_dtype=th.float16, attn_implementation="sdpa")
30
+ else:
31
+ image_encoder_dino = AutoModel.from_config(Dinov2Config.from_dict(dino_config_dict))
32
+ image_encoder_siglip = AutoModel.from_config(SiglipVisionConfig.from_dict(siglip_config_dict))
33
+
34
+ image_encoder_dino.requires_grad_(False)
35
+ image_encoder_dino = image_encoder_dino.to(memory_format=th.channels_last)
36
+
37
+ image_encoder_siglip.requires_grad_(False)
38
+ image_encoder_siglip = image_encoder_siglip.to(memory_format=th.channels_last)
39
+ self.image_encoder_dino = image_encoder_dino
40
+ self.image_encoder_siglip = image_encoder_siglip
41
+
42
+
43
+ def dino_normalization(self, encoder_output):
44
+ embeds = encoder_output.last_hidden_state
45
+ embeds_pooled = embeds[:, 0:1]
46
+ embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
47
+ return embeds
48
+
49
+ def siglip_normalization(self, encoder_output):
50
+ embeds = th.cat ([encoder_output.pooler_output[:, None, :], encoder_output.last_hidden_state], dim=1)
51
+ embeds_pooled = embeds[:, 0:1]
52
+ embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
53
+ return embeds
54
+
55
+ def forward(self, dino_in, siglip_in):
56
+
57
+ x_1 = self.image_encoder_dino(dino_in, output_hidden_states=True)
58
+ x_1_first = x_1.hidden_states[0]
59
+ x_1 = self.dino_normalization(x_1)
60
+ x_2 = self.image_encoder_siglip(siglip_in, output_hidden_states=True)
61
+ x_2_first = x_2.hidden_states[0]
62
+ x_2_first_pool = th.mean(x_2_first, dim=1, keepdim=True)
63
+ x_2_first = th.cat([x_2_first_pool, x_2_first], 1)
64
+ x_2 = self.siglip_normalization(x_2)
65
+ dino_embd = th.cat([x_1, x_1_first], -1)
66
+ siglip_embd = th.cat([x_2, x_2_first], -1)
67
+ return dino_embd, siglip_embd
68
+