Spaces:
Runtime error
Runtime error
Update pipeline_stable_diffusion_3_ipa.py
Browse files
pipeline_stable_diffusion_3_ipa.py
CHANGED
@@ -1148,72 +1148,51 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
1148 |
print('Using primary image.')
|
1149 |
clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
|
1150 |
#clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
|
1151 |
-
|
1152 |
-
|
1153 |
-
|
1154 |
-
|
|
|
1155 |
print('encoder output size: ', clip_image_embeds_1.shape)
|
1156 |
clip_image_embeds_1 = clip_image_embeds_1 * scale_1
|
1157 |
image_prompt_embeds_list.append(clip_image_embeds_1)
|
1158 |
if clip_image_2 != None:
|
1159 |
print('Using secondary image.')
|
1160 |
clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
|
1161 |
-
|
1162 |
-
|
1163 |
-
|
1164 |
-
|
1165 |
clip_image_embeds_2 = clip_image_embeds_2 * scale_2
|
1166 |
image_prompt_embeds_list.append(clip_image_embeds_2)
|
1167 |
if clip_image_3 != None:
|
1168 |
print('Using tertiary image.')
|
1169 |
clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
|
1170 |
-
|
1171 |
-
|
1172 |
-
|
1173 |
-
|
1174 |
clip_image_embeds_3 = clip_image_embeds_3 * scale_3
|
1175 |
image_prompt_embeds_list.append(clip_image_embeds_3)
|
1176 |
if clip_image_4 != None:
|
1177 |
print('Using quaternary image.')
|
1178 |
clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
clip_image_embeds_4 = clip_image_embeds_4 * scale_4
|
1184 |
image_prompt_embeds_list.append(clip_image_embeds_4)
|
1185 |
if clip_image_5 != None:
|
1186 |
print('Using quinary image.')
|
1187 |
clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
clip_image_embeds_5 = clip_image_embeds_5 * scale_5
|
1193 |
image_prompt_embeds_list.append(clip_image_embeds_5)
|
1194 |
|
1195 |
-
# Concatenate the image embeddings
|
1196 |
-
## clip_image_embeds = torch.mean(torch.stack(image_prompt_embeds_list), dim=0)
|
1197 |
-
|
1198 |
-
|
1199 |
-
# clip_image_embeds = torch.cat(image_prompt_embeds_list, dim=0).mean(dim=0) #.unsqueeze(0)
|
1200 |
-
#bs_embed, seq_len = clip_image_embeds.shape
|
1201 |
-
#clip_image_embeds = clip_image_embeds.view(bs_embed, seq_len) # Simplified reshape
|
1202 |
-
|
1203 |
-
# experimental way
|
1204 |
-
#clip_image_embeds = torch.cat([torch.zeros_like(torch.stack(image_prompt_embeds_list)), torch.stack(image_prompt_embeds_list)], dim=0).mean(dim=0)
|
1205 |
-
# FAILS clip_image_embeds = torch.cat(torch.stack(image_prompt_embeds_list), dim=0).mean(dim=0)
|
1206 |
-
# FAILS TIMESTEPS clip_image_embeds = torch.cat(image_prompt_embeds_list, dim=0).mean(dim=0)
|
1207 |
-
|
1208 |
-
clip_image_embeds_stack_list = torch.stack(image_prompt_embeds_list).mean(dim=0)
|
1209 |
-
print('stacked with mean dim 0 shape: ', clip_image_embeds_stack_list.shape)
|
1210 |
-
zeroes_tensor = torch.zeros_like(clip_image_embeds_stack_list)
|
1211 |
-
print('zeros shape: ', zeroes_tensor.shape)
|
1212 |
-
clip_image_embeds = torch.cat([zeroes_tensor, clip_image_embeds_stack_list], dim=0)
|
1213 |
-
print('embeds shape old: ', clip_image_embeds.shape)
|
1214 |
-
|
1215 |
-
|
1216 |
-
|
1217 |
clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
|
1218 |
print('catted embeds list with mean: ',clip_image_embeds_cat_list.shape)
|
1219 |
seq_len, _ = clip_image_embeds_cat_list.shape
|
@@ -1224,45 +1203,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
1224 |
zeros_tensor = torch.zeros_like(clip_image_embeds_view)
|
1225 |
print('zeros: ',zeros_tensor.shape)
|
1226 |
clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_view], dim=0)
|
1227 |
-
print('embeds shape
|
1228 |
-
|
1229 |
-
# 1. Stack the image embeddings
|
1230 |
-
stacked_image_embeds = torch.cat(image_prompt_embeds_list, dim=1)
|
1231 |
-
print('shape 1: ', stacked_image_embeds.shape)
|
1232 |
-
# 2. Calculate the mean of the stacked embeddings
|
1233 |
-
average_image_embed = torch.mean(stacked_image_embeds, dim=0) #.unsqueeze(0) # Add batch dimension after averaging
|
1234 |
-
print('shape 2: ', average_image_embed.shape)
|
1235 |
-
# 3. Create a tensor of zeros with the same shape as the averaged embedding
|
1236 |
-
zeros_tensor = torch.zeros_like(average_image_embed)
|
1237 |
-
#print('shape 3: ', zeros_tensor.shape)
|
1238 |
-
zeros_tensor_repeat = zeros_tensor.repeat(1, 1, 1)
|
1239 |
-
print('shape 3.1: ', zeros_tensor_repeat.shape)
|
1240 |
-
clip_image_embeds_repeat = average_image_embed.repeat(1, 1, 1)
|
1241 |
-
print('shape 3.5: ', clip_image_embeds_repeat.shape)
|
1242 |
-
# 4. Concatenate the zeros and the average embedding
|
1243 |
-
clip_image_embeds_cat = torch.cat([zeros_tensor, average_image_embed], dim=0)
|
1244 |
-
print('shape 4: ', clip_image_embeds_cat.shape)
|
1245 |
-
clip_image_embeds_cat_repeat = clip_image_embeds_cat.repeat(1, 1, 1)
|
1246 |
-
print('shape 4.1: ', clip_image_embeds_cat_repeat.shape)
|
1247 |
-
clip_image_embeds_repeat_cat = torch.cat([zeros_tensor_repeat, clip_image_embeds_repeat], dim=0)
|
1248 |
-
print('shape 4a: ', clip_image_embeds_repeat_cat.shape)
|
1249 |
-
clip_image_embeds_repeat_cat_1 = torch.cat([zeros_tensor_repeat, clip_image_embeds_repeat], dim=1)
|
1250 |
-
print('shape 4b: ', clip_image_embeds_repeat_cat_1.shape)
|
1251 |
-
#clip_image_embeds = clip_image_embeds_repeat_cat
|
1252 |
-
'''
|
1253 |
-
#clip_image_embeds = clip_image_embeds.unsqueeze(0) # Add a dimension at the beginning so now you have [1, 2*seq_len_img, embed_dim_img]
|
1254 |
-
print('shape 5: ', clip_image_embeds.shape)
|
1255 |
-
|
1256 |
-
bs_embed, seq_len, _ = clip_image_embeds.shape
|
1257 |
-
|
1258 |
-
clip_image_embedsa = clip_image_embeds.view(bs_embed, 1, -1)
|
1259 |
-
print('shape 7: ', clip_image_embedsa.shape)
|
1260 |
-
clip_image_embedsb = clip_image_embeds.view(seq_len, -1)
|
1261 |
-
print('shape 8: ', clip_image_embedsb.shape)
|
1262 |
-
|
1263 |
-
clip_image_embeds = clip_image_embedsb
|
1264 |
-
'''
|
1265 |
-
#clip_image_embeds = torch.cat([torch.stack(image_prompt_embeds_list)], dim=0).mean(dim=0)
|
1266 |
|
1267 |
# 4. Prepare timesteps
|
1268 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
|
|
1148 |
print('Using primary image.')
|
1149 |
clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
|
1150 |
#clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
|
1151 |
+
with torch.inference_mode():
|
1152 |
+
clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
|
1153 |
+
print('clip output size: ', clip_image_embeds_1.shape)
|
1154 |
+
clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
|
1155 |
+
clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
|
1156 |
print('encoder output size: ', clip_image_embeds_1.shape)
|
1157 |
clip_image_embeds_1 = clip_image_embeds_1 * scale_1
|
1158 |
image_prompt_embeds_list.append(clip_image_embeds_1)
|
1159 |
if clip_image_2 != None:
|
1160 |
print('Using secondary image.')
|
1161 |
clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
|
1162 |
+
with torch.inference_mode():
|
1163 |
+
clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
|
1164 |
+
clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
|
1165 |
+
clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
|
1166 |
clip_image_embeds_2 = clip_image_embeds_2 * scale_2
|
1167 |
image_prompt_embeds_list.append(clip_image_embeds_2)
|
1168 |
if clip_image_3 != None:
|
1169 |
print('Using tertiary image.')
|
1170 |
clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
|
1171 |
+
with torch.inference_mode():
|
1172 |
+
clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
|
1173 |
+
clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
|
1174 |
+
clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
|
1175 |
clip_image_embeds_3 = clip_image_embeds_3 * scale_3
|
1176 |
image_prompt_embeds_list.append(clip_image_embeds_3)
|
1177 |
if clip_image_4 != None:
|
1178 |
print('Using quaternary image.')
|
1179 |
clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
|
1180 |
+
with torch.inference_mode():
|
1181 |
+
clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
|
1182 |
+
clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
|
1183 |
+
clip_image_embeds_2 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
|
1184 |
clip_image_embeds_4 = clip_image_embeds_4 * scale_4
|
1185 |
image_prompt_embeds_list.append(clip_image_embeds_4)
|
1186 |
if clip_image_5 != None:
|
1187 |
print('Using quinary image.')
|
1188 |
clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
|
1189 |
+
with torch.inference_mode():
|
1190 |
+
clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
|
1191 |
+
clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
|
1192 |
+
clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
|
1193 |
clip_image_embeds_5 = clip_image_embeds_5 * scale_5
|
1194 |
image_prompt_embeds_list.append(clip_image_embeds_5)
|
1195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1196 |
clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
|
1197 |
print('catted embeds list with mean: ',clip_image_embeds_cat_list.shape)
|
1198 |
seq_len, _ = clip_image_embeds_cat_list.shape
|
|
|
1203 |
zeros_tensor = torch.zeros_like(clip_image_embeds_view)
|
1204 |
print('zeros: ',zeros_tensor.shape)
|
1205 |
clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_view], dim=0)
|
1206 |
+
print('embeds shape: ', clip_image_embeds.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1207 |
|
1208 |
# 4. Prepare timesteps
|
1209 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|