1inkusFace commited on
Commit
611d8c7
·
verified ·
1 Parent(s): 3a607cf

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +16 -2
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -1210,8 +1210,22 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1210
  zeroes_tensor = torch.zeros_like(clip_image_embeds_stack_list)
1211
  print('zeros shape: ', zeroes_tensor.shape)
1212
  clip_image_embeds = torch.cat([zeroes_tensor, clip_image_embeds_stack_list], dim=0)
1213
- print('embeds shape: ', clip_image_embeds.shape)
1214
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1215
  # 1. Stack the image embeddings
1216
  stacked_image_embeds = torch.cat(image_prompt_embeds_list, dim=1)
1217
  print('shape 1: ', stacked_image_embeds.shape)
 
1210
  zeroes_tensor = torch.zeros_like(clip_image_embeds_stack_list)
1211
  print('zeros shape: ', zeroes_tensor.shape)
1212
  clip_image_embeds = torch.cat([zeroes_tensor, clip_image_embeds_stack_list], dim=0)
1213
+ print('embeds shape old: ', clip_image_embeds.shape)
1214
+
1215
+
1216
+
1217
+ clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
1218
+ print('catted embeds list with mean: ',image_prompt_embeds.shape)
1219
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
1220
+ clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)
1221
+ print('catted embeds repeat: ',clip_image_embeds_cat_list_repeat.shape)
1222
+ clip_image_embeds_view = clip_image_embeds_cat_list_repeat.view(bs_embed * 1, seq_len, -1)
1223
+ print('catted viewed: ',clip_image_embeds_view.shape)
1224
+ zeros_tensor = torch.zeros_like(clip_image_embeds_view)
1225
+ print('zeros: ',zeros_tensor.shape)
1226
+ clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_view], dim=0)
1227
+ print('embeds shape new: ', clip_image_embeds.shape)
1228
+
1229
  # 1. Stack the image embeddings
1230
  stacked_image_embeds = torch.cat(image_prompt_embeds_list, dim=1)
1231
  print('shape 1: ', stacked_image_embeds.shape)