1inkusFace commited on
Commit
cda906f
Β·
verified Β·
1 Parent(s): 71dcf58

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +3 -10
ip_adapter/ip_adapter.py CHANGED
@@ -270,16 +270,9 @@ class IPAdapterXL(IPAdapter):
270
  image_prompt_embeds_list.append(image_prompt_embeds_5)
271
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
272
 
273
- image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
274
- print('catted embeds list with mean and unsqueeze shape: ',image_prompt_embeds.shape)
275
- bs_embed, seq_len, _ = image_prompt_embeds.shape
276
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
277
- print('catted embeds repeat: ',image_prompt_embeds.shape)
278
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
279
- print('viewed embeds: ',image_prompt_embeds.shape)
280
- uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
281
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
282
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
283
 
284
  with torch.inference_mode():
285
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
 
270
  image_prompt_embeds_list.append(image_prompt_embeds_5)
271
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
272
 
273
+ image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0)
274
+ image_prompt_embeds = image_prompt_embeds.mean(dim=0, keepdim=True)
275
+ uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0)
 
 
 
 
 
 
 
276
 
277
  with torch.inference_mode():
278
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(