ford442 commited on
Commit
a8007c5
·
verified ·
1 Parent(s): d9e3d60

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +11 -8
ip_adapter/ip_adapter.py CHANGED
@@ -203,6 +203,8 @@ class IPAdapterXL(IPAdapter):
203
  pil_image_5=None,
204
  prompt=None,
205
  negative_prompt=None,
 
 
206
  scale_1=1.0,
207
  scale_2=1.0,
208
  scale_3=1.0,
@@ -214,8 +216,8 @@ class IPAdapterXL(IPAdapter):
214
  guidance_scale=7.5,
215
  **kwargs,
216
  ):
217
- self.get_scale()
218
- self.set_scale(scale_1)
219
 
220
  if isinstance(pil_image_1, Image.Image):
221
  num_prompts = 1
@@ -241,26 +243,26 @@ class IPAdapterXL(IPAdapter):
241
 
242
  if pil_image_2 != None:
243
  print('Using secondary image.')
244
- self.set_scale(scale_2)
245
  image_prompt_embeds_2, uncond_image_prompt_embeds_2 = self.get_image_embeds(pil_image_2)
 
246
  image_prompt_embeds_list.append(image_prompt_embeds_2)
247
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_2)
248
  if pil_image_3 != None:
249
- print('Using secondary image.')
250
- self.set_scale(scale_3)
251
  image_prompt_embeds_3, uncond_image_prompt_embeds_3 = self.get_image_embeds(pil_image_3)
 
252
  image_prompt_embeds_list.append(image_prompt_embeds_3)
253
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_3)
254
  if pil_image_4 != None:
255
  print('Using secondary image.')
256
- self.set_scale(scale_4)
257
  image_prompt_embeds_4, uncond_image_prompt_embeds_4 = self.get_image_embeds(pil_image_4)
 
258
  image_prompt_embeds_list.append(image_prompt_embeds_4)
259
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_4)
260
  if pil_image_5 != None:
261
  print('Using secondary image.')
262
- self.set_scale(scale_5)
263
  image_prompt_embeds_5, uncond_image_prompt_embeds_5 = self.get_image_embeds(pil_image_5)
 
264
  image_prompt_embeds_list.append(image_prompt_embeds_5)
265
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
266
 
@@ -276,7 +278,8 @@ class IPAdapterXL(IPAdapter):
276
  with torch.inference_mode():
277
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
278
  prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
279
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
 
280
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
281
 
282
  generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
 
203
  pil_image_5=None,
204
  prompt=None,
205
  negative_prompt=None,
206
+ text_scale=1.0,
207
+ ip_scale=1.0,
208
  scale_1=1.0,
209
  scale_2=1.0,
210
  scale_3=1.0,
 
216
  guidance_scale=7.5,
217
  **kwargs,
218
  ):
219
+ #self.get_scale()
220
+ self.set_scale(ip_scale)
221
 
222
  if isinstance(pil_image_1, Image.Image):
223
  num_prompts = 1
 
243
 
244
  if pil_image_2 != None:
245
  print('Using secondary image.')
 
246
  image_prompt_embeds_2, uncond_image_prompt_embeds_2 = self.get_image_embeds(pil_image_2)
247
+ image_prompt_embeds_2 = image_prompt_embeds_2 * scale_2
248
  image_prompt_embeds_list.append(image_prompt_embeds_2)
249
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_2)
250
  if pil_image_3 != None:
251
+ print('Using tertiary image.')
 
252
  image_prompt_embeds_3, uncond_image_prompt_embeds_3 = self.get_image_embeds(pil_image_3)
253
+ image_prompt_embeds_3 = image_prompt_embeds_3 * scale_3
254
  image_prompt_embeds_list.append(image_prompt_embeds_3)
255
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_3)
256
  if pil_image_4 != None:
257
  print('Using secondary image.')
 
258
  image_prompt_embeds_4, uncond_image_prompt_embeds_4 = self.get_image_embeds(pil_image_4)
259
+ image_prompt_embeds_4 = image_prompt_embeds_4 * scale_4
260
  image_prompt_embeds_list.append(image_prompt_embeds_4)
261
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_4)
262
  if pil_image_5 != None:
263
  print('Using secondary image.')
 
264
  image_prompt_embeds_5, uncond_image_prompt_embeds_5 = self.get_image_embeds(pil_image_5)
265
+ image_prompt_embeds_5 = image_prompt_embeds_5 * scale_5
266
  image_prompt_embeds_list.append(image_prompt_embeds_5)
267
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
268
 
 
278
  with torch.inference_mode():
279
  prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
280
  prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
281
+ prompt_embeds = prompt_embeds * text_scale
282
+ prompt_embeds = prompt_embeds * text_scale = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
283
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
284
 
285
  generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None