1inkusFace commited on
Commit
679d53d
·
verified ·
1 Parent(s): 53e5941

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +1 -6
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -1154,7 +1154,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1154
  clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1155
  clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
1156
  print('encoder output shape: ', clip_image_embeds_1.shape)
1157
- clip_image_embeds_1 = self.image_proj_model(clip_image_embeds_1)
1158
 
1159
  print('projection model output shape: ', clip_image_embeds_1.shape)
1160
 
@@ -1167,7 +1166,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1167
  clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1168
  clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1169
  clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
1170
- clip_image_embeds_2 = self.image_proj_model(clip_image_embeds_2)
1171
  clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1172
  image_prompt_embeds_list.append(clip_image_embeds_2)
1173
  if clip_image_3 != None:
@@ -1177,7 +1175,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1177
  clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1178
  clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1179
  clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
1180
- clip_image_embeds_3 = self.image_proj_model(clip_image_embeds_3)
1181
  clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1182
  image_prompt_embeds_list.append(clip_image_embeds_3)
1183
  if clip_image_4 != None:
@@ -1187,7 +1184,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1187
  clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1188
  clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1189
  clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
1190
- clip_image_embeds_4 = self.image_proj_model(clip_image_embeds_4)
1191
  clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1192
  image_prompt_embeds_list.append(clip_image_embeds_4)
1193
  if clip_image_5 != None:
@@ -1197,11 +1193,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1197
  clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1198
  clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1199
  clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
1200
- clip_image_embeds_5 = self.image_proj_model(clip_image_embeds_5)
1201
  clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1202
  image_prompt_embeds_list.append(clip_image_embeds_5)
1203
 
1204
- clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
1205
  print('catted embeds list with mean and unsqueeze: ',clip_image_embeds_cat_list.shape)
1206
  seq_len, _ = clip_image_embeds_cat_list.shape
1207
  clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)
 
1154
  clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1155
  clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
1156
  print('encoder output shape: ', clip_image_embeds_1.shape)
 
1157
 
1158
  print('projection model output shape: ', clip_image_embeds_1.shape)
1159
 
 
1166
  clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1167
  clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1168
  clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
 
1169
  clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1170
  image_prompt_embeds_list.append(clip_image_embeds_2)
1171
  if clip_image_3 != None:
 
1175
  clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1176
  clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1177
  clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
 
1178
  clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1179
  image_prompt_embeds_list.append(clip_image_embeds_3)
1180
  if clip_image_4 != None:
 
1184
  clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1185
  clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1186
  clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
 
1187
  clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1188
  image_prompt_embeds_list.append(clip_image_embeds_4)
1189
  if clip_image_5 != None:
 
1193
  clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1194
  clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1195
  clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
 
1196
  clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1197
  image_prompt_embeds_list.append(clip_image_embeds_5)
1198
 
1199
+ clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
1200
  print('catted embeds list with mean and unsqueeze: ',clip_image_embeds_cat_list.shape)
1201
  seq_len, _ = clip_image_embeds_cat_list.shape
1202
  clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)