1inkusFace commited on
Commit
b1bfbda
·
verified ·
1 Parent(s): a127d4d

Update pipeline_stable_diffusion_3_ipa.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +22 -81
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -1148,72 +1148,51 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1148
  print('Using primary image.')
1149
  clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1150
  #clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
1151
- clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
1152
- print('clip output size: ', clip_image_embeds_1.shape)
1153
- clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1154
- clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
 
1155
  print('encoder output size: ', clip_image_embeds_1.shape)
1156
  clip_image_embeds_1 = clip_image_embeds_1 * scale_1
1157
  image_prompt_embeds_list.append(clip_image_embeds_1)
1158
  if clip_image_2 != None:
1159
  print('Using secondary image.')
1160
  clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
1161
- #clip_image_embeds_2 = self.encode_clip_image_emb(clip_image, device, dtype)
1162
- clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1163
- clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1164
- clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
1165
  clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1166
  image_prompt_embeds_list.append(clip_image_embeds_2)
1167
  if clip_image_3 != None:
1168
  print('Using tertiary image.')
1169
  clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
1170
- #clip_image_embeds_3 = self.encode_clip_image_emb(clip_image, device, dtype)
1171
- clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1172
- clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1173
- clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
1174
  clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1175
  image_prompt_embeds_list.append(clip_image_embeds_3)
1176
  if clip_image_4 != None:
1177
  print('Using quaternary image.')
1178
  clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
1179
- #clip_image_embeds_4 = self.encode_clip_image_emb(clip_image, device, dtype)
1180
- clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1181
- clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1182
- clip_image_embeds_2 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
1183
  clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1184
  image_prompt_embeds_list.append(clip_image_embeds_4)
1185
  if clip_image_5 != None:
1186
  print('Using quinary image.')
1187
  clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
1188
- #clip_image_embeds_5 = self.encode_clip_image_emb(clip_image, device, dtype)
1189
- clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1190
- clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1191
- clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
1192
  clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1193
  image_prompt_embeds_list.append(clip_image_embeds_5)
1194
 
1195
- # Concatenate the image embeddings
1196
- ## clip_image_embeds = torch.mean(torch.stack(image_prompt_embeds_list), dim=0)
1197
-
1198
-
1199
- # clip_image_embeds = torch.cat(image_prompt_embeds_list, dim=0).mean(dim=0) #.unsqueeze(0)
1200
- #bs_embed, seq_len = clip_image_embeds.shape
1201
- #clip_image_embeds = clip_image_embeds.view(bs_embed, seq_len) # Simplified reshape
1202
-
1203
- # experimental way
1204
- #clip_image_embeds = torch.cat([torch.zeros_like(torch.stack(image_prompt_embeds_list)), torch.stack(image_prompt_embeds_list)], dim=0).mean(dim=0)
1205
- # FAILS clip_image_embeds = torch.cat(torch.stack(image_prompt_embeds_list), dim=0).mean(dim=0)
1206
- # FAILS TIMESTEPS clip_image_embeds = torch.cat(image_prompt_embeds_list, dim=0).mean(dim=0)
1207
-
1208
- clip_image_embeds_stack_list = torch.stack(image_prompt_embeds_list).mean(dim=0)
1209
- print('stacked with mean dim 0 shape: ', clip_image_embeds_stack_list.shape)
1210
- zeroes_tensor = torch.zeros_like(clip_image_embeds_stack_list)
1211
- print('zeros shape: ', zeroes_tensor.shape)
1212
- clip_image_embeds = torch.cat([zeroes_tensor, clip_image_embeds_stack_list], dim=0)
1213
- print('embeds shape old: ', clip_image_embeds.shape)
1214
-
1215
-
1216
-
1217
  clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
1218
  print('catted embeds list with mean: ',clip_image_embeds_cat_list.shape)
1219
  seq_len, _ = clip_image_embeds_cat_list.shape
@@ -1224,45 +1203,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1224
  zeros_tensor = torch.zeros_like(clip_image_embeds_view)
1225
  print('zeros: ',zeros_tensor.shape)
1226
  clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_view], dim=0)
