1inkusFace commited on
Commit
53e5941
·
verified ·
1 Parent(s): 04e520e

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +12 -4
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -1150,10 +1150,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1150
  #clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
1151
  with torch.inference_mode():
1152
  clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
1153
- print('clip output size: ', clip_image_embeds_1.shape)
1154
  clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1155
  clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
1156
- print('encoder output size: ', clip_image_embeds_1.shape)
 
 
 
 
1157
  clip_image_embeds_1 = clip_image_embeds_1 * scale_1
1158
  image_prompt_embeds_list.append(clip_image_embeds_1)
1159
  if clip_image_2 != None:
@@ -1163,6 +1167,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1163
  clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1164
  clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1165
  clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
 
1166
  clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1167
  image_prompt_embeds_list.append(clip_image_embeds_2)
1168
  if clip_image_3 != None:
@@ -1172,6 +1177,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1172
  clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1173
  clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1174
  clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
 
1175
  clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1176
  image_prompt_embeds_list.append(clip_image_embeds_3)
1177
  if clip_image_4 != None:
@@ -1181,6 +1187,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1181
  clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1182
  clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1183
  clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
 
1184
  clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1185
  image_prompt_embeds_list.append(clip_image_embeds_4)
1186
  if clip_image_5 != None:
@@ -1190,11 +1197,12 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1190
  clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1191
  clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1192
  clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
 
1193
  clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1194
  image_prompt_embeds_list.append(clip_image_embeds_5)
1195
 
1196
- clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
1197
- print('catted embeds list with mean: ',clip_image_embeds_cat_list.shape)
1198
  seq_len, _ = clip_image_embeds_cat_list.shape
1199
  clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)
1200
  print('catted embeds repeat: ',clip_image_embeds_cat_list_repeat.shape)
 
1150
  #clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
1151
  with torch.inference_mode():
1152
  clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
1153
+ print('clip output shape: ', clip_image_embeds_1.shape)
1154
  clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1155
  clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
1156
+ print('encoder output shape: ', clip_image_embeds_1.shape)
1157
+ clip_image_embeds_1 = self.image_proj_model(clip_image_embeds_1)
1158
+
1159
+ print('projection model output shape: ', clip_image_embeds_1.shape)
1160
+
1161
  clip_image_embeds_1 = clip_image_embeds_1 * scale_1
1162
  image_prompt_embeds_list.append(clip_image_embeds_1)
1163
  if clip_image_2 != None:
 
1167
  clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1168
  clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1169
  clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
1170
+ clip_image_embeds_2 = self.image_proj_model(clip_image_embeds_2)
1171
  clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1172
  image_prompt_embeds_list.append(clip_image_embeds_2)
1173
  if clip_image_3 != None:
 
1177
  clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1178
  clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1179
  clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
1180
+ clip_image_embeds_3 = self.image_proj_model(clip_image_embeds_3)
1181
  clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1182
  image_prompt_embeds_list.append(clip_image_embeds_3)
1183
  if clip_image_4 != None:
 
1187
  clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1188
  clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1189
  clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
1190
+ clip_image_embeds_4 = self.image_proj_model(clip_image_embeds_4)
1191
  clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1192
  image_prompt_embeds_list.append(clip_image_embeds_4)
1193
  if clip_image_5 != None:
 
1197
  clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1198
  clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1199
  clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
1200
+ clip_image_embeds_5 = self.image_proj_model(clip_image_embeds_5)
1201
  clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1202
  image_prompt_embeds_list.append(clip_image_embeds_5)
1203
 
1204
+ clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
1205
+ print('catted embeds list with mean and unsqueeze: ',clip_image_embeds_cat_list.shape)
1206
  seq_len, _ = clip_image_embeds_cat_list.shape
1207
  clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)
1208
  print('catted embeds repeat: ',clip_image_embeds_cat_list_repeat.shape)