ford442 commited on
Commit
86b3ed9
·
verified ·
1 Parent(s): 633ba49

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +49 -18
ip_adapter/ip_adapter.py CHANGED
@@ -185,18 +185,25 @@ class IPAdapterXL(IPAdapter):
185
 
186
  def generate(
187
  self,
188
- pil_image,
189
  pil_image_2=None,
 
 
 
190
  prompt=None,
191
  negative_prompt=None,
192
- scale=1.0,
 
 
 
 
193
  num_samples=4,
194
  seed=-1,
195
  num_inference_steps=30,
196
  guidance_scale=7.5,
197
  **kwargs,
198
  ):
199
- self.set_scale(scale)
200
 
201
  if isinstance(pil_image, Image.Image):
202
  num_prompts = 1
@@ -212,24 +219,48 @@ class IPAdapterXL(IPAdapter):
212
  prompt = [prompt] * num_prompts
213
  if not isinstance(negative_prompt, List):
214
  negative_prompt = [negative_prompt] * num_prompts
 
 
 
 
 
 
 
 
215
  if pil_image_2 != None:
216
  print('Using secondary image.')
217
- image_prompt_embeds_1, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
218
  image_prompt_embeds_2, uncond_image_prompt_embeds_2 = self.get_image_embeds(pil_image_2)
219
- image_prompt_embeds = torch.cat((image_prompt_embeds_1,image_prompt_embeds_2)).mean(dim=0).unsqueeze(0)
220
- bs_embed, seq_len, _ = image_prompt_embeds.shape
221
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
222
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
223
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
224
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
225
- else:
226
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
227
- bs_embed, seq_len, _ = image_prompt_embeds.shape
228
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
229
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
230
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
231
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
232
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  with torch.inference_mode():
234
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
235
  prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
 
185
 
186
  def generate(
187
  self,
188
+ pil_image_1,
189
  pil_image_2=None,
190
+ pil_image_3=None,
191
+ pil_image_4=None,
192
+ pil_image_5=None,
193
  prompt=None,
194
  negative_prompt=None,
195
+ scale_1=1.0,
196
+ scale_2=1.0,
197
+ scale_3=1.0,
198
+ scale_4=1.0,
199
+ scale_5=1.0,
200
  num_samples=4,
201
  seed=-1,
202
  num_inference_steps=30,
203
  guidance_scale=7.5,
204
  **kwargs,
205
  ):
206
+ self.set_scale(scale_1)
207
 
208
  if isinstance(pil_image, Image.Image):
209
  num_prompts = 1
 
219
  prompt = [prompt] * num_prompts
220
  if not isinstance(negative_prompt, List):
221
  negative_prompt = [negative_prompt] * num_prompts
222
+
223
+ image_prompt_embeds_list = []
224
+ uncond_image_prompt_embeds_list = []
225
+
226
+ image_prompt_embeds_1, uncond_image_prompt_embeds_1 = self.get_image_embeds(pil_image)
227
+ image_prompt_embeds_list.append(image_prompt_embeds_1)
228
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_1)
229
+
230
  if pil_image_2 != None:
231
  print('Using secondary image.')
232
+ self.set_scale(scale_2)
233
  image_prompt_embeds_2, uncond_image_prompt_embeds_2 = self.get_image_embeds(pil_image_2)
234
+ image_prompt_embeds_list.append(image_prompt_embeds_2)
235
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_2)
236
+ if pil_image_3 != None:
237
+ print('Using secondary image.')
238
+ self.set_scale(scale_3)
239
+ image_prompt_embeds_3, uncond_image_prompt_embeds_3 = self.get_image_embeds(pil_image_3)
240
+ image_prompt_embeds_list.append(image_prompt_embeds_3)
241
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_3)
242
+ if pil_image_4 != None:
243
+ print('Using secondary image.')
244
+ self.set_scale(scale_4)
245
+ image_prompt_embeds_4, uncond_image_prompt_embeds_4 = self.get_image_embeds(pil_image_4)
246
+ image_prompt_embeds_list.append(image_prompt_embeds_4)
247
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_4)
248
+ if pil_image_5 != None:
249
+ print('Using secondary image.')
250
+ self.set_scale(scale_5)
251
+ image_prompt_embeds_5, uncond_image_prompt_embeds_5 = self.get_image_embeds(pil_image_5)
252
+ image_prompt_embeds_list.append(image_prompt_embeds_5)
253
+ uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
254
+
255
+ image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
256
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
257
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
258
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
259
+
260
+ uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
261
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
262
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
263
+
264
  with torch.inference_mode():
265
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
266
  prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)