Update pipeline_stable_diffusion_3_ipa.py
Browse files
pipeline_stable_diffusion_3_ipa.py
CHANGED
|
@@ -1154,7 +1154,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1154 |
clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
|
| 1155 |
clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
|
| 1156 |
print('encoder output shape: ', clip_image_embeds_1.shape)
|
| 1157 |
-
clip_image_embeds_1 = self.image_proj_model(clip_image_embeds_1)
|
| 1158 |
|
| 1159 |
print('projection model output shape: ', clip_image_embeds_1.shape)
|
| 1160 |
|
|
@@ -1167,7 +1166,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1167 |
clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
|
| 1168 |
clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
|
| 1169 |
clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
|
| 1170 |
-
clip_image_embeds_2 = self.image_proj_model(clip_image_embeds_2)
|
| 1171 |
clip_image_embeds_2 = clip_image_embeds_2 * scale_2
|
| 1172 |
image_prompt_embeds_list.append(clip_image_embeds_2)
|
| 1173 |
if clip_image_3 != None:
|
|
@@ -1177,7 +1175,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1177 |
clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
|
| 1178 |
clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
|
| 1179 |
clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
|
| 1180 |
-
clip_image_embeds_3 = self.image_proj_model(clip_image_embeds_3)
|
| 1181 |
clip_image_embeds_3 = clip_image_embeds_3 * scale_3
|
| 1182 |
image_prompt_embeds_list.append(clip_image_embeds_3)
|
| 1183 |
if clip_image_4 != None:
|
|
@@ -1187,7 +1184,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1187 |
clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
|
| 1188 |
clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
|
| 1189 |
clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
|
| 1190 |
-
clip_image_embeds_4 = self.image_proj_model(clip_image_embeds_4)
|
| 1191 |
clip_image_embeds_4 = clip_image_embeds_4 * scale_4
|
| 1192 |
image_prompt_embeds_list.append(clip_image_embeds_4)
|
| 1193 |
if clip_image_5 != None:
|
|
@@ -1197,11 +1193,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1197 |
clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
|
| 1198 |
clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
|
| 1199 |
clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
|
| 1200 |
-
clip_image_embeds_5 = self.image_proj_model(clip_image_embeds_5)
|
| 1201 |
clip_image_embeds_5 = clip_image_embeds_5 * scale_5
|
| 1202 |
image_prompt_embeds_list.append(clip_image_embeds_5)
|
| 1203 |
|
| 1204 |
-
clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
|
| 1205 |
print('catted embeds list with mean and unsqueeze: ',clip_image_embeds_cat_list.shape)
|
| 1206 |
seq_len, _ = clip_image_embeds_cat_list.shape
|
| 1207 |
clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)
|
|
|
|
| 1154 |
clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
|
| 1155 |
clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
|
| 1156 |
print('encoder output shape: ', clip_image_embeds_1.shape)
|
|
|
|
| 1157 |
|
| 1158 |
print('projection model output shape: ', clip_image_embeds_1.shape)
|
| 1159 |
|
|
|
|
| 1166 |
clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
|
| 1167 |
clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
|
| 1168 |
clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
|
|
|
|
| 1169 |
clip_image_embeds_2 = clip_image_embeds_2 * scale_2
|
| 1170 |
image_prompt_embeds_list.append(clip_image_embeds_2)
|
| 1171 |
if clip_image_3 != None:
|
|
|
|
| 1175 |
clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
|
| 1176 |
clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
|
| 1177 |
clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
|
|
|
|
| 1178 |
clip_image_embeds_3 = clip_image_embeds_3 * scale_3
|
| 1179 |
image_prompt_embeds_list.append(clip_image_embeds_3)
|
| 1180 |
if clip_image_4 != None:
|
|
|
|
| 1184 |
clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
|
| 1185 |
clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
|
| 1186 |
clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
|
|
|
|
| 1187 |
clip_image_embeds_4 = clip_image_embeds_4 * scale_4
|
| 1188 |
image_prompt_embeds_list.append(clip_image_embeds_4)
|
| 1189 |
if clip_image_5 != None:
|
|
|
|
| 1193 |
clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
|
| 1194 |
clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
|
| 1195 |
clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
|
|
|
|
| 1196 |
clip_image_embeds_5 = clip_image_embeds_5 * scale_5
|
| 1197 |
image_prompt_embeds_list.append(clip_image_embeds_5)
|
| 1198 |
|
| 1199 |
+
clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
|
| 1200 |
print('catted embeds list with mean and unsqueeze: ',clip_image_embeds_cat_list.shape)
|
| 1201 |
seq_len, _ = clip_image_embeds_cat_list.shape
|
| 1202 |
clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)
|