ford442 commited on
Commit
9faeb57
·
verified ·
1 Parent(s): 6996882

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +3 -1
ip_adapter/ip_adapter.py CHANGED
@@ -214,7 +214,9 @@ class IPAdapterXL(IPAdapter):
214
  negative_prompt = [negative_prompt] * num_prompts
215
  if pil_image_2 != None:
216
  print('Using secondary image.')
217
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
 
 
218
  bs_embed, seq_len, _ = image_prompt_embeds.shape
219
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
220
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
 
214
  negative_prompt = [negative_prompt] * num_prompts
215
  if pil_image_2 != None:
216
  print('Using secondary image.')
217
+ image_prompt_embeds_1, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
218
+ image_prompt_embeds_2, uncond_image_prompt_embeds_2 = self.get_image_embeds(pil_image_2)
219
+ image_prompt_embeds = torch.cat(image_prompt_embeds_1,image_prompt_embeds_2).mean(dim=0).unsqueeze(0)
220
  bs_embed, seq_len, _ = image_prompt_embeds.shape
221
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
222
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)