bardofcodes commited on
Commit
e8f69d9
·
verified ·
1 Parent(s): 434d295

Create analogy_input_processor.py

Browse files
analogy_input_processor/analogy_input_processor.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ADOBE CONFIDENTIAL
3
+ Copyright 2024 Adobe
4
+ All Rights Reserved.
5
+ NOTICE: All information contained herein is, and remains
6
+ the property of Adobe and its suppliers, if any. The intellectual
7
+ and technical concepts contained herein are proprietary to Adobe
8
+ and its suppliers and are protected by all applicable intellectual
9
+ property laws, including trade secret and copyright laws.
10
+ Dissemination of this information or reproduction of this material
11
+ is strictly forbidden unless prior written permission is obtained
12
+ from Adobe.
13
+ """
14
+
15
+ import torch as th
16
+ from torchvision import transforms
17
+ from diffusers import ModelMixin
18
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
19
+
20
+ DINO_SIZE = 224
21
+ DINO_MEAN = [0.485, 0.456, 0.406]
22
+ DINO_STD = [0.229, 0.224, 0.225]
23
+
24
+ SIGLIP_SIZE = 256
25
+ SIGLIP_MEAN = [0.5]
26
+ SIGLIP_STD = [0.5]
27
+
28
+
29
+ class AnalogyInputProcessor(ModelMixin, ConfigMixin):
30
+
31
+ @register_to_config
32
+ def __init__(self,):
33
+ super(AnalogyInputProcessor, self).__init__()
34
+
35
+ self.dino_transform = transforms.Compose(
36
+ [
37
+ transforms.Resize((DINO_SIZE, DINO_SIZE)),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(DINO_MEAN, DINO_STD),
40
+ ]
41
+ )
42
+
43
+ self.siglip_transform = transforms.Compose(
44
+ [
45
+ transforms.Resize((SIGLIP_SIZE, SIGLIP_SIZE)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(SIGLIP_MEAN, SIGLIP_STD),
48
+ ]
49
+ )
50
+
51
+ dino_mean = th.tensor(DINO_MEAN).view(1, 3, 1, 1)
52
+ dino_std = th.tensor(DINO_STD).view(1, 3, 1, 1)
53
+ siglip_mean = [SIGLIP_MEAN[0],] * 3
54
+ siglip_std = [SIGLIP_STD[0],] * 3
55
+ siglip_mean = th.tensor(siglip_mean).view(1, 3, 1, 1)
56
+ siglip_std = th.tensor(siglip_std).view(1, 3, 1, 1)
57
+ self.register_buffer("dino_mean", dino_mean)
58
+ self.register_buffer("dino_std", dino_std)
59
+ self.register_buffer("siglip_mean", siglip_mean)
60
+ self.register_buffer("siglip_std", siglip_std)
61
+
62
+ def __call__(self, analogy_prompt):
63
+ # List of tuples of (A, A*, B)
64
+ img_a_dino = []
65
+ img_a_siglip = []
66
+ img_a_star_dino = []
67
+ img_a_star_siglip = []
68
+ img_b_dino = []
69
+ img_b_siglip = []
70
+
71
+ for im_set in analogy_prompt:
72
+ img_a, img_a_star, img_b = im_set
73
+ img_a_dino.append(self.dino_transform(img_a))
74
+ img_a_siglip.append(self.siglip_transform(img_a))
75
+ img_a_star_dino.append(self.dino_transform(img_a_star))
76
+ img_a_star_siglip.append(self.siglip_transform(img_a_star))
77
+ img_b_dino.append(self.dino_transform(img_b))
78
+ img_b_siglip.append(self.siglip_transform(img_b))
79
+
80
+ img_a_dino = th.stack(img_a_dino, 0)
81
+ img_a_siglip = th.stack(img_a_siglip, 0)
82
+ img_a_star_dino = th.stack(img_a_star_dino, 0)
83
+ img_a_star_siglip = th.stack(img_a_star_siglip, 0)
84
+ img_b_dino = th.stack(img_b_dino, 0)
85
+ img_b_siglip = th.stack(img_b_siglip, 0)
86
+
87
+ dino_combined_input = th.stack([img_b_dino, img_a_dino, img_a_star_dino], 0)
88
+ siglip_combined_input = th.stack([img_b_siglip, img_a_siglip, img_a_star_siglip], 0)
89
+
90
+ return dino_combined_input, siglip_combined_input
91
+ def get_negative(self, dino_in, siglip_in):
92
+
93
+ dino_i = ((dino_in * 0 + 0.5) - self.dino_mean) / self.dino_std
94
+ siglip_i = ((siglip_in * 0 + 0.5) - self.siglip_mean) / self.siglip_std
95
+ return dino_i, siglip_i
96
+