Spaces:
Build error
Build error
Update pipline_StableDiffusion_ConsistentID.py
Browse files
pipline_StableDiffusion_ConsistentID.py
CHANGED
|
@@ -419,6 +419,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
| 419 |
class_tokens_mask: Optional[torch.LongTensor] = None,
|
| 420 |
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
| 421 |
retouching: bool=False,
|
|
|
|
| 422 |
):
|
| 423 |
# 0. Default height and width to unet
|
| 424 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
@@ -604,9 +605,12 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
| 604 |
image = self.decode_latents(latents)
|
| 605 |
|
| 606 |
# 9.2 Run safety checker
|
| 607 |
-
|
| 608 |
-
image,
|
| 609 |
-
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
# 9.3 Convert to PIL
|
| 612 |
image = self.numpy_to_pil(image)
|
|
|
|
| 419 |
class_tokens_mask: Optional[torch.LongTensor] = None,
|
| 420 |
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
| 421 |
retouching: bool=False,
|
| 422 |
+
need_safetycheck: bool=True,
|
| 423 |
):
|
| 424 |
# 0. Default height and width to unet
|
| 425 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
|
| 605 |
image = self.decode_latents(latents)
|
| 606 |
|
| 607 |
# 9.2 Run safety checker
|
| 608 |
+
if need_safetycheck:
|
| 609 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
| 610 |
+
image, device, prompt_embeds.dtype
|
| 611 |
+
)
|
| 612 |
+
else:
|
| 613 |
+
has_nsfw_concept = None
|
| 614 |
|
| 615 |
# 9.3 Convert to PIL
|
| 616 |
image = self.numpy_to_pil(image)
|