1inkusFace commited on
Commit
5d6dc06
·
verified ·
1 Parent(s): 154abbe

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +12 -3
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -1204,9 +1204,18 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1204
  # FAILS TIMESTEPS clip_image_embeds = torch.cat(image_prompt_embeds_list, dim=0).mean(dim=0)
1205
 
1206
 
1207
- #clip_image_embeds = torch.mean(torch.stack(image_prompt_embeds_list), dim=0) # working
1208
-
1209
- clip_image_embeds = torch.cat([torch.zeros_like(image_prompt_embeds_list),image_prompt_embeds_list]).mean(dim=0).unsqueeze(0)
 
 
 
 
 
 
 
 
 
1210
  bs_embed, seq_len, _ = clip_image_embeds.shape
1211
  clip_image_embeds = clip_image_embeds.repeat(1, 1, 1)
1212
  clip_image_embeds = clip_image_embeds.view(2, -1)
 
1204
  # FAILS TIMESTEPS clip_image_embeds = torch.cat(image_prompt_embeds_list, dim=0).mean(dim=0)
1205
 
1206
 
1207
+ # 1. Stack the image embeddings
1208
+ stacked_image_embeds = torch.stack(image_prompt_embeds_list)
1209
+
1210
+ # 2. Calculate the mean of the stacked embeddings
1211
+ average_image_embed = torch.mean(stacked_image_embeds, dim=0).unsqueeze(0) # Add batch dimension after averaging
1212
+
1213
+ # 3. Create a tensor of zeros with the same shape as the averaged embedding
1214
+ zeros_tensor = torch.zeros_like(average_image_embed)
1215
+
1216
+ # 4. Concatenate the zeros and the average embedding
1217
+ clip_image_embeds = torch.cat([zeros_tensor, average_image_embed], dim=0)
1218
+
1219
  bs_embed, seq_len, _ = clip_image_embeds.shape
1220
  clip_image_embeds = clip_image_embeds.repeat(1, 1, 1)
1221
  clip_image_embeds = clip_image_embeds.view(2, -1)