Update pipeline.py
Browse files- pipeline.py +134 -66
pipeline.py
CHANGED
@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
|
|
8 |
import PIL
|
9 |
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
|
10 |
-
from diffusers.onnx_utils import
|
11 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
12 |
-
from diffusers.utils import
|
|
|
13 |
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
17 |
|
18 |
re_attention = re.compile(
|
@@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|
390 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
391 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
392 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
417 |
self.unet_in_channels = 4
|
418 |
self.vae_scale_factor = 8
|
419 |
|
@@ -741,49 +811,47 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|
741 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
742 |
|
743 |
# 8. Denoising loop
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
)
|
758 |
-
noise_pred = noise_pred[0]
|
759 |
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
764 |
|
765 |
-
# compute the previous noisy sample x_t -> x_t-1
|
766 |
-
scheduler_output = self.scheduler.step(
|
767 |
-
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
768 |
-
)
|
769 |
-
latents = scheduler_output.prev_sample.numpy()
|
770 |
-
|
771 |
-
if mask is not None:
|
772 |
-
# masking
|
773 |
-
init_latents_proper = self.scheduler.add_noise(
|
774 |
-
torch.from_numpy(init_latents_orig),
|
775 |
-
torch.from_numpy(noise),
|
776 |
-
t,
|
777 |
-
).numpy()
|
778 |
-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
779 |
-
|
780 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
781 |
-
progress_bar.update()
|
782 |
-
if i % callback_steps == 0:
|
783 |
-
if callback is not None:
|
784 |
-
callback(i, t, latents)
|
785 |
-
if is_cancelled_callback is not None and is_cancelled_callback():
|
786 |
-
return None
|
787 |
# 9. Post-processing
|
788 |
image = self.decode_latents(latents)
|
789 |
|
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
8 |
+
import diffusers
|
9 |
import PIL
|
10 |
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
|
11 |
+
from diffusers.onnx_utils import OnnxRuntimeModel
|
12 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
13 |
+
from diffusers.utils import deprecate, logging
|
14 |
+
from packaging import version
|
15 |
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
16 |
|
17 |
|
18 |
+
try:
|
19 |
+
from diffusers.onnx_utils import ORT_TO_NP_TYPE
|
20 |
+
except ImportError:
|
21 |
+
ORT_TO_NP_TYPE = {
|
22 |
+
"tensor(bool)": np.bool_,
|
23 |
+
"tensor(int8)": np.int8,
|
24 |
+
"tensor(uint8)": np.uint8,
|
25 |
+
"tensor(int16)": np.int16,
|
26 |
+
"tensor(uint16)": np.uint16,
|
27 |
+
"tensor(int32)": np.int32,
|
28 |
+
"tensor(uint32)": np.uint32,
|
29 |
+
"tensor(int64)": np.int64,
|
30 |
+
"tensor(uint64)": np.uint64,
|
31 |
+
"tensor(float16)": np.float16,
|
32 |
+
"tensor(float)": np.float32,
|
33 |
+
"tensor(double)": np.float64,
|
34 |
+
}
|
35 |
+
|
36 |
+
try:
|
37 |
+
from diffusers.utils import PIL_INTERPOLATION
|
38 |
+
except ImportError:
|
39 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
40 |
+
PIL_INTERPOLATION = {
|
41 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
42 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
43 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
44 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
45 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
46 |
+
}
|
47 |
+
else:
|
48 |
+
PIL_INTERPOLATION = {
|
49 |
+
"linear": PIL.Image.LINEAR,
|
50 |
+
"bilinear": PIL.Image.BILINEAR,
|
51 |
+
"bicubic": PIL.Image.BICUBIC,
|
52 |
+
"lanczos": PIL.Image.LANCZOS,
|
53 |
+
"nearest": PIL.Image.NEAREST,
|
54 |
+
}
|
55 |
+
# ------------------------------------------------------------------------------
|
56 |
+
|
57 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
58 |
|
59 |
re_attention = re.compile(
|
|
|
431 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
432 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
433 |
"""
|
434 |
+
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
435 |
+
|
436 |
+
def __init__(
|
437 |
+
self,
|
438 |
+
vae_encoder: OnnxRuntimeModel,
|
439 |
+
vae_decoder: OnnxRuntimeModel,
|
440 |
+
text_encoder: OnnxRuntimeModel,
|
441 |
+
tokenizer: CLIPTokenizer,
|
442 |
+
unet: OnnxRuntimeModel,
|
443 |
+
scheduler: SchedulerMixin,
|
444 |
+
safety_checker: OnnxRuntimeModel,
|
445 |
+
feature_extractor: CLIPFeatureExtractor,
|
446 |
+
requires_safety_checker: bool = True,
|
447 |
+
):
|
448 |
+
super().__init__(
|
449 |
+
vae_encoder=vae_encoder,
|
450 |
+
vae_decoder=vae_decoder,
|
451 |
+
text_encoder=text_encoder,
|
452 |
+
tokenizer=tokenizer,
|
453 |
+
unet=unet,
|
454 |
+
scheduler=scheduler,
|
455 |
+
safety_checker=safety_checker,
|
456 |
+
feature_extractor=feature_extractor,
|
457 |
+
requires_safety_checker=requires_safety_checker,
|
458 |
+
)
|
459 |
+
self.__init__additional__()
|
460 |
|
461 |
+
else:
|
462 |
+
|
463 |
+
def __init__(
|
464 |
+
self,
|
465 |
+
vae_encoder: OnnxRuntimeModel,
|
466 |
+
vae_decoder: OnnxRuntimeModel,
|
467 |
+
text_encoder: OnnxRuntimeModel,
|
468 |
+
tokenizer: CLIPTokenizer,
|
469 |
+
unet: OnnxRuntimeModel,
|
470 |
+
scheduler: SchedulerMixin,
|
471 |
+
safety_checker: OnnxRuntimeModel,
|
472 |
+
feature_extractor: CLIPFeatureExtractor,
|
473 |
+
):
|
474 |
+
super().__init__(
|
475 |
+
vae_encoder=vae_encoder,
|
476 |
+
vae_decoder=vae_decoder,
|
477 |
+
text_encoder=text_encoder,
|
478 |
+
tokenizer=tokenizer,
|
479 |
+
unet=unet,
|
480 |
+
scheduler=scheduler,
|
481 |
+
safety_checker=safety_checker,
|
482 |
+
feature_extractor=feature_extractor,
|
483 |
+
)
|
484 |
+
self.__init__additional__()
|
485 |
+
|
486 |
+
def __init__additional__(self):
|
487 |
self.unet_in_channels = 4
|
488 |
self.vae_scale_factor = 8
|
489 |
|
|
|
811 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
812 |
|
813 |
# 8. Denoising loop
|
814 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
815 |
+
# expand the latents if we are doing classifier free guidance
|
816 |
+
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
817 |
+
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
818 |
+
latent_model_input = latent_model_input.numpy()
|
819 |
+
|
820 |
+
# predict the noise residual
|
821 |
+
noise_pred = self.unet(
|
822 |
+
sample=latent_model_input,
|
823 |
+
timestep=np.array([t], dtype=timestep_dtype),
|
824 |
+
encoder_hidden_states=text_embeddings,
|
825 |
+
)
|
826 |
+
noise_pred = noise_pred[0]
|
|
|
|
|
827 |
|
828 |
+
# perform guidance
|
829 |
+
if do_classifier_free_guidance:
|
830 |
+
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
831 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
832 |
+
|
833 |
+
# compute the previous noisy sample x_t -> x_t-1
|
834 |
+
scheduler_output = self.scheduler.step(
|
835 |
+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
836 |
+
)
|
837 |
+
latents = scheduler_output.prev_sample.numpy()
|
838 |
+
|
839 |
+
if mask is not None:
|
840 |
+
# masking
|
841 |
+
init_latents_proper = self.scheduler.add_noise(
|
842 |
+
torch.from_numpy(init_latents_orig),
|
843 |
+
torch.from_numpy(noise),
|
844 |
+
t,
|
845 |
+
).numpy()
|
846 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
847 |
+
|
848 |
+
# call the callback, if provided
|
849 |
+
if i % callback_steps == 0:
|
850 |
+
if callback is not None:
|
851 |
+
callback(i, t, latents)
|
852 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
853 |
+
return None
|
854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
855 |
# 9. Post-processing
|
856 |
image = self.decode_latents(latents)
|
857 |
|