Shiroi-max commited on
Commit
8ab7af4
·
verified ·
1 Parent(s): 4e65884

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +62 -0
pipeline.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ from diffusers import DiffusionPipeline, ImagePipelineOutput
3
+ from diffusers.utils.torch_utils import randn_tensor
4
+
5
+ import torch
6
+
7
+
8
+ class DDPMConditionalPipeline(DiffusionPipeline):
9
+ model_cpu_offload_seq = "unet"
10
+
11
+ def __init__(self, unet, scheduler):
12
+ super().__init__()
13
+ self.register_modules(unet=unet, scheduler=scheduler)
14
+
15
+ @torch.no_grad()
16
+ def __call__(
17
+ self,
18
+ label,
19
+ batch_size: int = 1,
20
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
21
+ num_inference_steps: int = 1000,
22
+ output_type: Optional[str] = "pil",
23
+ return_dict: bool = True,
24
+ ) -> Union[ImagePipelineOutput, Tuple]:
25
+ # Sample gaussian noise to begin loop
26
+ if isinstance(self.unet.sample_size, int):
27
+ image_shape = (
28
+ batch_size,
29
+ self.unet.in_channels,
30
+ self.unet.sample_size,
31
+ self.unet.sample_size,
32
+ )
33
+ else:
34
+ image_shape = (
35
+ batch_size,
36
+ self.unet.in_channels,
37
+ *self.unet.sample_size,
38
+ )
39
+
40
+ image = randn_tensor(image_shape, generator=generator)
41
+
42
+ # set step values
43
+ self.scheduler.set_timesteps(num_inference_steps)
44
+
45
+ for t in self.progress_bar(self.scheduler.timesteps):
46
+ # 1. predict noise model_output
47
+ model_output = self.unet(image, t, label).sample
48
+
49
+ # 2. compute previous image: x_t -> x_t-1
50
+ image = self.scheduler.step(
51
+ model_output, t, image, generator=generator
52
+ ).prev_sample
53
+
54
+ image = (image / 2 + 0.5).clamp(0, 1)
55
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
56
+ if output_type == "pil":
57
+ image = self.numpy_to_pil(image)
58
+
59
+ if not return_dict:
60
+ return (image,)
61
+
62
+ return ImagePipelineOutput(images=image)