ford442 commited on
Commit
301b68a
·
verified ·
1 Parent(s): 5c38014

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +16 -7
ip_adapter/ip_adapter.py CHANGED
@@ -189,6 +189,7 @@ class IPAdapterXL(IPAdapter):
189
  def generate(
190
  self,
191
  pil_image,
 
192
  prompt=None,
193
  negative_prompt=None,
194
  scale=1.0,
@@ -213,14 +214,22 @@ class IPAdapterXL(IPAdapter):
213
  prompt = [prompt] * num_prompts
214
  if not isinstance(negative_prompt, List):
215
  negative_prompt = [negative_prompt] * num_prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
218
- bs_embed, seq_len, _ = image_prompt_embeds.shape
219
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
220
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
221
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
222
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
223
-
224
  with torch.inference_mode():
225
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
226
  prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
 
189
  def generate(
190
  self,
191
  pil_image,
192
+ pil_image_2=None,
193
  prompt=None,
194
  negative_prompt=None,
195
  scale=1.0,
 
214
  prompt = [prompt] * num_prompts
215
  if not isinstance(negative_prompt, List):
216
  negative_prompt = [negative_prompt] * num_prompts
217
+ if pil_image_2 != None:
218
+ print('Using secondary image.')
219
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
220
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
221
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
222
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
223
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
224
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
225
+ else:
226
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
227
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
228
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
229
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
230
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
231
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
232
 
 
 
 
 
 
 
 
233
  with torch.inference_mode():
234
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
235
  prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)