1227
- print('embeds shape new: ', clip_image_embeds.shape)
1228
-
1229
- # 1. Stack the image embeddings
1230
- stacked_image_embeds = torch.cat(image_prompt_embeds_list, dim=1)
1231
- print('shape 1: ', stacked_image_embeds.shape)
1232
- # 2. Calculate the mean of the stacked embeddings
1233
- average_image_embed = torch.mean(stacked_image_embeds, dim=0) #.unsqueeze(0) # Add batch dimension after averaging
1234
- print('shape 2: ', average_image_embed.shape)
1235
- # 3. Create a tensor of zeros with the same shape as the averaged embedding
1236
- zeros_tensor = torch.zeros_like(average_image_embed)
1237
- #print('shape 3: ', zeros_tensor.shape)
1238
- zeros_tensor_repeat = zeros_tensor.repeat(1, 1, 1)
1239
- print('shape 3.1: ', zeros_tensor_repeat.shape)
1240
- clip_image_embeds_repeat = average_image_embed.repeat(1, 1, 1)
1241
- print('shape 3.5: ', clip_image_embeds_repeat.shape)
1242
- # 4. Concatenate the zeros and the average embedding
1243
- clip_image_embeds_cat = torch.cat([zeros_tensor, average_image_embed], dim=0)
1244
- print('shape 4: ', clip_image_embeds_cat.shape)
1245
- clip_image_embeds_cat_repeat = clip_image_embeds_cat.repeat(1, 1, 1)
1246
- print('shape 4.1: ', clip_image_embeds_cat_repeat.shape)
1247
- clip_image_embeds_repeat_cat = torch.cat([zeros_tensor_repeat, clip_image_embeds_repeat], dim=0)
1248
- print('shape 4a: ', clip_image_embeds_repeat_cat.shape)
1249
- clip_image_embeds_repeat_cat_1 = torch.cat([zeros_tensor_repeat, clip_image_embeds_repeat], dim=1)
1250
- print('shape 4b: ', clip_image_embeds_repeat_cat_1.shape)
1251
- #clip_image_embeds = clip_image_embeds_repeat_cat
1252
- '''
1253
- #clip_image_embeds = clip_image_embeds.unsqueeze(0) # Add a dimension at the beginning so now you have [1, 2*seq_len_img, embed_dim_img]
1254
- print('shape 5: ', clip_image_embeds.shape)
1255
-
1256
- bs_embed, seq_len, _ = clip_image_embeds.shape
1257
-
1258
- clip_image_embedsa = clip_image_embeds.view(bs_embed, 1, -1)
1259
- print('shape 7: ', clip_image_embedsa.shape)
1260
- clip_image_embedsb = clip_image_embeds.view(seq_len, -1)
1261
- print('shape 8: ', clip_image_embedsb.shape)
1262
-
1263
- clip_image_embeds = clip_image_embedsb
1264
- '''
1265
- #clip_image_embeds = torch.cat([torch.stack(image_prompt_embeds_list)], dim=0).mean(dim=0)
1266
 
1267
  # 4. Prepare timesteps
1268
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
 
1148
  print('Using primary image.')
1149
  clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1150
  #clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
1151
+ with torch.inference_mode():
1152
+ clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
1153
+ print('clip output size: ', clip_image_embeds_1.shape)
1154
+ clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1155
+ clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
1156
  print('encoder output size: ', clip_image_embeds_1.shape)
1157
  clip_image_embeds_1 = clip_image_embeds_1 * scale_1
1158
  image_prompt_embeds_list.append(clip_image_embeds_1)
1159
  if clip_image_2 != None:
1160
  print('Using secondary image.')
1161
  clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
1162
+ with torch.inference_mode():
1163
+ clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1164
+ clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1165
+ clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
1166
  clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1167
  image_prompt_embeds_list.append(clip_image_embeds_2)
1168
  if clip_image_3 != None:
1169
  print('Using tertiary image.')
1170
  clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
1171
+ with torch.inference_mode():
1172
+ clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1173
+ clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1174
+ clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
1175
  clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1176
  image_prompt_embeds_list.append(clip_image_embeds_3)
1177
  if clip_image_4 != None:
1178
  print('Using quaternary image.')
1179
  clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
1180
+ with torch.inference_mode():
1181
+ clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1182
+ clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1183
+ clip_image_embeds_2 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
1184
  clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1185
  image_prompt_embeds_list.append(clip_image_embeds_4)
1186
  if clip_image_5 != None:
1187
  print('Using quinary image.')
1188
  clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
1189
+ with torch.inference_mode():
1190
+ clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1191
+ clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1192
+ clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
1193
  clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1194
  image_prompt_embeds_list.append(clip_image_embeds_5)
1195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1196
  clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
1197
  print('catted embeds list with mean: ',clip_image_embeds_cat_list.shape)
1198
  seq_len, _ = clip_image_embeds_cat_list.shape
 
1203
  zeros_tensor = torch.zeros_like(clip_image_embeds_view)
1204
  print('zeros: ',zeros_tensor.shape)
1205
  clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_view], dim=0)
1206
+ print('embeds shape: ', clip_image_embeds.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1207
 
1208
  # 4. Prepare timesteps
1209
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)