1inkusFace commited on
Commit
4c32653
·
verified ·
1 Parent(s): c2041c6

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +5 -11
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -1206,32 +1206,26 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1206
  # FAILS TIMESTEPS clip_image_embeds = torch.cat(image_prompt_embeds_list, dim=0).mean(dim=0)
1207
 
1208
  # 1. Stack the image embeddings
1209
- stacked_image_embedsg = torch.stack(image_prompt_embeds_list)
1210
  stacked_image_embeds = torch.cat(image_prompt_embeds_list, dim=1)
1211
  print('shape 1: ', stacked_image_embeds.shape)
1212
- print('shape 1a: ', stacked_image_embedsg.shape)
1213
  # 2. Calculate the mean of the stacked embeddings
1214
  average_image_embed = torch.mean(stacked_image_embeds, dim=0) #.unsqueeze(0) # Add batch dimension after averaging
1215
  print('shape 2: ', average_image_embed.shape)
1216
- average_image_embedf = torch.mean(stacked_image_embeds, dim=1).unsqueeze(0) # Add batch dimension after averaging
1217
- print('shape 2a: ', average_image_embedf.shape)
1218
-
1219
  # 3. Create a tensor of zeros with the same shape as the averaged embedding
1220
  zeros_tensor = torch.zeros_like(average_image_embed)
1221
- print('shape 3: ', zeros_tensor.shape)
1222
- zeros_tensor = torch.zeros_like(average_image_embed)
1223
- zeros_tensora = average_image_embed.repeat(1, 1, 1)
1224
  print('shape 3.1: ', zeros_tensora.shape)
1225
  clip_image_embedsa = average_image_embed.repeat(1, 1, 1)
1226
  print('shape 3.5: ', clip_image_embedsa.shape)
1227
- clip_image_embedse = torch.cat([zeros_tensora, clip_image_embedsa], dim=0)
1228
- print('shape 3.8: ', clip_image_embedse.shape)
1229
  # 4. Concatenate the zeros and the average embedding
1230
  clip_image_embeds2 = torch.cat([zeros_tensor, average_image_embed], dim=0)
 
1231
  print('shape 4: ', clip_image_embeds2.shape)
1232
  clip_image_embeds = torch.cat([zeros_tensora, clip_image_embedsa], dim=0)
1233
  print('shape 4a: ', clip_image_embeds.shape)
1234
-
 
1235
  '''
1236
  #clip_image_embeds = clip_image_embeds.unsqueeze(0) # Add a dimension at the beginning so now you have [1, 2*seq_len_img, embed_dim_img]
1237
  print('shape 5: ', clip_image_embeds.shape)
 
1206
  # FAILS TIMESTEPS clip_image_embeds = torch.cat(image_prompt_embeds_list, dim=0).mean(dim=0)
1207
 
1208
  # 1. Stack the image embeddings
 
1209
  stacked_image_embeds = torch.cat(image_prompt_embeds_list, dim=1)
1210
  print('shape 1: ', stacked_image_embeds.shape)
 
1211
  # 2. Calculate the mean of the stacked embeddings
1212
  average_image_embed = torch.mean(stacked_image_embeds, dim=0) #.unsqueeze(0) # Add batch dimension after averaging
1213
  print('shape 2: ', average_image_embed.shape)
 
 
 
1214
  # 3. Create a tensor of zeros with the same shape as the averaged embedding
1215
  zeros_tensor = torch.zeros_like(average_image_embed)
1216
+ #print('shape 3: ', zeros_tensor.shape)
1217
+ zeros_tensora = zeros_tensor.repeat(1, 1, 1)
 
1218
  print('shape 3.1: ', zeros_tensora.shape)
1219
  clip_image_embedsa = average_image_embed.repeat(1, 1, 1)
1220
  print('shape 3.5: ', clip_image_embedsa.shape)
 
 
1221
  # 4. Concatenate the zeros and the average embedding
1222
  clip_image_embeds2 = torch.cat([zeros_tensor, average_image_embed], dim=0)
1223
+ clip_image_embeds3 = clip_image_embeds2.repeat(1, 1, 1)
1224
  print('shape 4: ', clip_image_embeds2.shape)
1225
  clip_image_embeds = torch.cat([zeros_tensora, clip_image_embedsa], dim=0)
1226
  print('shape 4a: ', clip_image_embeds.shape)
1227
+ clip_image_embeds = torch.cat([zeros_tensora, clip_image_embedsa], dim=0)
1228
+ print('shape 4b: ', clip_image_embeds3.shape)
1229
  '''
1230
  #clip_image_embeds = clip_image_embeds.unsqueeze(0) # Add a dimension at the beginning so now you have [1, 2*seq_len_img, embed_dim_img]
1231
  print('shape 5: ', clip_image_embeds.shape)