skytnt commited on
Commit
aa296ba
·
1 Parent(s): 1cd0e50

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +134 -66
pipeline.py CHANGED
@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union
5
  import numpy as np
6
  import torch
7
 
 
8
  import PIL
9
  from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
10
- from diffusers.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
11
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
12
- from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
 
13
  from transformers import CLIPFeatureExtractor, CLIPTokenizer
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
 
18
  re_attention = re.compile(
@@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
390
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
391
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
392
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
- def __init__(
395
- self,
396
- vae_encoder: OnnxRuntimeModel,
397
- vae_decoder: OnnxRuntimeModel,
398
- text_encoder: OnnxRuntimeModel,
399
- tokenizer: CLIPTokenizer,
400
- unet: OnnxRuntimeModel,
401
- scheduler: SchedulerMixin,
402
- safety_checker: OnnxRuntimeModel,
403
- feature_extractor: CLIPFeatureExtractor,
404
- requires_safety_checker: bool = True,
405
- ):
406
- super().__init__(
407
- vae_encoder=vae_encoder,
408
- vae_decoder=vae_decoder,
409
- text_encoder=text_encoder,
410
- tokenizer=tokenizer,
411
- unet=unet,
412
- scheduler=scheduler,
413
- safety_checker=safety_checker,
414
- feature_extractor=feature_extractor,
415
- requires_safety_checker=requires_safety_checker,
416
- )
 
 
 
417
  self.unet_in_channels = 4
418
  self.vae_scale_factor = 8
419
 
@@ -741,49 +811,47 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
741
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
742
 
743
  # 8. Denoising loop
744
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
745
- with self.progress_bar(total=num_inference_steps) as progress_bar:
746
- for i, t in enumerate(timesteps):
747
- # expand the latents if we are doing classifier free guidance
748
- latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
749
- latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
750
- latent_model_input = latent_model_input.numpy()
751
-
752
- # predict the noise residual
753
- noise_pred = self.unet(
754
- sample=latent_model_input,
755
- timestep=np.array([t], dtype=timestep_dtype),
756
- encoder_hidden_states=text_embeddings,
757
- )
758
- noise_pred = noise_pred[0]
759
 
760
- # perform guidance
761
- if do_classifier_free_guidance:
762
- noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
763
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
 
765
- # compute the previous noisy sample x_t -> x_t-1
766
- scheduler_output = self.scheduler.step(
767
- torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
768
- )
769
- latents = scheduler_output.prev_sample.numpy()
770
-
771
- if mask is not None:
772
- # masking
773
- init_latents_proper = self.scheduler.add_noise(
774
- torch.from_numpy(init_latents_orig),
775
- torch.from_numpy(noise),
776
- t,
777
- ).numpy()
778
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
779
-
780
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
781
- progress_bar.update()
782
- if i % callback_steps == 0:
783
- if callback is not None:
784
- callback(i, t, latents)
785
- if is_cancelled_callback is not None and is_cancelled_callback():
786
- return None
787
  # 9. Post-processing
788
  image = self.decode_latents(latents)
789
 
 
5
  import numpy as np
6
  import torch
7
 
8
+ import diffusers
9
  import PIL
10
  from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
11
+ from diffusers.onnx_utils import OnnxRuntimeModel
12
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
+ from diffusers.utils import deprecate, logging
14
+ from packaging import version
15
  from transformers import CLIPFeatureExtractor, CLIPTokenizer
16
 
17
 
18
+ try:
19
+ from diffusers.onnx_utils import ORT_TO_NP_TYPE
20
+ except ImportError:
21
+ ORT_TO_NP_TYPE = {
22
+ "tensor(bool)": np.bool_,
23
+ "tensor(int8)": np.int8,
24
+ "tensor(uint8)": np.uint8,
25
+ "tensor(int16)": np.int16,
26
+ "tensor(uint16)": np.uint16,
27
+ "tensor(int32)": np.int32,
28
+ "tensor(uint32)": np.uint32,
29
+ "tensor(int64)": np.int64,
30
+ "tensor(uint64)": np.uint64,
31
+ "tensor(float16)": np.float16,
32
+ "tensor(float)": np.float32,
33
+ "tensor(double)": np.float64,
34
+ }
35
+
36
+ try:
37
+ from diffusers.utils import PIL_INTERPOLATION
38
+ except ImportError:
39
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
40
+ PIL_INTERPOLATION = {
41
+ "linear": PIL.Image.Resampling.BILINEAR,
42
+ "bilinear": PIL.Image.Resampling.BILINEAR,
43
+ "bicubic": PIL.Image.Resampling.BICUBIC,
44
+ "lanczos": PIL.Image.Resampling.LANCZOS,
45
+ "nearest": PIL.Image.Resampling.NEAREST,
46
+ }
47
+ else:
48
+ PIL_INTERPOLATION = {
49
+ "linear": PIL.Image.LINEAR,
50
+ "bilinear": PIL.Image.BILINEAR,
51
+ "bicubic": PIL.Image.BICUBIC,
52
+ "lanczos": PIL.Image.LANCZOS,
53
+ "nearest": PIL.Image.NEAREST,
54
+ }
55
+ # ------------------------------------------------------------------------------
56
+
57
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
58
 
59
  re_attention = re.compile(
 
431
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
432
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
433
  """
434
+ if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
435
+
436
+ def __init__(
437
+ self,
438
+ vae_encoder: OnnxRuntimeModel,
439
+ vae_decoder: OnnxRuntimeModel,
440
+ text_encoder: OnnxRuntimeModel,
441
+ tokenizer: CLIPTokenizer,
442
+ unet: OnnxRuntimeModel,
443
+ scheduler: SchedulerMixin,
444
+ safety_checker: OnnxRuntimeModel,
445
+ feature_extractor: CLIPFeatureExtractor,
446
+ requires_safety_checker: bool = True,
447
+ ):
448
+ super().__init__(
449
+ vae_encoder=vae_encoder,
450
+ vae_decoder=vae_decoder,
451
+ text_encoder=text_encoder,
452
+ tokenizer=tokenizer,
453
+ unet=unet,
454
+ scheduler=scheduler,
455
+ safety_checker=safety_checker,
456
+ feature_extractor=feature_extractor,
457
+ requires_safety_checker=requires_safety_checker,
458
+ )
459
+ self.__init__additional__()
460
 
461
+ else:
462
+
463
+ def __init__(
464
+ self,
465
+ vae_encoder: OnnxRuntimeModel,
466
+ vae_decoder: OnnxRuntimeModel,
467
+ text_encoder: OnnxRuntimeModel,
468
+ tokenizer: CLIPTokenizer,
469
+ unet: OnnxRuntimeModel,
470
+ scheduler: SchedulerMixin,
471
+ safety_checker: OnnxRuntimeModel,
472
+ feature_extractor: CLIPFeatureExtractor,
473
+ ):
474
+ super().__init__(
475
+ vae_encoder=vae_encoder,
476
+ vae_decoder=vae_decoder,
477
+ text_encoder=text_encoder,
478
+ tokenizer=tokenizer,
479
+ unet=unet,
480
+ scheduler=scheduler,
481
+ safety_checker=safety_checker,
482
+ feature_extractor=feature_extractor,
483
+ )
484
+ self.__init__additional__()
485
+
486
+ def __init__additional__(self):
487
  self.unet_in_channels = 4
488
  self.vae_scale_factor = 8
489
 
 
811
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
812
 
813
  # 8. Denoising loop
814
+ for i, t in enumerate(self.progress_bar(timesteps)):
815
+ # expand the latents if we are doing classifier free guidance
816
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
817
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
818
+ latent_model_input = latent_model_input.numpy()
819
+
820
+ # predict the noise residual
821
+ noise_pred = self.unet(
822
+ sample=latent_model_input,
823
+ timestep=np.array([t], dtype=timestep_dtype),
824
+ encoder_hidden_states=text_embeddings,
825
+ )
826
+ noise_pred = noise_pred[0]
 
 
827
 
828
+ # perform guidance
829
+ if do_classifier_free_guidance:
830
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
831
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
832
+
833
+ # compute the previous noisy sample x_t -> x_t-1
834
+ scheduler_output = self.scheduler.step(
835
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
836
+ )
837
+ latents = scheduler_output.prev_sample.numpy()
838
+
839
+ if mask is not None:
840
+ # masking
841
+ init_latents_proper = self.scheduler.add_noise(
842
+ torch.from_numpy(init_latents_orig),
843
+ torch.from_numpy(noise),
844
+ t,
845
+ ).numpy()
846
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
847
+
848
+ # call the callback, if provided
849
+ if i % callback_steps == 0:
850
+ if callback is not None:
851
+ callback(i, t, latents)
852
+ if is_cancelled_callback is not None and is_cancelled_callback():
853
+ return None
854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
  # 9. Post-processing
856
  image = self.decode_latents(latents)
857