giulio98 commited on
Commit
46d8a6f
·
verified ·
1 Parent(s): 112b10b

Create conditional_unet_model.py

Browse files
Files changed (1) hide show
  1. unet/conditional_unet_model.py +862 -0
unet/conditional_unet_model.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
16
+ from torch import tensor
17
+
18
+
19
+ @dataclass
20
+ class UNet2DOutput(BaseOutput):
21
+ """
22
+ The output of [`UNet2DModel`].
23
+
24
+ Args:
25
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
26
+ The hidden states output from the last layer of the model.
27
+ """
28
+
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ class UNet2DModel(ModelMixin, ConfigMixin):
33
+ r"""
34
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
35
+
36
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
37
+ for all models (such as downloading or saving).
38
+
39
+ Parameters:
40
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
41
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
42
+ 1)`.
43
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
44
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
45
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
46
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
47
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
48
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
49
+ Whether to flip sin to cos for Fourier time embedding.
50
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
51
+ Tuple of downsample block types.
52
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
53
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
54
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
55
+ Tuple of upsample block types.
56
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
57
+ Tuple of block output channels.
58
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
59
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
60
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
61
+ downsample_type (`str`, *optional*, defaults to `conv`):
62
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
63
+ upsample_type (`str`, *optional*, defaults to `conv`):
64
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
65
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
66
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
67
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
68
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
69
+ attn_norm_num_groups (`int`, *optional*, defaults to `None`):
70
+ If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
71
+ given number of groups. If left as `None`, the group norm layer will only be created if
72
+ `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
73
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
74
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
75
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
76
+ class_embed_type (`str`, *optional*, defaults to `None`):
77
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
78
+ `"timestep"`, or `"identity"`.
79
+ num_class_embeds (`int`, *optional*, defaults to `None`):
80
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
81
+ conditioning with `class_embed_type` equal to `None`.
82
+ """
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
88
+ in_channels: int = 3,
89
+ out_channels: int = 3,
90
+ center_input_sample: bool = False,
91
+ time_embedding_type: str = "positional",
92
+ freq_shift: int = 0,
93
+ flip_sin_to_cos: bool = True,
94
+ down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
95
+ up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
96
+ block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
97
+ layers_per_block: int = 2,
98
+ mid_block_scale_factor: float = 1,
99
+ downsample_padding: int = 1,
100
+ downsample_type: str = "conv",
101
+ upsample_type: str = "conv",
102
+ dropout: float = 0.0,
103
+ act_fn: str = "silu",
104
+ attention_head_dim: Optional[int] = 8,
105
+ norm_num_groups: int = 32,
106
+ attn_norm_num_groups: Optional[int] = None,
107
+ norm_eps: float = 1e-5,
108
+ resnet_time_scale_shift: str = "default",
109
+ add_attention: bool = True,
110
+ class_embed_type: Optional[str] = None,
111
+ num_class_embeds: Optional[int] = None,
112
+ num_train_timesteps: Optional[int] = None,
113
+ set_W_to_weight: Optional[bool] = True
114
+ ):
115
+ super().__init__()
116
+
117
+ self.sample_size = sample_size
118
+ time_embed_dim = block_out_channels[0] * 4
119
+
120
+ # Check inputs
121
+ if len(down_block_types) != len(up_block_types):
122
+ raise ValueError(
123
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
124
+ )
125
+
126
+ if len(block_out_channels) != len(down_block_types):
127
+ raise ValueError(
128
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
129
+ )
130
+
131
+ # input
132
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
133
+
134
+ # time
135
+ if time_embedding_type == "fourier":
136
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16, set_W_to_weight=set_W_to_weight)
137
+ timestep_input_dim = 2 * block_out_channels[0]
138
+ elif time_embedding_type == "positional":
139
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
140
+ timestep_input_dim = block_out_channels[0]
141
+ elif time_embedding_type == "learned":
142
+ self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
143
+ timestep_input_dim = block_out_channels[0]
144
+
145
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
146
+
147
+ # class embedding
148
+ if class_embed_type is None and num_class_embeds is not None:
149
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
150
+ elif class_embed_type == "timestep":
151
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
152
+ elif class_embed_type == "identity":
153
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
154
+ else:
155
+ self.class_embedding = None
156
+
157
+ self.down_blocks = nn.ModuleList([])
158
+ self.mid_block = None
159
+ self.up_blocks = nn.ModuleList([])
160
+
161
+ # down
162
+ output_channel = block_out_channels[0]
163
+ for i, down_block_type in enumerate(down_block_types):
164
+ input_channel = output_channel
165
+ output_channel = block_out_channels[i]
166
+ is_final_block = i == len(block_out_channels) - 1
167
+
168
+ down_block = get_down_block(
169
+ down_block_type,
170
+ num_layers=layers_per_block,
171
+ in_channels=input_channel,
172
+ out_channels=output_channel,
173
+ temb_channels=time_embed_dim,
174
+ add_downsample=not is_final_block,
175
+ resnet_eps=norm_eps,
176
+ resnet_act_fn=act_fn,
177
+ resnet_groups=norm_num_groups,
178
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
179
+ downsample_padding=downsample_padding,
180
+ resnet_time_scale_shift=resnet_time_scale_shift,
181
+ downsample_type=downsample_type,
182
+ dropout=dropout,
183
+ )
184
+ self.down_blocks.append(down_block)
185
+
186
+ # mid
187
+ self.mid_block = UNetMidBlock2D(
188
+ in_channels=block_out_channels[-1],
189
+ temb_channels=time_embed_dim,
190
+ dropout=dropout,
191
+ resnet_eps=norm_eps,
192
+ resnet_act_fn=act_fn,
193
+ output_scale_factor=mid_block_scale_factor,
194
+ resnet_time_scale_shift=resnet_time_scale_shift,
195
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
196
+ resnet_groups=norm_num_groups,
197
+ attn_groups=attn_norm_num_groups,
198
+ add_attention=add_attention,
199
+ )
200
+
201
+ # up
202
+ reversed_block_out_channels = list(reversed(block_out_channels))
203
+ output_channel = reversed_block_out_channels[0]
204
+ for i, up_block_type in enumerate(up_block_types):
205
+ prev_output_channel = output_channel
206
+ output_channel = reversed_block_out_channels[i]
207
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
208
+
209
+ is_final_block = i == len(block_out_channels) - 1
210
+
211
+ up_block = get_up_block(
212
+ up_block_type,
213
+ num_layers=layers_per_block + 1,
214
+ in_channels=input_channel,
215
+ out_channels=output_channel,
216
+ prev_output_channel=prev_output_channel,
217
+ temb_channels=time_embed_dim,
218
+ add_upsample=not is_final_block,
219
+ resnet_eps=norm_eps,
220
+ resnet_act_fn=act_fn,
221
+ resnet_groups=norm_num_groups,
222
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
223
+ resnet_time_scale_shift=resnet_time_scale_shift,
224
+ upsample_type=upsample_type,
225
+ dropout=dropout,
226
+ )
227
+ self.up_blocks.append(up_block)
228
+ prev_output_channel = output_channel
229
+
230
+ # out
231
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
232
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
233
+ self.conv_act = nn.SiLU()
234
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
235
+
236
+ def forward(
237
+ self,
238
+ sample: torch.FloatTensor,
239
+ timestep: Union[torch.Tensor, float, int],
240
+ class_labels: Optional[torch.Tensor] = None,
241
+ return_dict: bool = True,
242
+ ) -> Union[UNet2DOutput, Tuple]:
243
+ r"""
244
+ The [`UNet2DModel`] forward method.
245
+
246
+ Args:
247
+ sample (`torch.FloatTensor`):
248
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
249
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
250
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
251
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
252
+ return_dict (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
254
+
255
+ Returns:
256
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`:
257
+ If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
258
+ returned where the first element is the sample tensor.
259
+ """
260
+ # 0. center input if necessary
261
+ if self.config.center_input_sample:
262
+ sample = 2 * sample - 1.0
263
+
264
+ # 1. time
265
+ timesteps = timestep
266
+ if not torch.is_tensor(timesteps):
267
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
268
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
269
+ timesteps = timesteps[None].to(sample.device)
270
+
271
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
272
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
273
+
274
+ t_emb = self.time_proj(timesteps)
275
+
276
+ # timesteps does not contain any weights and will always return f32 tensors
277
+ # but time_embedding might actually be running in fp16. so we need to cast here.
278
+ # there might be better ways to encapsulate this.
279
+ t_emb = t_emb.to(dtype=self.dtype)
280
+ emb = self.time_embedding(t_emb)
281
+
282
+ if self.class_embedding is not None:
283
+ if class_labels is None:
284
+ raise ValueError("class_labels should be provided when doing class conditioning")
285
+
286
+ if self.config.class_embed_type == "timestep":
287
+ class_labels = self.time_proj(class_labels)
288
+
289
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
290
+ emb = emb + class_emb
291
+ elif self.class_embedding is None and class_labels is not None:
292
+ raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
293
+
294
+ # 2. pre-process
295
+ skip_sample = sample
296
+ sample = self.conv_in(sample)
297
+
298
+ # 3. down
299
+ down_block_res_samples = (sample,)
300
+ for downsample_block in self.down_blocks:
301
+ if hasattr(downsample_block, "skip_conv"):
302
+ sample, res_samples, skip_sample = downsample_block(
303
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
304
+ )
305
+ else:
306
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
307
+
308
+ down_block_res_samples += res_samples
309
+
310
+ # 4. mid
311
+ sample = self.mid_block(sample, emb)
312
+
313
+ # 5. up
314
+ skip_sample = None
315
+ for upsample_block in self.up_blocks:
316
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
317
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
318
+
319
+ if hasattr(upsample_block, "skip_conv"):
320
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
321
+ else:
322
+ sample = upsample_block(sample, res_samples, emb)
323
+
324
+ # 6. post-process
325
+ sample = self.conv_norm_out(sample)
326
+ sample = self.conv_act(sample)
327
+ sample = self.conv_out(sample)
328
+
329
+ if skip_sample is not None:
330
+ sample += skip_sample
331
+
332
+ if self.config.time_embedding_type == "fourier":
333
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
334
+ sample = sample / timesteps
335
+
336
+ if not return_dict:
337
+ return (sample,)
338
+
339
+ return UNet2DOutput(sample=sample)
340
+
341
+
342
+ class MultiLabelConditionalUNet2DModelForCelebaHQ(ModelMixin, ConfigMixin):
343
+ r"""
344
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
345
+
346
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
347
+ for all models (such as downloading or saving).
348
+
349
+ Parameters:
350
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
351
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
352
+ 1)`.
353
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
354
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
355
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
356
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
357
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
358
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
359
+ Whether to flip sin to cos for Fourier time embedding.
360
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
361
+ Tuple of downsample block types.
362
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
363
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
364
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
365
+ Tuple of upsample block types.
366
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
367
+ Tuple of block output channels.
368
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
369
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
370
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
371
+ downsample_type (`str`, *optional*, defaults to `conv`):
372
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
373
+ upsample_type (`str`, *optional*, defaults to `conv`):
374
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
375
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
376
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
377
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
378
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
379
+ attn_norm_num_groups (`int`, *optional*, defaults to `None`):
380
+ If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
381
+ given number of groups. If left as `None`, the group norm layer will only be created if
382
+ `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
383
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
384
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
385
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
386
+ class_embed_type (`str`, *optional*, defaults to `None`):
387
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
388
+ `"timestep"`, or `"identity"`.
389
+ num_class_embeds (`int`, *optional*, defaults to `None`):
390
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
391
+ conditioning with `class_embed_type` equal to `None`.
392
+ """
393
+
394
+ @register_to_config
395
+ def __init__(
396
+ self,
397
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
398
+ in_channels: int = 3,
399
+ out_channels: int = 3,
400
+ center_input_sample: bool = False,
401
+ time_embedding_type: str = "positional",
402
+ freq_shift: int = 0,
403
+ flip_sin_to_cos: bool = True,
404
+ down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
405
+ up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
406
+ block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
407
+ layers_per_block: int = 2,
408
+ mid_block_scale_factor: float = 1,
409
+ downsample_padding: int = 1,
410
+ downsample_type: str = "conv",
411
+ upsample_type: str = "conv",
412
+ dropout: float = 0.0,
413
+ act_fn: str = "silu",
414
+ attention_head_dim: Optional[int] = 8,
415
+ norm_num_groups: int = 32,
416
+ attn_norm_num_groups: Optional[int] = None,
417
+ norm_eps: float = 1e-5,
418
+ resnet_time_scale_shift: str = "default",
419
+ add_attention: bool = True,
420
+ class_embed_type: Optional[str] = None,
421
+ num_train_timesteps: Optional[int] = None,
422
+ set_W_to_weight: Optional[bool] = True
423
+ ):
424
+ super().__init__()
425
+
426
+ self.sample_size = sample_size
427
+ time_embed_dim = block_out_channels[0] * 4
428
+
429
+ # Check inputs
430
+ if len(down_block_types) != len(up_block_types):
431
+ raise ValueError(
432
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
433
+ )
434
+
435
+ if len(block_out_channels) != len(down_block_types):
436
+ raise ValueError(
437
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
438
+ )
439
+
440
+ # input
441
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
442
+
443
+ # time
444
+ if time_embedding_type == "fourier":
445
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16, set_W_to_weight=set_W_to_weight)
446
+ timestep_input_dim = 2 * block_out_channels[0]
447
+ elif time_embedding_type == "positional":
448
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
449
+ timestep_input_dim = block_out_channels[0]
450
+ elif time_embedding_type == "learned":
451
+ self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
452
+ timestep_input_dim = block_out_channels[0]
453
+
454
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
455
+
456
+ # class embedding
457
+ if class_embed_type is None:
458
+ # self.class_embedding_5_o_Clock_Shadow = nn.Embedding(2, time_embed_dim)
459
+ # self.class_embedding_Arched_Eyebrows = nn.Embedding(2, time_embed_dim)
460
+ # self.class_embedding_Attractive = nn.Embedding(2, time_embed_dim)
461
+ # self.class_embedding_Bags_Under_Eyes = nn.Embedding(2, time_embed_dim)
462
+ # self.class_embedding_Bald = nn.Embedding(2, time_embed_dim)
463
+ self.class_embedding_Bangs = nn.Embedding(3, time_embed_dim)
464
+ # self.class_embedding_Big_Lips = nn.Embedding(2, time_embed_dim)
465
+ # self.class_embedding_Big_Nose = nn.Embedding(2, time_embed_dim)
466
+ # self.class_embedding_Black_Hair = nn.Embedding(2, time_embed_dim)
467
+ # self.class_embedding_Blond_Hair = nn.Embedding(2, time_embed_dim)
468
+ # self.class_embedding_Blurry = nn.Embedding(2, time_embed_dim)
469
+ # self.class_embedding_Brown_Hair = nn.Embedding(2, time_embed_dim)
470
+ # self.class_embedding_Bushy_Eyebrows = nn.Embedding(2, time_embed_dim)
471
+ # self.class_embedding_Chubby = nn.Embedding(2, time_embed_dim)
472
+ # self.class_embedding_Double_Chin = nn.Embedding(2, time_embed_dim)
473
+ self.class_embedding_Eyeglasses = nn.Embedding(3, time_embed_dim)
474
+ # self.class_embedding_Goatee = nn.Embedding(2, time_embed_dim)
475
+ # self.class_embedding_Gray_Hair = nn.Embedding(2, time_embed_dim)
476
+ # self.class_embedding_Heavy_Makeup = nn.Embedding(2, time_embed_dim)
477
+ # self.class_embedding_High_Cheekbones = nn.Embedding(2, time_embed_dim)
478
+ self.class_embedding_Male = nn.Embedding(3, time_embed_dim)
479
+ # self.class_embedding_Mouth_Slightly_Open = nn.Embedding(2, time_embed_dim)
480
+ self.class_embedding_Mustache = nn.Embedding(3, time_embed_dim)
481
+ # self.class_embedding_Narrow_Eyes = nn.Embedding(2, time_embed_dim)
482
+ # self.class_embedding_No_Beard = nn.Embedding(2, time_embed_dim)
483
+ # self.class_embedding_Oval_Face = nn.Embedding(2, time_embed_dim)
484
+ # self.class_embedding_Pale_Skin = nn.Embedding(2, time_embed_dim)
485
+ self.class_embedding_Pointy_Nose = nn.Embedding(3, time_embed_dim)
486
+ # self.class_embedding_Receding_Hairline = nn.Embedding(2, time_embed_dim)
487
+ # self.class_embedding_Rosy_Cheeks = nn.Embedding(2, time_embed_dim)
488
+ # self.class_embedding_Sideburns = nn.Embedding(2, time_embed_dim)
489
+ self.class_embedding_Smiling = nn.Embedding(3, time_embed_dim)
490
+ # self.class_embedding_Straight_Hair = nn.Embedding(2, time_embed_dim)
491
+ # self.class_embedding_Wavy_Hair = nn.Embedding(2, time_embed_dim)
492
+ # self.class_embedding_Wearing_Earrings = nn.Embedding(2, time_embed_dim)
493
+ self.class_embedding_Wearing_Hat = nn.Embedding(3, time_embed_dim)
494
+ # self.class_embedding_Wearing_Lipstick = nn.Embedding(2, time_embed_dim)
495
+ # self.class_embedding_Wearing_Necklace = nn.Embedding(2, time_embed_dim)
496
+ # self.class_embedding_Wearing_Necktie = nn.Embedding(2, time_embed_dim)
497
+ self.class_embedding_Young = nn.Embedding(3, time_embed_dim)
498
+ self.apply_class_emb = True
499
+ elif class_embed_type == "timestep":
500
+ # self.class_embedding_5_o_Clock_Shadow = TimestepEmbedding(timestep_input_dim, time_embed_dim)
501
+ # self.class_embedding_Arched_Eyebrows = TimestepEmbedding(timestep_input_dim, time_embed_dim)
502
+ # self.class_embedding_Attractive = TimestepEmbedding(timestep_input_dim, time_embed_dim)
503
+ # self.class_embedding_Bags_Under_Eyes = TimestepEmbedding(timestep_input_dim, time_embed_dim)
504
+ # self.class_embedding_Bald = TimestepEmbedding(timestep_input_dim, time_embed_dim)
505
+ self.class_embedding_Bangs = TimestepEmbedding(timestep_input_dim, time_embed_dim)
506
+ # self.class_embedding_Big_Lips = TimestepEmbedding(timestep_input_dim, time_embed_dim)
507
+ # self.class_embedding_Big_Nose = TimestepEmbedding(timestep_input_dim, time_embed_dim)
508
+ # self.class_embedding_Black_Hair = TimestepEmbedding(timestep_input_dim, time_embed_dim)
509
+ # self.class_embedding_Blond_Hair = TimestepEmbedding(timestep_input_dim, time_embed_dim)
510
+ # self.class_embedding_Blurry = TimestepEmbedding(timestep_input_dim, time_embed_dim)
511
+ # self.class_embedding_Brown_Hair = TimestepEmbedding(timestep_input_dim, time_embed_dim)
512
+ # self.class_embedding_Bushy_Eyebrows = TimestepEmbedding(timestep_input_dim, time_embed_dim)
513
+ # self.class_embedding_Chubby = TimestepEmbedding(timestep_input_dim, time_embed_dim)
514
+ # self.class_embedding_Double_Chin = TimestepEmbedding(timestep_input_dim, time_embed_dim)
515
+ self.class_embedding_Eyeglasses = TimestepEmbedding(timestep_input_dim, time_embed_dim)
516
+ # self.class_embedding_Goatee = TimestepEmbedding(timestep_input_dim, time_embed_dim)
517
+ # self.class_embedding_Gray_Hair = TimestepEmbedding(timestep_input_dim, time_embed_dim)
518
+ # self.class_embedding_Heavy_Makeup = TimestepEmbedding(timestep_input_dim, time_embed_dim)
519
+ # self.class_embedding_High_Cheekbones = TimestepEmbedding(timestep_input_dim, time_embed_dim)
520
+ self.class_embedding_Male = TimestepEmbedding(timestep_input_dim, time_embed_dim)
521
+ # self.class_embedding_Mouth_Slightly_Open = TimestepEmbedding(timestep_input_dim, time_embed_dim)
522
+ self.class_embedding_Mustache = TimestepEmbedding(timestep_input_dim, time_embed_dim)
523
+ # self.class_embedding_Narrow_Eyes = TimestepEmbedding(timestep_input_dim, time_embed_dim)
524
+ # self.class_embedding_No_Beard = TimestepEmbedding(timestep_input_dim, time_embed_dim)
525
+ # self.class_embedding_Oval_Face = TimestepEmbedding(timestep_input_dim, time_embed_dim)
526
+ # self.class_embedding_Pale_Skin = TimestepEmbedding(timestep_input_dim, time_embed_dim)
527
+ self.class_embedding_Pointy_Nose = TimestepEmbedding(timestep_input_dim, time_embed_dim)
528
+ # self.class_embedding_Receding_Hairline = TimestepEmbedding(timestep_input_dim, time_embed_dim)
529
+ # self.class_embedding_Rosy_Cheeks = TimestepEmbedding(timestep_input_dim, time_embed_dim)
530
+ # self.class_embedding_Sideburns = TimestepEmbedding(timestep_input_dim, time_embed_dim)
531
+ self.class_embedding_Smiling = TimestepEmbedding(timestep_input_dim, time_embed_dim)
532
+ # self.class_embedding_Straight_Hair = TimestepEmbedding(timestep_input_dim, time_embed_dim)
533
+ # self.class_embedding_Wavy_Hair = TimestepEmbedding(timestep_input_dim, time_embed_dim)
534
+ # self.class_embedding_Wearing_Earrings = TimestepEmbedding(timestep_input_dim, time_embed_dim)
535
+ self.class_embedding_Wearing_Hat = TimestepEmbedding(timestep_input_dim, time_embed_dim)
536
+ # self.class_embedding_Wearing_Lipstick = TimestepEmbedding(timestep_input_dim, time_embed_dim)
537
+ # self.class_embedding_Wearing_Necklace = TimestepEmbedding(timestep_input_dim, time_embed_dim)
538
+ # self.class_embedding_Wearing_Necktie = TimestepEmbedding(timestep_input_dim, time_embed_dim)
539
+ self.class_embedding_Young = TimestepEmbedding(timestep_input_dim, time_embed_dim)
540
+ self.apply_class_emb = True
541
+ elif class_embed_type == "identity":
542
+ # self.class_embedding_5_o_Clock_Shadow = nn.Identity(time_embed_dim)
543
+ # self.class_embedding_Arched_Eyebrows = nn.Identity(time_embed_dim)
544
+ # self.class_embedding_Attractive = nn.Identity(time_embed_dim)
545
+ # self.class_embedding_Bags_Under_Eyes = nn.Identity(time_embed_dim)
546
+ # self.class_embedding_Bald = nn.Identity(time_embed_dim)
547
+ self.class_embedding_Bangs = nn.Identity(time_embed_dim)
548
+ # self.class_embedding_Big_Lips = nn.Identity(time_embed_dim)
549
+ # self.class_embedding_Big_Nose = nn.Identity(time_embed_dim)
550
+ # self.class_embedding_Black_Hair = nn.Identity(time_embed_dim)
551
+ # self.class_embedding_Blond_Hair = nn.Identity(time_embed_dim)
552
+ # self.class_embedding_Blurry = nn.Identity(time_embed_dim)
553
+ # self.class_embedding_Brown_Hair = nn.Identity(time_embed_dim)
554
+ # self.class_embedding_Bushy_Eyebrows = nn.Identity(time_embed_dim)
555
+ # self.class_embedding_Chubby = nn.Identity(time_embed_dim)
556
+ # self.class_embedding_Double_Chin = nn.Identity(time_embed_dim)
557
+ self.class_embedding_Eyeglasses = nn.Identity(time_embed_dim)
558
+ # self.class_embedding_Goatee = nn.Identity(time_embed_dim)
559
+ # self.class_embedding_Gray_Hair = nn.Identity(time_embed_dim)
560
+ # self.class_embedding_Heavy_Makeup = nn.Identity(time_embed_dim)
561
+ # self.class_embedding_High_Cheekbones = nn.Identity(time_embed_dim)
562
+ self.class_embedding_Male = nn.Identity(time_embed_dim)
563
+ # self.class_embedding_Mouth_Slightly_Open = nn.Identity(time_embed_dim)
564
+ self.class_embedding_Mustache = nn.Identity(time_embed_dim)
565
+ # self.class_embedding_Narrow_Eyes = nn.Identity(time_embed_dim)
566
+ # self.class_embedding_No_Beard = nn.Identity(time_embed_dim)
567
+ # self.class_embedding_Oval_Face = nn.Identity(time_embed_dim)
568
+ # self.class_embedding_Pale_Skin = nn.Identity(time_embed_dim)
569
+ self.class_embedding_Pointy_Nose = nn.Identity(time_embed_dim)
570
+ # self.class_embedding_Receding_Hairline = nn.Identity(time_embed_dim)
571
+ # self.class_embedding_Rosy_Cheeks = nn.Identity(time_embed_dim)
572
+ # self.class_embedding_Sideburns = nn.Identity(time_embed_dim)
573
+ self.class_embedding_Smiling = nn.Identity(time_embed_dim)
574
+ # self.class_embedding_Straight_Hair = nn.Identity(time_embed_dim)
575
+ # self.class_embedding_Wavy_Hair = nn.Identity(time_embed_dim)
576
+ # self.class_embedding_Wearing_Earrings = nn.Identity(time_embed_dim)
577
+ self.class_embedding_Wearing_Hat = nn.Identity(time_embed_dim)
578
+ # self.class_embedding_Wearing_Lipstick = nn.Identity(time_embed_dim)
579
+ # self.class_embedding_Wearing_Necklace = nn.Identity(time_embed_dim)
580
+ # self.class_embedding_Wearing_Necktie = nn.Identity(time_embed_dim)
581
+ self.class_embedding_Young = nn.Identity(time_embed_dim)
582
+ self.apply_class_emb = True
583
+ else:
584
+ self.apply_class_emb = False
585
+
586
+ self.down_blocks = nn.ModuleList([])
587
+ self.mid_block = None
588
+ self.up_blocks = nn.ModuleList([])
589
+
590
+ # down
591
+ output_channel = block_out_channels[0]
592
+ for i, down_block_type in enumerate(down_block_types):
593
+ input_channel = output_channel
594
+ output_channel = block_out_channels[i]
595
+ is_final_block = i == len(block_out_channels) - 1
596
+
597
+ down_block = get_down_block(
598
+ down_block_type,
599
+ num_layers=layers_per_block,
600
+ in_channels=input_channel,
601
+ out_channels=output_channel,
602
+ temb_channels=time_embed_dim,
603
+ add_downsample=not is_final_block,
604
+ resnet_eps=norm_eps,
605
+ resnet_act_fn=act_fn,
606
+ resnet_groups=norm_num_groups,
607
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
608
+ downsample_padding=downsample_padding,
609
+ resnet_time_scale_shift=resnet_time_scale_shift,
610
+ downsample_type=downsample_type,
611
+ dropout=dropout,
612
+ )
613
+ self.down_blocks.append(down_block)
614
+
615
+ # mid
616
+ self.mid_block = UNetMidBlock2D(
617
+ in_channels=block_out_channels[-1],
618
+ temb_channels=time_embed_dim,
619
+ dropout=dropout,
620
+ resnet_eps=norm_eps,
621
+ resnet_act_fn=act_fn,
622
+ output_scale_factor=mid_block_scale_factor,
623
+ resnet_time_scale_shift=resnet_time_scale_shift,
624
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
625
+ resnet_groups=norm_num_groups,
626
+ attn_groups=attn_norm_num_groups,
627
+ add_attention=add_attention,
628
+ )
629
+
630
+ # up
631
+ reversed_block_out_channels = list(reversed(block_out_channels))
632
+ output_channel = reversed_block_out_channels[0]
633
+ for i, up_block_type in enumerate(up_block_types):
634
+ prev_output_channel = output_channel
635
+ output_channel = reversed_block_out_channels[i]
636
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
637
+
638
+ is_final_block = i == len(block_out_channels) - 1
639
+
640
+ up_block = get_up_block(
641
+ up_block_type,
642
+ num_layers=layers_per_block + 1,
643
+ in_channels=input_channel,
644
+ out_channels=output_channel,
645
+ prev_output_channel=prev_output_channel,
646
+ temb_channels=time_embed_dim,
647
+ add_upsample=not is_final_block,
648
+ resnet_eps=norm_eps,
649
+ resnet_act_fn=act_fn,
650
+ resnet_groups=norm_num_groups,
651
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
652
+ resnet_time_scale_shift=resnet_time_scale_shift,
653
+ upsample_type=upsample_type,
654
+ dropout=dropout,
655
+ )
656
+ self.up_blocks.append(up_block)
657
+ prev_output_channel = output_channel
658
+
659
+ # out
660
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
661
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
662
+ self.conv_act = nn.SiLU()
663
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
664
+
665
+ def forward(
666
+ self,
667
+ sample: torch.FloatTensor,
668
+ timestep: Union[torch.Tensor, float, int],
669
+ class_labels: Optional[torch.Tensor] = None,
670
+ return_dict: bool = True,
671
+ ) -> Union[UNet2DOutput, Tuple]:
672
+ r"""
673
+ The [`UNet2DModel`] forward method.
674
+
675
+ Args:
676
+ sample (`torch.FloatTensor`):
677
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
678
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
679
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
680
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
681
+ return_dict (`bool`, *optional*, defaults to `True`):
682
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
683
+
684
+ Returns:
685
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`:
686
+ If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
687
+ returned where the first element is the sample tensor.
688
+ """
689
+ # 0. center input if necessary
690
+ if self.config.center_input_sample:
691
+ sample = 2 * sample - 1.0
692
+
693
+ # 1. time
694
+ timesteps = timestep
695
+ if not torch.is_tensor(timesteps):
696
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
697
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
698
+ timesteps = timesteps[None].to(sample.device)
699
+
700
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
701
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
702
+
703
+ t_emb = self.time_proj(timesteps)
704
+
705
+ # timesteps does not contain any weights and will always return f32 tensors
706
+ # but time_embedding might actually be running in fp16. so we need to cast here.
707
+ # there might be better ways to encapsulate this.
708
+ t_emb = t_emb.to(dtype=self.dtype)
709
+ emb = self.time_embedding(t_emb)
710
+
711
+ if self.apply_class_emb:
712
+ if class_labels is None:
713
+ raise ValueError("class_labels should be provided when doing class conditioning")
714
+ # class_labels_5_o_Clock_Shadow = class_labels[:, 0]
715
+ # class_labels_Arched_Eyebrows = class_labels[:, 1]
716
+ # class_labels_Attractive = class_labels[:, 2]
717
+ # class_labels_Bags_Under_Eyes = class_labels[:, 3]
718
+ # class_labels_Bald = class_labels[:, 4]
719
+ class_labels_Bangs = class_labels[:, 5]
720
+ # class_labels_Big_Lips = class_labels[:, 6]
721
+ # class_labels_Big_Nose = class_labels[:, 7]
722
+ # class_labels_Black_Hair = class_labels[:, 8]
723
+ # class_labels_Blond_Hair = class_labels[:, 9]
724
+ # class_labels_Blurry = class_labels[:, 10]
725
+ # class_labels_Brown_Hair = class_labels[:, 11]
726
+ # class_labels_Bushy_Eyebrows = class_labels[:, 12]
727
+ # class_labels_Chubby = class_labels[:, 13]
728
+ # class_labels_Double_Chin = class_labels[:, 14]
729
+ class_labels_Eyeglasses = class_labels[:, 15]
730
+ # class_labels_Goatee = class_labels[:, 16]
731
+ # class_labels_Gray_Hair = class_labels[:, 17]
732
+ # class_labels_Heavy_Makeup = class_labels[:, 18]
733
+ # class_labels_High_Cheekbones = class_labels[:, 19]
734
+ class_labels_Male = class_labels[:, 20]
735
+ # class_labels_Mouth_Slightly_Open = class_labels[:, 21]
736
+ class_labels_Mustache = class_labels[:, 22]
737
+ # class_labels_Narrow_Eyes = class_labels[:, 23]
738
+ # class_labels_No_Beard = class_labels[:, 24]
739
+ # class_labels_Oval_Face = class_labels[:, 25]
740
+ # class_labels_Pale_Skin = class_labels[:, 26]
741
+ class_labels_Pointy_Nose = class_labels[:, 27]
742
+ # class_labels_Receding_Hairline = class_labels[:, 28]
743
+ # class_labels_Rosy_Cheeks = class_labels[:, 29]
744
+ # class_labels_Sideburns = class_labels[:, 30]
745
+ class_labels_Smiling = class_labels[:, 31]
746
+ # class_labels_Straight_Hair = class_labels[:, 32]
747
+ # class_labels_Wavy_Hair = class_labels[:, 33]
748
+ # class_labels_Wearing_Earrings = class_labels[:, 34]
749
+ class_labels_Wearing_Hat = class_labels[:, 35]
750
+ # class_labels_Wearing_Lipstick = class_labels[:, 36]
751
+ # class_labels_Wearing_Necklace = class_labels[:, 37]
752
+ # class_labels_Wearing_Necktie = class_labels[:, 38]
753
+ class_labels_Young = class_labels[:, 39]
754
+
755
+ # Apply time projection if configured
756
+ if self.config.class_embed_type == "timestep":
757
+ # class_labels_5_o_Clock_Shadow = self.time_proj(class_labels_5_o_Clock_Shadow)
758
+ # class_labels_Arched_Eyebrows = self.time_proj(class_labels_Arched_Eyebrows)
759
+ # class_labels_Attractive = self.time_proj(class_labels_Attractive)
760
+ # class_labels_Bags_Under_Eyes = self.time_proj(class_labels_Bags_Under_Eyes)
761
+ # class_labels_Bald = self.time_proj(class_labels_Bald)
762
+ class_labels_Bangs = self.time_proj(class_labels_Bangs)
763
+ # class_labels_Big_Lips = self.time_proj(class_labels_Big_Lips)
764
+ # class_labels_Big_Nose = self.time_proj(class_labels_Big_Nose)
765
+ # class_labels_Black_Hair = self.time_proj(class_labels_Black_Hair)
766
+ # class_labels_Blond_Hair = self.time_proj(class_labels_Blond_Hair)
767
+ # class_labels_Blurry = self.time_proj(class_labels_Blurry)
768
+ # class_labels_Brown_Hair = self.time_proj(class_labels_Brown_Hair)
769
+ # class_labels_Bushy_Eyebrows = self.time_proj(class_labels_Bushy_Eyebrows)
770
+ # class_labels_Chubby = self.time_proj(class_labels_Chubby)
771
+ # class_labels_Double_Chin = self.time_proj(class_labels_Double_Chin)
772
+ class_labels_Eyeglasses = self.time_proj(class_labels_Eyeglasses)
773
+ # class_labels_Goatee = self.time_proj(class_labels_Goatee)
774
+ # class_labels_Gray_Hair = self.time_proj(class_labels_Gray_Hair)
775
+ # class_labels_Heavy_Makeup = self.time_proj(class_labels_Heavy_Makeup)
776
+ # class_labels_High_Cheekbones = self.time_proj(class_labels_High_Cheekbones)
777
+ class_labels_Male = self.time_proj(class_labels_Male)
778
+ # class_labels_Mouth_Slightly_Open = self.time_proj(class_labels_Mouth_Slightly_Open)
779
+ class_labels_Mustache = self.time_proj(class_labels_Mustache)
780
+ # class_labels_Narrow_Eyes = self.time_proj(class_labels_Narrow_Eyes)
781
+ # class_labels_No_Beard = self.time_proj(class_labels_No_Beard)
782
+ # class_labels_Oval_Face = self.time_proj(class_labels_Oval_Face)
783
+ # class_labels_Pale_Skin = self.time_proj(class_labels_Pale_Skin)
784
+ class_labels_Pointy_Nose = self.time_proj(class_labels_Pointy_Nose)
785
+ # class_labels_Receding_Hairline = self.time_proj(class_labels_Receding_Hairline)
786
+ # class_labels_Rosy_Cheeks = self.time_proj(class_labels_Rosy_Cheeks)
787
+ # class_labels_Sideburns = self.time_proj(class_labels_Sideburns)
788
+ class_labels_Smiling = self.time_proj(class_labels_Smiling)
789
+ # class_labels_Straight_Hair = self.time_proj(class_labels_Straight_Hair)
790
+ # class_labels_Wavy_Hair = self.time_proj(class_labels_Wavy_Hair)
791
+ # class_labels_Wearing_Earrings = self.time_proj(class_labels_Wearing_Earrings)
792
+ class_labels_Wearing_Hat = self.time_proj(class_labels_Wearing_Hat)
793
+ # class_labels_Wearing_Lipstick = self.time_proj(class_labels_Wearing_Lipstick)
794
+ # class_labels_Wearing_Necklace = self.time_proj(class_labels_Wearing_Necklace)
795
+ # class_labels_Wearing_Necktie = self.time_proj(class_labels_Wearing_Necktie)
796
+ class_labels_Young = self.time_proj(class_labels_Young)
797
+
798
+ if self.class_embedding_Bangs:
799
+ emb += self.class_embedding_Bangs(class_labels_Bangs + 1)
800
+ if self.class_embedding_Male:
801
+ emb += self.class_embedding_Male(class_labels_Male + 1)
802
+ if self.class_embedding_Eyeglasses:
803
+ emb += self.class_embedding_Eyeglasses(class_labels_Eyeglasses + 1)
804
+ if self.class_embedding_Mustache:
805
+ emb += self.class_embedding_Mustache(class_labels_Mustache + 1)
806
+ if self.class_embedding_Pointy_Nose:
807
+ emb += self.class_embedding_Pointy_Nose(class_labels_Pointy_Nose + 1)
808
+ if self.class_embedding_Smiling:
809
+ emb += self.class_embedding_Smiling(class_labels_Smiling + 1)
810
+ if self.class_embedding_Wearing_Hat:
811
+ emb += self.class_embedding_Wearing_Hat(class_labels_Wearing_Hat + 1)
812
+ if self.class_embedding_Young:
813
+ emb += self.class_embedding_Young(class_labels_Young + 1)
814
+ elif not self.apply_class_emb and class_labels is not None:
815
+ raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
816
+
817
+ # 2. pre-process
818
+ skip_sample = sample
819
+ sample = self.conv_in(sample)
820
+
821
+ # 3. down
822
+ down_block_res_samples = (sample,)
823
+ for downsample_block in self.down_blocks:
824
+ if hasattr(downsample_block, "skip_conv"):
825
+ sample, res_samples, skip_sample = downsample_block(
826
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
827
+ )
828
+ else:
829
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
830
+
831
+ down_block_res_samples += res_samples
832
+
833
+ # 4. mid
834
+ sample = self.mid_block(sample, emb)
835
+
836
+ # 5. up
837
+ skip_sample = None
838
+ for upsample_block in self.up_blocks:
839
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
840
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
841
+
842
+ if hasattr(upsample_block, "skip_conv"):
843
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
844
+ else:
845
+ sample = upsample_block(sample, res_samples, emb)
846
+
847
+ # 6. post-process
848
+ sample = self.conv_norm_out(sample)
849
+ sample = self.conv_act(sample)
850
+ sample = self.conv_out(sample)
851
+
852
+ if skip_sample is not None:
853
+ sample += skip_sample
854
+
855
+ if self.config.time_embedding_type == "fourier":
856
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
857
+ sample = sample / timesteps
858
+
859
+ if not return_dict:
860
+ return (sample,)
861
+
862
+ return UNet2DOutput(sample=sample)