1inkusFace commited on
Commit
583adef
·
verified ·
1 Parent(s): fa468b5

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +44 -2
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -965,7 +965,17 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
965
 
966
  # ipa
967
  clip_image=None,
 
 
 
 
 
968
  ipadapter_scale=1.0,
 
 
 
 
 
969
  ):
970
  r"""
971
  Function invoked when calling the pipeline for generation.
@@ -1126,10 +1136,42 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1126
  if self.do_classifier_free_guidance:
1127
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1128
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1129
-
 
 
 
 
1130
  # 3. prepare clip emb
1131
  clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1132
- clip_image_embeds = self.encode_clip_image_emb(clip_image, device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
 
1134
  # 4. Prepare timesteps
1135
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
 
965
 
966
  # ipa
967
  clip_image=None,
968
+ clip_image_2=None,
969
+ clip_image_3=None,
970
+ clip_image_4=None,
971
+ clip_image_5=None,
972
+ text_scale=1.0,
973
  ipadapter_scale=1.0,
974
+ scale_1=1.0,
975
+ scale_2=1.0,
976
+ scale_3=1.0,
977
+ scale_4=1.0,
978
+ scale_5=1.0,
979
  ):
980
  r"""
981
  Function invoked when calling the pipeline for generation.
 
1136
  if self.do_classifier_free_guidance:
1137
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1138
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1139
+
1140
+ prompt_embeds = prompt_embeds * text_scale
1141
+
1142
+ image_prompt_embeds_list = []
1143
+
1144
  # 3. prepare clip emb
1145
  clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1146
+ clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
1147
+ image_prompt_embeds_list.append(clip_image_embeds_1)
1148
+
1149
+ if clip_image_2 != None:
1150
+ print('Using secondary image.')
1151
+ clip_image_2 = clip_image_2.resize((max(clip_image.size), max(clip_image.size)))
1152
+ image_prompt_embeds_2 = self.encode_clip_image_emb(clip_image, device, dtype)
1153
+ image_prompt_embeds_2 = image_prompt_embeds_2 * scale_2
1154
+ image_prompt_embeds_list.append(image_prompt_embeds_2)
1155
+ if clip_image_3 != None:
1156
+ print('Using tertiary image.')
1157
+ clip_image_3 = clip_image_3.resize((max(clip_image.size), max(clip_image.size)))
1158
+ image_prompt_embeds_3 = self.encode_clip_image_emb(clip_image, device, dtype)
1159
+ image_prompt_embeds_3 = image_prompt_embeds_3 * scale_3
1160
+ image_prompt_embeds_list.append(image_prompt_embeds_3)
1161
+ if clip_image_4 != None:
1162
+ print('Using quaternary image.')
1163
+ clip_image_4 = clip_image_4.resize((max(clip_image.size), max(clip_image.size)))
1164
+ image_prompt_embeds_4 = self.encode_clip_image_emb(clip_image, device, dtype)
1165
+ image_prompt_embeds_4 = image_prompt_embeds_4 * scale_4
1166
+ image_prompt_embeds_list.append(image_prompt_embeds_4)
1167
+ if clip_image_5 != None:
1168
+ print('Using quinary image.')
1169
+ clip_image_5 = clip_image_5.resize((max(clip_image.size), max(clip_image.size)))
1170
+ image_prompt_embeds_5 = self.encode_clip_image_emb(clip_image, device, dtype)
1171
+ image_prompt_embeds_5 = image_prompt_embeds_5 * scale_5
1172
+ image_prompt_embeds_list.append(image_prompt_embeds_5)
1173
+
1174
+ clip_image_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
1175
 
1176
  # 4. Prepare timesteps
1177
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)