Spaces:
Runtime error
Runtime error
Added tensor dim expansion for edit directions.
Browse files
styleclip/styleclip_global.py
CHANGED
@@ -89,6 +89,7 @@ imagenet_templates = [
|
|
89 |
'a tattoo of the {}.',
|
90 |
]
|
91 |
|
|
|
92 |
FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \
|
93 |
[(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)]
|
94 |
|
@@ -107,6 +108,16 @@ def zeroshot_classifier(model, classnames, templates, device):
|
|
107 |
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
108 |
return zeroshot_weights
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
def get_direction(neutral_class, target_class, beta, di, clip_model=None):
|
112 |
|
@@ -157,6 +168,8 @@ def style_dict_to_style_tensor(style_dict, reference_generator):
|
|
157 |
def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
|
158 |
edit_direction = get_direction(source_class, target_class, beta, di, clip_model)
|
159 |
|
|
|
|
|
160 |
source_s = style_dict_to_style_tensor(source_latent, reference_generator)
|
161 |
|
162 |
-
return source_s + alpha *
|
|
|
89 |
'a tattoo of the {}.',
|
90 |
]
|
91 |
|
92 |
+
CONV_CODE_INDICES = [(0, 512), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)]
|
93 |
FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \
|
94 |
[(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)]
|
95 |
|
|
|
108 |
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
109 |
return zeroshot_weights
|
110 |
|
111 |
+
def expand_to_full_dim(partial_tensor):
|
112 |
+
full_dim_tensor = torch.zeros(size=(1, 9088))
|
113 |
+
|
114 |
+
start_idx = 0
|
115 |
+
for conv_start, conv_end in CONV_CODE_INDICES:
|
116 |
+
length = conv_end - conv_start
|
117 |
+
full_dim_tensor[:, conv_start:conv_end] = partial_tensor[:, start_idx:start_idx + length]
|
118 |
+
start_idx += length
|
119 |
+
|
120 |
+
return full_dim_tensor
|
121 |
|
122 |
def get_direction(neutral_class, target_class, beta, di, clip_model=None):
|
123 |
|
|
|
168 |
def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
|
169 |
edit_direction = get_direction(source_class, target_class, beta, di, clip_model)
|
170 |
|
171 |
+
edit_full_dim = expand_to_full_dim(edit_direction)
|
172 |
+
|
173 |
source_s = style_dict_to_style_tensor(source_latent, reference_generator)
|
174 |
|
175 |
+
return source_s + alpha * edit_full_dim
|