JiminHeo commited on
Commit
8732441
·
1 Parent(s): 0078e7b
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ from ldm.util import instantiate_from_config
12
+
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x, return_all=False):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def configure_optimizers(self):
198
+ lr_d = self.learning_rate
199
+ lr_g = self.lr_g_factor*self.learning_rate
200
+ print("lr_d", lr_d)
201
+ print("lr_g", lr_g)
202
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
+ list(self.decoder.parameters())+
204
+ list(self.quantize.parameters())+
205
+ list(self.quant_conv.parameters())+
206
+ list(self.post_quant_conv.parameters()),
207
+ lr=lr_g, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
+ lr=lr_d, betas=(0.5, 0.9))
210
+
211
+ if self.scheduler_config is not None:
212
+ scheduler = instantiate_from_config(self.scheduler_config)
213
+
214
+ print("Setting up LambdaLR scheduler...")
215
+ scheduler = [
216
+ {
217
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
+ 'interval': 'step',
219
+ 'frequency': 1
220
+ },
221
+ {
222
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
+ 'interval': 'step',
224
+ 'frequency': 1
225
+ },
226
+ ]
227
+ return [opt_ae, opt_disc], scheduler
228
+ return [opt_ae, opt_disc], []
229
+
230
+ def get_last_layer(self):
231
+ return self.decoder.conv_out.weight
232
+
233
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
+ log = dict()
235
+ x = self.get_input(batch, self.image_key)
236
+ x = x.to(self.device)
237
+ if only_inputs:
238
+ log["inputs"] = x
239
+ return log
240
+ xrec, _ = self(x)
241
+ if x.shape[1] > 3:
242
+ # colorize with random projection
243
+ assert xrec.shape[1] > 3
244
+ x = self.to_rgb(x)
245
+ xrec = self.to_rgb(xrec)
246
+ log["inputs"] = x
247
+ log["reconstructions"] = xrec
248
+ if plot_ema:
249
+ with self.ema_scope():
250
+ xrec_ema, _ = self(x)
251
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
+ log["reconstructions_ema"] = xrec_ema
253
+ return log
254
+
255
+ def to_rgb(self, x):
256
+ assert self.image_key == "segmentation"
257
+ if not hasattr(self, "colorize"):
258
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
+ x = F.conv2d(x, weight=self.colorize)
260
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
+ return x
262
+
263
+ class VQModelInterface(VQModel):
264
+ def __init__(self, embed_dim, *args, **kwargs):
265
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
266
+ self.embed_dim = embed_dim
267
+
268
+ def encode(self, x, return_all=False):
269
+ h = self.encoder(x)
270
+ h = self.quant_conv(h)
271
+ return h
272
+
273
+
274
+ def decode(self, h, force_not_quantize=False):
275
+ # also go through quantization layer
276
+ if not force_not_quantize:
277
+ quant, emb_loss, info = self.quantize(h)
278
+ else:
279
+ quant = h
280
+ quant = self.post_quant_conv(quant)
281
+ dec = self.decoder(quant)
282
+ return dec
283
+
284
+
285
+ class AutoencoderKL(pl.LightningModule):
286
+ def __init__(self,
287
+ ddconfig,
288
+ lossconfig,
289
+ embed_dim,
290
+ ckpt_path=None,
291
+ ignore_keys=[],
292
+ image_key="image",
293
+ colorize_nlabels=None,
294
+ monitor=None,
295
+ ):
296
+ super().__init__()
297
+ self.image_key = image_key
298
+ self.encoder = Encoder(**ddconfig)
299
+ self.decoder = Decoder(**ddconfig)
300
+ self.loss = instantiate_from_config(lossconfig)
301
+ assert ddconfig["double_z"]
302
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
+ self.embed_dim = embed_dim
305
+ if colorize_nlabels is not None:
306
+ assert type(colorize_nlabels)==int
307
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
+ if monitor is not None:
309
+ self.monitor = monitor
310
+ if ckpt_path is not None:
311
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
+
313
+ def init_from_ckpt(self, path, ignore_keys=list()):
314
+ sd = torch.load(path, map_location="cpu")["state_dict"]
315
+ keys = list(sd.keys())
316
+ for k in keys:
317
+ for ik in ignore_keys:
318
+ if k.startswith(ik):
319
+ print("Deleting key {} from state_dict.".format(k))
320
+ del sd[k]
321
+ self.load_state_dict(sd, strict=False)
322
+ print(f"Restored from {path}")
323
+
324
+ def encode(self, x, return_all=False):
325
+ h = self.encoder(x)
326
+ moments = self.quant_conv(h)
327
+ posterior = DiagonalGaussianDistribution(moments)
328
+ if return_all: return posterior, moments
329
+ return posterior
330
+
331
+ def decode(self, z):
332
+ z = self.post_quant_conv(z)
333
+ dec = self.decoder(z)
334
+ return dec
335
+
336
+ def forward(self, input, sample_posterior=True):
337
+ posterior = self.encode(input)
338
+ if sample_posterior:
339
+ z = posterior.sample()
340
+ else:
341
+ z = posterior.mode()
342
+ dec = self.decode(z)
343
+ return dec, posterior
344
+
345
+ def get_input(self, batch, k):
346
+ x = batch[k]
347
+ if len(x.shape) == 3:
348
+ x = x[..., None]
349
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
350
+ return x
351
+
352
+ def training_step(self, batch, batch_idx, optimizer_idx):
353
+ inputs = self.get_input(batch, self.image_key)
354
+ reconstructions, posterior = self(inputs)
355
+
356
+ if optimizer_idx == 0:
357
+ # train encoder+decoder+logvar
358
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
359
+ last_layer=self.get_last_layer(), split="train")
360
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
361
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
362
+ return aeloss
363
+
364
+ if optimizer_idx == 1:
365
+ # train the discriminator
366
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
367
+ last_layer=self.get_last_layer(), split="train")
368
+
369
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
370
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
371
+ return discloss
372
+
373
+ def validation_step(self, batch, batch_idx):
374
+ inputs = self.get_input(batch, self.image_key)
375
+ reconstructions, posterior = self(inputs)
376
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
377
+ last_layer=self.get_last_layer(), split="val")
378
+
379
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
380
+ last_layer=self.get_last_layer(), split="val")
381
+
382
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
383
+ self.log_dict(log_dict_ae)
384
+ self.log_dict(log_dict_disc)
385
+ return self.log_dict
386
+
387
+ def configure_optimizers(self):
388
+ lr = self.learning_rate
389
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
390
+ list(self.decoder.parameters())+
391
+ list(self.quant_conv.parameters())+
392
+ list(self.post_quant_conv.parameters()),
393
+ lr=lr, betas=(0.5, 0.9))
394
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
395
+ lr=lr, betas=(0.5, 0.9))
396
+ return [opt_ae, opt_disc], []
397
+
398
+ def get_last_layer(self):
399
+ return self.decoder.conv_out.weight
400
+
401
+ @torch.no_grad()
402
+ def log_images(self, batch, only_inputs=False, **kwargs):
403
+ log = dict()
404
+ x = self.get_input(batch, self.image_key)
405
+ x = x.to(self.device)
406
+ if not only_inputs:
407
+ xrec, posterior = self(x)
408
+ if x.shape[1] > 3:
409
+ # colorize with random projection
410
+ assert xrec.shape[1] > 3
411
+ x = self.to_rgb(x)
412
+ xrec = self.to_rgb(xrec)
413
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
414
+ log["reconstructions"] = xrec
415
+ log["inputs"] = x
416
+ return log
417
+
418
+ def to_rgb(self, x):
419
+ assert self.image_key == "segmentation"
420
+ if not hasattr(self, "colorize"):
421
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
422
+ x = F.conv2d(x, weight=self.colorize)
423
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
424
+ return x
425
+
426
+
427
+ class IdentityFirstStage(torch.nn.Module):
428
+ def __init__(self, *args, vq_interface=False, **kwargs):
429
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
430
+ super().__init__()
431
+
432
+ def encode(self, x, *args, **kwargs):
433
+ return x
434
+
435
+ def decode(self, x, *args, **kwargs):
436
+ return x
437
+
438
+ def quantize(self, x, *args, **kwargs):
439
+ if self.vq_interface:
440
+ return x, None, [None, None, None]
441
+ return x
442
+
443
+ def forward(self, x, *args, **kwargs):
444
+ return x
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class DDIMSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27
+ alphas_cumprod = self.model.alphas_cumprod
28
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30
+
31
+ self.register_buffer('betas', to_torch(self.model.betas))
32
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34
+
35
+ # calculations for diffusion q(x_t | x_{t-1}) and others
36
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41
+
42
+ # ddim sampling parameters
43
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44
+ ddim_timesteps=self.ddim_timesteps,
45
+ eta=ddim_eta,verbose=verbose)
46
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
47
+ self.register_buffer('ddim_alphas', ddim_alphas)
48
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54
+
55
+ @torch.no_grad()
56
+ def sample(self,
57
+ S,
58
+ batch_size,
59
+ shape,
60
+ conditioning=None,
61
+ callback=None,
62
+ normals_sequence=None,
63
+ img_callback=None,
64
+ quantize_x0=False,
65
+ eta=0.,
66
+ mask=None,
67
+ x0=None,
68
+ temperature=1.,
69
+ noise_dropout=0.,
70
+ score_corrector=None,
71
+ corrector_kwargs=None,
72
+ verbose=True,
73
+ x_T=None,
74
+ log_every_t=100,
75
+ unconditional_guidance_scale=1.,
76
+ unconditional_conditioning=None,
77
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
83
+ if cbs != batch_size:
84
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
85
+ else:
86
+ if conditioning.shape[0] != batch_size:
87
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
88
+
89
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
90
+ # sampling
91
+ C, H, W = shape
92
+ size = (batch_size, C, H, W)
93
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
94
+
95
+ samples, intermediates = self.ddim_sampling(conditioning, size,
96
+ callback=callback,
97
+ img_callback=img_callback,
98
+ quantize_denoised=quantize_x0,
99
+ mask=mask, x0=x0,
100
+ ddim_use_original_steps=False,
101
+ noise_dropout=noise_dropout,
102
+ temperature=temperature,
103
+ score_corrector=score_corrector,
104
+ corrector_kwargs=corrector_kwargs,
105
+ x_T=x_T,
106
+ log_every_t=log_every_t,
107
+ unconditional_guidance_scale=unconditional_guidance_scale,
108
+ unconditional_conditioning=unconditional_conditioning,
109
+ )
110
+ return samples, intermediates
111
+
112
+ @torch.no_grad()
113
+ def ddim_sampling(self, cond, shape,
114
+ x_T=None, ddim_use_original_steps=False,
115
+ callback=None, timesteps=None, quantize_denoised=False,
116
+ mask=None, x0=None, img_callback=None, log_every_t=100,
117
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
118
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
119
+ device = self.model.betas.device
120
+ b = shape[0]
121
+ if x_T is None:
122
+ img = torch.randn(shape, device=device)
123
+ else:
124
+ img = x_T
125
+
126
+ if timesteps is None:
127
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
128
+ elif timesteps is not None and not ddim_use_original_steps:
129
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
130
+ timesteps = self.ddim_timesteps[:subset_end]
131
+
132
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
133
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
134
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
135
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
136
+
137
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
138
+
139
+ for i, step in enumerate(iterator):
140
+ index = total_steps - i - 1
141
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
142
+
143
+ if mask is not None:
144
+ assert x0 is not None
145
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
146
+ img = img_orig * mask + (1. - mask) * img
147
+
148
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
149
+ quantize_denoised=quantize_denoised, temperature=temperature,
150
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
151
+ corrector_kwargs=corrector_kwargs,
152
+ unconditional_guidance_scale=unconditional_guidance_scale,
153
+ unconditional_conditioning=unconditional_conditioning)
154
+ img, pred_x0 = outs
155
+ if callback: callback(i)
156
+ if img_callback: img_callback(pred_x0, i)
157
+
158
+ if index % log_every_t == 0 or index == total_steps - 1:
159
+ intermediates['x_inter'].append(img)
160
+ intermediates['pred_x0'].append(pred_x0)
161
+
162
+ return img, intermediates
163
+
164
+ @torch.no_grad()
165
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
166
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
167
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
168
+ b, *_, device = *x.shape, x.device
169
+
170
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
171
+ e_t = self.model.apply_model(x, t, c)
172
+ else:
173
+ x_in = torch.cat([x] * 2)
174
+ t_in = torch.cat([t] * 2)
175
+ c_in = torch.cat([unconditional_conditioning, c])
176
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
177
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
178
+
179
+ if score_corrector is not None:
180
+ assert self.model.parameterization == "eps"
181
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
182
+
183
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187
+ # select parameters corresponding to the currently considered timestep
188
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
189
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
190
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
191
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
192
+
193
+ # current prediction for x_0
194
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
195
+ if quantize_denoised:
196
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
197
+ # direction pointing to x_t
198
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
199
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
200
+ if noise_dropout > 0.:
201
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
203
+ return x_prev, pred_x0
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager
16
+ from functools import partial
17
+ from tqdm import tqdm
18
+ from torchvision.utils import make_grid
19
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
20
+
21
+
22
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
23
+ from ldm.modules.ema import LitEma
24
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
25
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
26
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like, betas_for_alpha_bar
27
+ from ldm.models.diffusion.ddim import DDIMSampler
28
+
29
+ __conditioning_keys__ = {'concat': 'c_concat',
30
+ 'crossattn': 'c_crossattn',
31
+ 'adm': 'y'}
32
+
33
+
34
+ def disabled_train(self, mode=True):
35
+ """Overwrite model.train with this function to make sure train/eval mode
36
+ does not change anymore."""
37
+ return self
38
+
39
+
40
+ def uniform_on_device(r1, r2, shape, device):
41
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
42
+
43
+
44
+ class DDPM(pl.LightningModule):
45
+ # classic DDPM with Gaussian diffusion, in image space
46
+ def __init__(self,
47
+ unet_config,
48
+ timesteps=1000,
49
+ beta_schedule="linear", # "linear", "cosine", "sqrt_linear"
50
+ loss_type="l2",
51
+ ckpt_path=None,
52
+ ignore_keys=[],
53
+ load_only_unet=False,
54
+ monitor="val/loss",
55
+ use_ema=True,
56
+ first_stage_key="image",
57
+ image_size=256,
58
+ channels=3,
59
+ log_every_t=100,
60
+ clip_denoised=True,
61
+ linear_start=1e-4,
62
+ linear_end=2e-2,
63
+ cosine_s=8e-3,
64
+ given_betas=None,
65
+ original_elbo_weight=0.,
66
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
67
+ l_simple_weight=1.,
68
+ conditioning_key=None,
69
+ parameterization="eps", #was eps, x0 # all assuming fixed variance schedules
70
+ scheduler_config=None,
71
+ use_positional_encodings=False,
72
+ learn_logvar=False,
73
+ logvar_init=0.,
74
+ ):
75
+ super().__init__()
76
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
77
+ self.parameterization = parameterization
78
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
79
+ self.cond_stage_model = None
80
+ self.clip_denoised = clip_denoised
81
+ self.log_every_t = log_every_t
82
+ self.first_stage_key = first_stage_key
83
+ self.image_size = image_size # try conv?
84
+ self.channels = channels
85
+ self.use_positional_encodings = use_positional_encodings
86
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
87
+ count_params(self.model, verbose=True)
88
+ self.use_ema = use_ema
89
+ if self.use_ema:
90
+ self.model_ema = LitEma(self.model)
91
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
92
+
93
+ self.use_scheduler = scheduler_config is not None
94
+ if self.use_scheduler:
95
+ self.scheduler_config = scheduler_config
96
+
97
+ self.v_posterior = v_posterior
98
+ self.original_elbo_weight = original_elbo_weight
99
+ self.l_simple_weight = l_simple_weight
100
+
101
+ if monitor is not None:
102
+ self.monitor = monitor
103
+ if ckpt_path is not None:
104
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
105
+
106
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
107
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
108
+
109
+ self.loss_type = loss_type
110
+
111
+ self.learn_logvar = learn_logvar
112
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
113
+ if self.learn_logvar:
114
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
115
+
116
+
117
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
118
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
119
+ #beta_schedule="edm"
120
+ if exists(given_betas):
121
+ betas = given_betas
122
+ elif beta_schedule=="edm":
123
+ alpha = 0.1
124
+ sigma_min = 0.002
125
+ sigma_max = 80
126
+ sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max),timesteps))
127
+ self.num_timesteps = int(timesteps)
128
+ self.sigma_min = sigma_min
129
+ self.sigma_max = sigma_max
130
+ assert sigmas.shape[0] == self.num_timesteps, 'sigmas have to be defined for each timestep'
131
+ to_torch = partial(torch.tensor, dtype=torch.float32)
132
+
133
+ alphas_cumprod = 1. - sigmas**2
134
+ sigma_prev = np.append(0., sigmas[:-1])
135
+ betas = sigmas**2 - sigma_prev**2
136
+
137
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(torch.ones_like(sigmas)))
138
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
139
+ self.register_buffer('betas', to_torch(betas))
140
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(sigmas))
141
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
142
+ self.register_buffer('sqrt_recip_alphas_cumprod',to_torch(torch.ones_like(sigmas)))
143
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(sigmas))
144
+ self.register_buffer('sigmas', to_torch(sigmas))
145
+ posterior_variance = (1 - self.v_posterior)*(sigma_prev/sigmas)**2 * (1/(betas)) + self.v_posterior*betas
146
+
147
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
148
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
149
+
150
+ self.register_buffer('posterior_mean_coef1', to_torch(1. - (sigma_prev/sigmas)**2))
151
+ self.register_buffer('posterior_mean_coef2', to_torch((sigma_prev/sigmas)**2))
152
+
153
+ if self.parameterization == "eps":
154
+ lvlb_weights = self.sqrt_recipm1_alphas_cumprod**2 / (2*self.posterior_variance)
155
+ elif self.parameterization == "x0":
156
+ ##not changed because not needed
157
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
158
+ else:
159
+ raise NotImplementedError("mu not supported")
160
+ else:
161
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
162
+ alphas = 1. - betas
163
+ alphas_cumprod = np.cumprod(alphas, axis=0)
164
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
165
+
166
+ ### beta_jump = alpha_
167
+ timesteps, = betas.shape
168
+ self.num_timesteps = int(timesteps)
169
+ self.linear_start = linear_start
170
+ self.linear_end = linear_end
171
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
172
+
173
+ to_torch = partial(torch.tensor, dtype=torch.float32)
174
+
175
+ self.register_buffer('betas', to_torch(betas))
176
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
177
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
178
+
179
+ # calculations for diffusion q(x_t | x_{t-1}) and others
180
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) ##for mean
181
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
182
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
183
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
184
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
185
+
186
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
187
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
188
+ 1. - alphas_cumprod) + self.v_posterior * betas
189
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
190
+
191
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
192
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
193
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
194
+ self.register_buffer('posterior_mean_coef1', to_torch(
195
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
196
+ self.register_buffer('posterior_mean_coef2', to_torch(
197
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
198
+
199
+ if self.parameterization == "eps":
200
+ lvlb_weights = self.betas ** 2 / (
201
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
202
+ elif self.parameterization == "x0":
203
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
204
+ else:
205
+ raise NotImplementedError("mu not supported")
206
+
207
+ # TODO how to choose this term
208
+ lvlb_weights[0] = lvlb_weights[1]
209
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
210
+ assert not torch.isnan(self.lvlb_weights).all()
211
+
212
+ @contextmanager
213
+ def ema_scope(self, context=None):
214
+ if self.use_ema:
215
+ self.model_ema.store(self.model.parameters())
216
+ self.model_ema.copy_to(self.model)
217
+ if context is not None:
218
+ print(f"{context}: Switched to EMA weights")
219
+ try:
220
+ yield None
221
+ finally:
222
+ if self.use_ema:
223
+ self.model_ema.restore(self.model.parameters())
224
+ if context is not None:
225
+ print(f"{context}: Restored training weights")
226
+
227
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
228
+ sd = torch.load(path, map_location="cpu")
229
+ if "state_dict" in list(sd.keys()):
230
+ sd = sd["state_dict"]
231
+ keys = list(sd.keys())
232
+ for k in keys:
233
+ for ik in ignore_keys:
234
+ if k.startswith(ik):
235
+ print("Deleting key {} from state_dict.".format(k))
236
+ del sd[k]
237
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
238
+ sd, strict=False)
239
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
240
+ if len(missing) > 0:
241
+ print(f"Missing Keys: {missing}")
242
+ if len(unexpected) > 0:
243
+ print(f"Unexpected Keys: {unexpected}")
244
+
245
+ def q_mean_variance(self, x_start, t):
246
+ """
247
+ Get the distribution q(x_t | x_0).
248
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
249
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
250
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
251
+ """
252
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
253
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
254
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
255
+ return mean, variance, log_variance
256
+
257
+ def predict_start_from_noise(self, x_t, t, noise):
258
+ return (
259
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
260
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
261
+ )
262
+ ##
263
+
264
+ def q_posterior(self, x_start, x_t, t):
265
+ posterior_mean = (
266
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
267
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
268
+ )
269
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
270
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
271
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
272
+
273
+
274
+ def p_mean_variance(self, x, t, clip_denoised: bool):
275
+ model_out = self.model(x, t)
276
+ if self.parameterization == "eps":
277
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
278
+ elif self.parameterization == "x0":
279
+ x_recon = model_out
280
+ if clip_denoised:
281
+ x_recon.clamp_(-1., 1.)
282
+
283
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
284
+ return model_mean, posterior_variance, posterior_log_variance
285
+
286
+ #@torch.no_grad()
287
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
288
+ b, *_, device = *x.shape, x.device
289
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
290
+ noise = noise_like(x.shape, device, repeat_noise)
291
+ # no noise when t == 0
292
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
293
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
294
+
295
+ @torch.no_grad()
296
+ def p_sample_loop(self, shape, return_intermediates=False):
297
+ device = self.betas.device
298
+ b = shape[0]
299
+ img = torch.randn(shape, device=device)
300
+ intermediates = [img]
301
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
302
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
303
+ clip_denoised=self.clip_denoised)
304
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
305
+ intermediates.append(img)
306
+ if return_intermediates:
307
+ return img, intermediates
308
+ return img
309
+
310
+ #@torch.no_grad()
311
+ def sample(self, batch_size=16, return_intermediates=False):
312
+ image_size = self.image_size
313
+ channels = self.channels
314
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
315
+ return_intermediates=return_intermediates)
316
+
317
+ def q_sample(self, x_start, t, noise=None):
318
+ noise = default(noise, lambda: torch.randn_like(x_start))
319
+ first = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
320
+ first = first * x_start
321
+ second = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
322
+ second = second * noise
323
+ return ( first + second
324
+ )
325
+
326
+ def q_sample_seq(self, x_start, t, noise=None):
327
+ noise = default(noise, lambda: torch.randn_like(x_start))
328
+ t_sorted, indices = torch.sort(t)
329
+ sigma_t = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_sorted, x_start.shape)
330
+ #sigma_prev = torch.append(0., sigma_t[:-1])
331
+ #sigmas_cond_prev_t =
332
+
333
+ x_t = x_start
334
+ x_t[0] = x_start[0] + sigma_t[0]*noise[0]
335
+ cum_noise = sigma_t[0]*noise[0]
336
+ for i in range(1,x_start.shape[0]):
337
+ x_t[i] = x_t[i-1] + noise[i] * torch.sqrt(sigma_t[i]**2 - sigma_t[i-1]**2)
338
+ cum_noise += noise[i] * torch.sqrt(sigma_t[i]**2 - sigma_t[i-1]**2)
339
+ noise[i] = (cum_noise)/(sigma_t[i])
340
+
341
+
342
+ return x_t, noise
343
+
344
+ def get_loss(self, pred, target, mean=True):
345
+ if self.loss_type == 'l1':
346
+ loss = (target - pred).abs()
347
+ if mean:
348
+ loss = loss.mean()
349
+ elif self.loss_type == 'l2':
350
+ if mean:
351
+ loss = torch.nn.functional.mse_loss(target, pred)
352
+ else:
353
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
354
+ else:
355
+ raise NotImplementedError("unknown loss type '{loss_type}'")
356
+
357
+ return loss
358
+
359
+ def p_losses(self, x_start, t, noise=None):
360
+ noise = default(noise, lambda: torch.randn_like(x_start))
361
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
362
+ model_out = self.model(x_noisy, t)
363
+
364
+ loss_dict = {}
365
+ if self.parameterization == "eps":
366
+ target = noise
367
+ elif self.parameterization == "x0":
368
+ target = x_start
369
+ else:
370
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
371
+
372
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
373
+
374
+ log_prefix = 'train' if self.training else 'val'
375
+
376
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
377
+ loss_simple = loss.mean() * self.l_simple_weight
378
+
379
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
380
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
381
+
382
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
383
+
384
+ loss_dict.update({f'{log_prefix}/loss': loss})
385
+
386
+ return loss, loss_dict
387
+
388
+ def forward(self, x, *args, **kwargs):
389
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
390
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
391
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
392
+ return self.p_losses(x, t, *args, **kwargs)
393
+
394
+ def get_input(self, batch, k):
395
+ x = batch[k]
396
+ if len(x.shape) == 3:
397
+ x = x[..., None]
398
+ x = rearrange(x, 'b h w c -> b c h w')
399
+ x = x.to(memory_format=torch.contiguous_format).float()
400
+ return x
401
+
402
+ def shared_step(self, batch):
403
+ x = self.get_input(batch, self.first_stage_key)
404
+ loss, loss_dict = self(x)
405
+ return loss, loss_dict
406
+
407
+ def training_step(self, batch, batch_idx):
408
+ loss, loss_dict = self.shared_step(batch)
409
+
410
+ self.log_dict(loss_dict, prog_bar=True,
411
+ logger=True, on_step=True, on_epoch=True)
412
+
413
+ self.log("global_step", self.global_step,
414
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
415
+
416
+ if self.use_scheduler:
417
+ lr = self.optimizers().param_groups[0]['lr']
418
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
419
+
420
+ return loss
421
+
422
+ @torch.no_grad()
423
+ def validation_step(self, batch, batch_idx):
424
+ _, loss_dict_no_ema = self.shared_step(batch)
425
+ with self.ema_scope():
426
+ _, loss_dict_ema = self.shared_step(batch)
427
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
428
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
429
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
430
+
431
+ def on_train_batch_end(self, *args, **kwargs):
432
+ if self.use_ema:
433
+ self.model_ema(self.model)
434
+
435
+ def _get_rows_from_list(self, samples):
436
+ n_imgs_per_row = len(samples)
437
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
438
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
439
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
440
+ return denoise_grid
441
+
442
+ @torch.no_grad()
443
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
444
+ log = dict()
445
+ x = self.get_input(batch, self.first_stage_key)
446
+ N = min(x.shape[0], N)
447
+ n_row = min(x.shape[0], n_row)
448
+ x = x.to(self.device)[:N]
449
+ log["inputs"] = x
450
+
451
+ # get diffusion row
452
+ diffusion_row = list()
453
+ x_start = x[:n_row]
454
+
455
+ for t in range(self.num_timesteps):
456
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
457
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
458
+ t = t.to(self.device).long()
459
+ noise = torch.randn_like(x_start)
460
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
461
+ diffusion_row.append(x_noisy)
462
+
463
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
464
+
465
+ if sample:
466
+ # get denoise row
467
+ with self.ema_scope("Plotting"):
468
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
469
+
470
+ log["samples"] = samples
471
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
472
+
473
+ if return_keys:
474
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
475
+ return log
476
+ else:
477
+ return {key: log[key] for key in return_keys}
478
+ return log
479
+
480
+ def configure_optimizers(self):
481
+ lr = self.learning_rate
482
+ params = list(self.model.parameters())
483
+ if self.learn_logvar:
484
+ params = params + [self.logvar]
485
+ opt = torch.optim.AdamW(params, lr=lr)
486
+ return opt
487
+
488
+
489
+ class LatentDiffusion(DDPM):
490
+ """main class"""
491
+ def __init__(self,
492
+ first_stage_config,
493
+ cond_stage_config,
494
+ num_timesteps_cond=None,
495
+ cond_stage_key="image",
496
+ cond_stage_trainable=False,
497
+ concat_mode=True,
498
+ cond_stage_forward=None,
499
+ conditioning_key=None,
500
+ scale_factor=1.0,
501
+ scale_by_std=False,
502
+ *args, **kwargs):
503
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
504
+ self.scale_by_std = scale_by_std
505
+ assert self.num_timesteps_cond <= kwargs['timesteps']
506
+ # for backwards compatibility after implementation of DiffusionWrapper
507
+ if conditioning_key is None:
508
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
509
+ if cond_stage_config == '__is_unconditional__':
510
+ conditioning_key = None
511
+ ckpt_path = kwargs.pop("ckpt_path", None)
512
+ ignore_keys = kwargs.pop("ignore_keys", [])
513
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
514
+ self.concat_mode = concat_mode
515
+ self.cond_stage_trainable = cond_stage_trainable
516
+ self.cond_stage_key = cond_stage_key
517
+ try:
518
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
519
+ except:
520
+ self.num_downs = 0
521
+ if not scale_by_std:
522
+ self.scale_factor = scale_factor
523
+ else:
524
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
525
+ self.instantiate_first_stage(first_stage_config)
526
+ self.instantiate_cond_stage(cond_stage_config)
527
+ self.cond_stage_forward = cond_stage_forward
528
+ self.clip_denoised = False
529
+ self.bbox_tokenizer = None
530
+
531
+ self.restarted_from_ckpt = False
532
+ if ckpt_path is not None:
533
+ self.init_from_ckpt(ckpt_path, ignore_keys)
534
+ self.restarted_from_ckpt = True
535
+
536
+ def make_cond_schedule(self, ):
537
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
538
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
539
+ self.cond_ids[:self.num_timesteps_cond] = ids
540
+
541
+ @rank_zero_only
542
+ @torch.no_grad()
543
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
544
+ # only for very first batch
545
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
546
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
547
+ # set rescale weight to 1./std of encodings
548
+ print("### USING STD-RESCALING ###")
549
+ x = super().get_input(batch, self.first_stage_key)
550
+ x = x.to(self.device)
551
+ encoder_posterior = self.encode_first_stage(x)
552
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
553
+ del self.scale_factor
554
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
555
+ print(f"setting self.scale_factor to {self.scale_factor}")
556
+ print("### USING STD-RESCALING ###")
557
+
558
+ def register_schedule(self,
559
+ given_betas=None, beta_schedule="linear", timesteps=1000,
560
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
561
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
562
+
563
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
564
+ if self.shorten_cond_schedule:
565
+ self.make_cond_schedule()
566
+
567
+ def instantiate_first_stage(self, config):
568
+ model = instantiate_from_config(config)
569
+ self.first_stage_model = model.eval()
570
+ self.first_stage_model.train = disabled_train
571
+ for param in self.first_stage_model.parameters():
572
+ param.requires_grad = False
573
+
574
+ def instantiate_cond_stage(self, config):
575
+ if not self.cond_stage_trainable:
576
+ if config == "__is_first_stage__":
577
+ print("Using first stage also as cond stage.")
578
+ self.cond_stage_model = self.first_stage_model
579
+ elif config == "__is_unconditional__":
580
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
581
+ self.cond_stage_model = None
582
+ # self.be_unconditional = True
583
+ else:
584
+ model = instantiate_from_config(config)
585
+ self.cond_stage_model = model.eval()
586
+ self.cond_stage_model.train = disabled_train
587
+ for param in self.cond_stage_model.parameters():
588
+ param.requires_grad = False
589
+ else:
590
+ assert config != '__is_first_stage__'
591
+ assert config != '__is_unconditional__'
592
+ model = instantiate_from_config(config)
593
+ self.cond_stage_model = model
594
+
595
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
596
+ denoise_row = []
597
+ for zd in tqdm(samples, desc=desc):
598
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
599
+ force_not_quantize=force_no_decoder_quantization))
600
+ n_imgs_per_row = len(denoise_row)
601
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
602
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
603
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
604
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
605
+ return denoise_grid
606
+
607
+ def get_first_stage_encoding(self, encoder_posterior):
608
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
609
+ z = encoder_posterior.sample()
610
+ elif isinstance(encoder_posterior, torch.Tensor):
611
+ z = encoder_posterior
612
+ else:
613
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
614
+ return self.scale_factor * z
615
+
616
+ def get_learned_conditioning(self, c):
617
+ if self.cond_stage_forward is None:
618
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
619
+ c = self.cond_stage_model.encode(c)
620
+ if isinstance(c, DiagonalGaussianDistribution):
621
+ c = c.mode()
622
+ else:
623
+ c = self.cond_stage_model(c)
624
+ else:
625
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
626
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
627
+ return c
628
+
629
+ def meshgrid(self, h, w):
630
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
631
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
632
+
633
+ arr = torch.cat([y, x], dim=-1)
634
+ return arr
635
+
636
+ def delta_border(self, h, w):
637
+ """
638
+ :param h: height
639
+ :param w: width
640
+ :return: normalized distance to image border,
641
+ wtith min distance = 0 at border and max dist = 0.5 at image center
642
+ """
643
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
644
+ arr = self.meshgrid(h, w) / lower_right_corner
645
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
646
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
647
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
648
+ return edge_dist
649
+
650
+ def get_weighting(self, h, w, Ly, Lx, device):
651
+ weighting = self.delta_border(h, w)
652
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
653
+ self.split_input_params["clip_max_weight"], )
654
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
655
+
656
+ if self.split_input_params["tie_braker"]:
657
+ L_weighting = self.delta_border(Ly, Lx)
658
+ L_weighting = torch.clip(L_weighting,
659
+ self.split_input_params["clip_min_tie_weight"],
660
+ self.split_input_params["clip_max_tie_weight"])
661
+
662
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
663
+ weighting = weighting * L_weighting
664
+ return weighting
665
+
666
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
667
+ """
668
+ :param x: img of size (bs, c, h, w)
669
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
670
+ """
671
+ bs, nc, h, w = x.shape
672
+
673
+ # number of crops in image
674
+ Ly = (h - kernel_size[0]) // stride[0] + 1
675
+ Lx = (w - kernel_size[1]) // stride[1] + 1
676
+
677
+ if uf == 1 and df == 1:
678
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
679
+ unfold = torch.nn.Unfold(**fold_params)
680
+
681
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
682
+
683
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
684
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
685
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
686
+
687
+ elif uf > 1 and df == 1:
688
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
689
+ unfold = torch.nn.Unfold(**fold_params)
690
+
691
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
692
+ dilation=1, padding=0,
693
+ stride=(stride[0] * uf, stride[1] * uf))
694
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
695
+
696
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
697
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
698
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
699
+
700
+ elif df > 1 and uf == 1:
701
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
702
+ unfold = torch.nn.Unfold(**fold_params)
703
+
704
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
705
+ dilation=1, padding=0,
706
+ stride=(stride[0] // df, stride[1] // df))
707
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
708
+
709
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
710
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
711
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
712
+
713
+ else:
714
+ raise NotImplementedError
715
+
716
+ return fold, unfold, normalization, weighting
717
+
718
+ @torch.no_grad()
719
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
720
+ cond_key=None, return_original_cond=False, bs=None):
721
+ x = super().get_input(batch, k)
722
+ if bs is not None:
723
+ x = x[:bs]
724
+ x = x.to(self.device)
725
+ encoder_posterior = self.encode_first_stage(x)
726
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
727
+
728
+ if self.model.conditioning_key is not None:
729
+ if cond_key is None:
730
+ cond_key = self.cond_stage_key
731
+ if cond_key != self.first_stage_key:
732
+ if cond_key in ['caption', 'coordinates_bbox']:
733
+ xc = batch[cond_key]
734
+ elif cond_key == 'class_label':
735
+ xc = batch
736
+ else:
737
+ xc = super().get_input(batch, cond_key).to(self.device)
738
+ else:
739
+ xc = x
740
+ if not self.cond_stage_trainable or force_c_encode:
741
+ if isinstance(xc, dict) or isinstance(xc, list):
742
+ # import pudb; pudb.set_trace()
743
+ c = self.get_learned_conditioning(xc)
744
+ else:
745
+ c = self.get_learned_conditioning(xc.to(self.device))
746
+ else:
747
+ c = xc
748
+ if bs is not None:
749
+ c = c[:bs]
750
+
751
+ if self.use_positional_encodings:
752
+ pos_x, pos_y = self.compute_latent_shifts(batch)
753
+ ckey = __conditioning_keys__[self.model.conditioning_key]
754
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
755
+
756
+ else:
757
+ c = None
758
+ xc = None
759
+ if self.use_positional_encodings:
760
+ pos_x, pos_y = self.compute_latent_shifts(batch)
761
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
762
+ out = [z, c]
763
+ if return_first_stage_outputs:
764
+ xrec = self.decode_first_stage(z)
765
+ out.extend([x, xrec])
766
+ if return_original_cond:
767
+ out.append(xc)
768
+ return out
769
+
770
+ #@torch.no_grad()
771
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
772
+ if predict_cids:
773
+ if z.dim() == 4:
774
+ z = torch.argmax(z.exp(), dim=1).long()
775
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
776
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
777
+
778
+ z = 1. / self.scale_factor * z
779
+
780
+ if hasattr(self, "split_input_params"):
781
+ if self.split_input_params["patch_distributed_vq"]:
782
+ ks = self.split_input_params["ks"] # eg. (128, 128)
783
+ stride = self.split_input_params["stride"] # eg. (64, 64)
784
+ uf = self.split_input_params["vqf"]
785
+ bs, nc, h, w = z.shape
786
+ if ks[0] > h or ks[1] > w:
787
+ ks = (min(ks[0], h), min(ks[1], w))
788
+ print("reducing Kernel")
789
+
790
+ if stride[0] > h or stride[1] > w:
791
+ stride = (min(stride[0], h), min(stride[1], w))
792
+ print("reducing stride")
793
+
794
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
795
+
796
+ z = unfold(z) # (bn, nc * prod(**ks), L)
797
+ # 1. Reshape to img shape
798
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
799
+
800
+ # 2. apply model loop over last dim
801
+ if isinstance(self.first_stage_model, VQModelInterface):
802
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
803
+ force_not_quantize=predict_cids or force_not_quantize)
804
+ for i in range(z.shape[-1])]
805
+ else:
806
+
807
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
808
+ for i in range(z.shape[-1])]
809
+
810
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
811
+ o = o * weighting
812
+ # Reverse 1. reshape to img shape
813
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
814
+ # stitch crops together
815
+ decoded = fold(o)
816
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
817
+ return decoded
818
+ else:
819
+ if isinstance(self.first_stage_model, VQModelInterface):
820
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
821
+ else:
822
+ return self.first_stage_model.decode(z)
823
+
824
+ else:
825
+ if isinstance(self.first_stage_model, VQModelInterface):
826
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
827
+ else:
828
+ return self.first_stage_model.decode(z)
829
+
830
+ # same as above but without decorator
831
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
832
+ if predict_cids:
833
+ if z.dim() == 4:
834
+ z = torch.argmax(z.exp(), dim=1).long()
835
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
836
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
837
+
838
+ z = 1. / self.scale_factor * z
839
+
840
+ if hasattr(self, "split_input_params"):
841
+ if self.split_input_params["patch_distributed_vq"]:
842
+ ks = self.split_input_params["ks"] # eg. (128, 128)
843
+ stride = self.split_input_params["stride"] # eg. (64, 64)
844
+ uf = self.split_input_params["vqf"]
845
+ bs, nc, h, w = z.shape
846
+ if ks[0] > h or ks[1] > w:
847
+ ks = (min(ks[0], h), min(ks[1], w))
848
+ print("reducing Kernel")
849
+
850
+ if stride[0] > h or stride[1] > w:
851
+ stride = (min(stride[0], h), min(stride[1], w))
852
+ print("reducing stride")
853
+
854
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
855
+
856
+ z = unfold(z) # (bn, nc * prod(**ks), L)
857
+ # 1. Reshape to img shape
858
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
859
+
860
+ # 2. apply model loop over last dim
861
+ if isinstance(self.first_stage_model, VQModelInterface):
862
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
863
+ force_not_quantize=predict_cids or force_not_quantize)
864
+ for i in range(z.shape[-1])]
865
+ else:
866
+
867
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
868
+ for i in range(z.shape[-1])]
869
+
870
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
871
+ o = o * weighting
872
+ # Reverse 1. reshape to img shape
873
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
874
+ # stitch crops together
875
+ decoded = fold(o)
876
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
877
+ return decoded
878
+ else:
879
+ if isinstance(self.first_stage_model, VQModelInterface):
880
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
881
+ else:
882
+ return self.first_stage_model.decode(z)
883
+
884
+ else:
885
+ if isinstance(self.first_stage_model, VQModelInterface):
886
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
887
+ else:
888
+ return self.first_stage_model.decode(z)
889
+
890
+ #@torch.no_grad()
891
+ def encode_first_stage(self, x, return_all=None):
892
+ if hasattr(self, "split_input_params"):
893
+ if self.split_input_params["patch_distributed_vq"]:
894
+ ks = self.split_input_params["ks"] # eg. (128, 128)
895
+ stride = self.split_input_params["stride"] # eg. (64, 64)
896
+ df = self.split_input_params["vqf"]
897
+ self.split_input_params['original_image_size'] = x.shape[-2:]
898
+ bs, nc, h, w = x.shape
899
+ if ks[0] > h or ks[1] > w:
900
+ ks = (min(ks[0], h), min(ks[1], w))
901
+ print("reducing Kernel")
902
+
903
+ if stride[0] > h or stride[1] > w:
904
+ stride = (min(stride[0], h), min(stride[1], w))
905
+ print("reducing stride")
906
+
907
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
908
+ z = unfold(x) # (bn, nc * prod(**ks), L)
909
+ # Reshape to img shape
910
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
911
+
912
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
913
+ for i in range(z.shape[-1])]
914
+
915
+ o = torch.stack(output_list, axis=-1)
916
+ o = o * weighting
917
+
918
+ # Reverse reshape to img shape
919
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
920
+ # stitch crops together
921
+ decoded = fold(o)
922
+ decoded = decoded / normalization
923
+ return decoded
924
+ else:
925
+ return self.first_stage_model.encode(x,return_all)
926
+ else:
927
+ posterior = self.first_stage_model.encode(x, return_all) #
928
+ #print(self.first_stage_model.loss.logvar)
929
+ return posterior #
930
+
931
+ def shared_step(self, batch, **kwargs):
932
+ x, c = self.get_input(batch, self.first_stage_key)
933
+ loss = self(x, c)
934
+ return loss
935
+
936
+ def forward(self, x, c, *args, **kwargs):
937
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
938
+ if self.model.conditioning_key is not None:
939
+ assert c is not None
940
+ if self.cond_stage_trainable:
941
+ c = self.get_learned_conditioning(c)
942
+ if self.shorten_cond_schedule: # TODO: drop this option
943
+ tc = self.cond_ids[t].to(self.device)
944
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
945
+ return self.p_losses(x, c, t, *args, **kwargs)
946
+
947
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
948
+ def rescale_bbox(bbox):
949
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
950
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
951
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
952
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
953
+ return x0, y0, w, h
954
+
955
+ return [rescale_bbox(b) for b in bboxes]
956
+
957
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
958
+
959
+ if isinstance(cond, dict):
960
+ # hybrid case, cond is exptected to be a dict
961
+ pass
962
+ else:
963
+ if not isinstance(cond, list):
964
+ cond = [cond]
965
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
966
+ cond = {key: cond}
967
+
968
+ if hasattr(self, "split_input_params"):
969
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
970
+ assert not return_ids
971
+ ks = self.split_input_params["ks"] # eg. (128, 128)
972
+ stride = self.split_input_params["stride"] # eg. (64, 64)
973
+
974
+ h, w = x_noisy.shape[-2:]
975
+
976
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
977
+
978
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
979
+ # Reshape to img shape
980
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
981
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
982
+
983
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
984
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
985
+ c_key = next(iter(cond.keys())) # get key
986
+ c = next(iter(cond.values())) # get value
987
+ assert (len(c) == 1) # todo extend to list with more than one elem
988
+ c = c[0] # get element
989
+
990
+ c = unfold(c)
991
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
992
+
993
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
994
+
995
+ elif self.cond_stage_key == 'coordinates_bbox':
996
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
997
+
998
+ # assuming padding of unfold is always 0 and its dilation is always 1
999
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
1000
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
1001
+ # as we are operating on latents, we need the factor from the original image size to the
1002
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
1003
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
1004
+ rescale_latent = 2 ** (num_downs)
1005
+
1006
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
1007
+ # need to rescale the tl patch coordinates to be in between (0,1)
1008
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
1009
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
1010
+ for patch_nr in range(z.shape[-1])]
1011
+
1012
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
1013
+ patch_limits = [(x_tl, y_tl,
1014
+ rescale_latent * ks[0] / full_img_w,
1015
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
1016
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
1017
+
1018
+ # tokenize crop coordinates for the bounding boxes of the respective patches
1019
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
1020
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
1021
+ print(patch_limits_tknzd[0].shape)
1022
+ # cut tknzd crop position from conditioning
1023
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
1024
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
1025
+ print(cut_cond.shape)
1026
+
1027
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
1028
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
1029
+ print(adapted_cond.shape)
1030
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
1031
+ print(adapted_cond.shape)
1032
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
1033
+ print(adapted_cond.shape)
1034
+
1035
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
1036
+
1037
+ else:
1038
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
1039
+
1040
+ # apply model by loop over crops
1041
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1042
+ assert not isinstance(output_list[0],
1043
+ tuple) # todo cant deal with multiple model outputs check this never happens
1044
+
1045
+ o = torch.stack(output_list, axis=-1)
1046
+ o = o * weighting
1047
+ # Reverse reshape to img shape
1048
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1049
+ # stitch crops together
1050
+ x_recon = fold(o) / normalization
1051
+
1052
+ else:
1053
+ x_recon = self.model(x_noisy, t, **cond)
1054
+
1055
+ if isinstance(x_recon, tuple) and not return_ids:
1056
+ return x_recon[0]
1057
+ else:
1058
+ return x_recon
1059
+
1060
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1061
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1062
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1063
+
1064
+ def _prior_bpd(self, x_start):
1065
+ """
1066
+ Get the prior KL term for the variational lower-bound, measured in
1067
+ bits-per-dim.
1068
+ This term can't be optimized, as it only depends on the encoder.
1069
+ :param x_start: the [N x C x ...] tensor of inputs.
1070
+ :return: a batch of [N] KL values (in bits), one per batch element.
1071
+ """
1072
+ batch_size = x_start.shape[0]
1073
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1074
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1075
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1076
+ return mean_flat(kl_prior) / np.log(2.0)
1077
+
1078
+ def p_losses(self, x_start, cond, t, noise=None):
1079
+ noise = default(noise, lambda: torch.randn_like(x_start))
1080
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1081
+ model_output = self.apply_model(x_noisy, t, cond)
1082
+
1083
+ loss_dict = {}
1084
+ prefix = 'train' if self.training else 'val'
1085
+
1086
+ if self.parameterization == "x0":
1087
+ target = x_start
1088
+ elif self.parameterization == "eps":
1089
+ target = noise
1090
+ else:
1091
+ raise NotImplementedError()
1092
+
1093
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1094
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1095
+
1096
+ logvar_t = self.logvar[t.cpu()].to(self.device)
1097
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1098
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1099
+ if self.learn_logvar:
1100
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1101
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1102
+
1103
+ loss = self.l_simple_weight * loss.mean()
1104
+
1105
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1106
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1107
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1108
+ loss += (self.original_elbo_weight * loss_vlb)
1109
+ loss_dict.update({f'{prefix}/loss': loss})
1110
+
1111
+ return loss, loss_dict
1112
+
1113
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1114
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1115
+ t_in = t
1116
+ if c is not None:
1117
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1118
+ else:
1119
+ model_out = self.model(x, t_in)
1120
+
1121
+ if score_corrector is not None:
1122
+ assert self.parameterization == "eps"
1123
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1124
+
1125
+ if return_codebook_ids:
1126
+ model_out, logits = model_out
1127
+
1128
+ if self.parameterization == "eps":
1129
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1130
+ elif self.parameterization == "x0":
1131
+ x_recon = model_out
1132
+ else:
1133
+ raise NotImplementedError()
1134
+
1135
+ if clip_denoised:
1136
+ x_recon.clamp_(-1., 1.)
1137
+ if quantize_denoised:
1138
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1139
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1140
+ if return_codebook_ids:
1141
+ return model_mean, posterior_variance, posterior_log_variance, logits
1142
+ elif return_x0:
1143
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1144
+ else:
1145
+ return model_mean, posterior_variance, posterior_log_variance
1146
+
1147
+ #@torch.no_grad()
1148
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1149
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1150
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1151
+ b, *_, device = *x.shape, x.device
1152
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1153
+ return_codebook_ids=return_codebook_ids,
1154
+ quantize_denoised=quantize_denoised,
1155
+ return_x0=return_x0,
1156
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1157
+ if return_codebook_ids:
1158
+ raise DeprecationWarning("Support dropped.")
1159
+ model_mean, _, model_log_variance, logits = outputs
1160
+ elif return_x0:
1161
+ model_mean, _, model_log_variance, x0 = outputs
1162
+ else:
1163
+ model_mean, _, model_log_variance = outputs
1164
+
1165
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1166
+ if noise_dropout > 0.:
1167
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1168
+ # no noise when t == 0
1169
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1170
+
1171
+ if return_codebook_ids:
1172
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1173
+ if return_x0:
1174
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1175
+ else:
1176
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1177
+
1178
+ #@torch.no_grad()
1179
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1180
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1181
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1182
+ log_every_t=None):
1183
+ if not log_every_t:
1184
+ log_every_t = self.log_every_t
1185
+ timesteps = self.num_timesteps
1186
+ if batch_size is not None:
1187
+ b = batch_size if batch_size is not None else shape[0]
1188
+ shape = [batch_size] + list(shape)
1189
+ else:
1190
+ b = batch_size = shape[0]
1191
+ if x_T is None:
1192
+ img = torch.randn(shape, device=self.device)
1193
+ else:
1194
+ img = x_T
1195
+ intermediates = []
1196
+ if cond is not None:
1197
+ if isinstance(cond, dict):
1198
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1199
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1200
+ else:
1201
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1202
+
1203
+ if start_T is not None:
1204
+ timesteps = min(timesteps, start_T)
1205
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1206
+ total=timesteps) if verbose else reversed(
1207
+ range(0, timesteps))
1208
+ if type(temperature) == float:
1209
+ temperature = [temperature] * timesteps
1210
+
1211
+ for i in iterator:
1212
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1213
+ if self.shorten_cond_schedule:
1214
+ assert self.model.conditioning_key != 'hybrid'
1215
+ tc = self.cond_ids[ts].to(cond.device)
1216
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1217
+
1218
+ img, x0_partial = self.p_sample(img, cond, ts,
1219
+ clip_denoised=self.clip_denoised,
1220
+ quantize_denoised=quantize_denoised, return_x0=True,
1221
+ temperature=temperature[i], noise_dropout=noise_dropout,
1222
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1223
+ if mask is not None:
1224
+ assert x0 is not None
1225
+ img_orig = self.q_sample(x0, ts)
1226
+ img = img_orig * mask + (1. - mask) * img
1227
+
1228
+ if i % log_every_t == 0 or i == timesteps - 1:
1229
+ intermediates.append(x0_partial)
1230
+ if callback: callback(i)
1231
+ if img_callback: img_callback(img, i)
1232
+ return img, intermediates
1233
+
1234
+ @torch.no_grad()
1235
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1236
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1237
+ mask=None, x0=None, img_callback=None, start_T=None,
1238
+ log_every_t=None):
1239
+
1240
+ if not log_every_t:
1241
+ log_every_t = self.log_every_t
1242
+ device = self.betas.device
1243
+ b = shape[0]
1244
+ if x_T is None:
1245
+ img = torch.randn(shape, device=device)
1246
+ else:
1247
+ img = x_T
1248
+
1249
+ intermediates = [img]
1250
+ if timesteps is None:
1251
+ timesteps = self.num_timesteps
1252
+
1253
+ if start_T is not None:
1254
+ timesteps = min(timesteps, start_T)
1255
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1256
+ range(0, timesteps))
1257
+
1258
+ if mask is not None:
1259
+ assert x0 is not None
1260
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1261
+
1262
+ for i in iterator:
1263
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1264
+ if self.shorten_cond_schedule:
1265
+ assert self.model.conditioning_key != 'hybrid'
1266
+ tc = self.cond_ids[ts].to(cond.device)
1267
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1268
+
1269
+ img = self.p_sample(img, cond, ts,
1270
+ clip_denoised=self.clip_denoised,
1271
+ quantize_denoised=quantize_denoised)
1272
+ if mask is not None:
1273
+ img_orig = self.q_sample(x0, ts)
1274
+ img = img_orig * mask + (1. - mask) * img
1275
+
1276
+ if i % log_every_t == 0 or i == timesteps - 1:
1277
+ intermediates.append(img)
1278
+ if callback: callback(i)
1279
+ if img_callback: img_callback(img, i)
1280
+
1281
+ if return_intermediates:
1282
+ return img, intermediates
1283
+ return img
1284
+
1285
+ @torch.no_grad()
1286
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1287
+ verbose=True, timesteps=None, quantize_denoised=False,
1288
+ mask=None, x0=None, shape=None,**kwargs):
1289
+ if shape is None:
1290
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1291
+ if cond is not None:
1292
+ if isinstance(cond, dict):
1293
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1294
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1295
+ else:
1296
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1297
+ return self.p_sample_loop(cond,
1298
+ shape,
1299
+ return_intermediates=return_intermediates, x_T=x_T,
1300
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1301
+ mask=mask, x0=x0)
1302
+
1303
+ @torch.no_grad()
1304
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1305
+
1306
+ if ddim:
1307
+ ddim_sampler = DDIMSampler(self)
1308
+ shape = (self.channels, self.image_size, self.image_size)
1309
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1310
+ shape,cond,verbose=False,**kwargs)
1311
+
1312
+ else:
1313
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1314
+ return_intermediates=True,**kwargs)
1315
+
1316
+ return samples, intermediates
1317
+
1318
+
1319
+ @torch.no_grad()
1320
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1321
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1322
+ plot_diffusion_rows=True, **kwargs):
1323
+
1324
+ use_ddim = ddim_steps is not None
1325
+
1326
+ log = dict()
1327
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1328
+ return_first_stage_outputs=True,
1329
+ force_c_encode=True,
1330
+ return_original_cond=True,
1331
+ bs=N)
1332
+ N = min(x.shape[0], N)
1333
+ n_row = min(x.shape[0], n_row)
1334
+ log["inputs"] = x
1335
+ log["reconstruction"] = xrec
1336
+ if self.model.conditioning_key is not None:
1337
+ if hasattr(self.cond_stage_model, "decode"):
1338
+ xc = self.cond_stage_model.decode(c)
1339
+ log["conditioning"] = xc
1340
+ elif self.cond_stage_key in ["caption"]:
1341
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1342
+ log["conditioning"] = xc
1343
+ elif self.cond_stage_key == 'class_label':
1344
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1345
+ log['conditioning'] = xc
1346
+ elif isimage(xc):
1347
+ log["conditioning"] = xc
1348
+ if ismap(xc):
1349
+ log["original_conditioning"] = self.to_rgb(xc)
1350
+
1351
+ if plot_diffusion_rows:
1352
+ # get diffusion row
1353
+ diffusion_row = list()
1354
+ z_start = z[:n_row]
1355
+ for t in range(self.num_timesteps):
1356
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1357
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1358
+ t = t.to(self.device).long()
1359
+ noise = torch.randn_like(z_start)
1360
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1361
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1362
+
1363
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1364
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1365
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1366
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1367
+ log["diffusion_row"] = diffusion_grid
1368
+
1369
+ if sample:
1370
+ # get denoise row
1371
+ with self.ema_scope("Plotting"):
1372
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1373
+ ddim_steps=ddim_steps,eta=ddim_eta)
1374
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1375
+ x_samples = self.decode_first_stage(samples)
1376
+ log["samples"] = x_samples
1377
+ if plot_denoise_rows:
1378
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1379
+ log["denoise_row"] = denoise_grid
1380
+
1381
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1382
+ self.first_stage_model, IdentityFirstStage):
1383
+ # also display when quantizing x0 while sampling
1384
+ with self.ema_scope("Plotting Quantized Denoised"):
1385
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1386
+ ddim_steps=ddim_steps,eta=ddim_eta,
1387
+ quantize_denoised=True)
1388
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1389
+ # quantize_denoised=True)
1390
+ x_samples = self.decode_first_stage(samples.to(self.device))
1391
+ log["samples_x0_quantized"] = x_samples
1392
+
1393
+ if inpaint:
1394
+ # make a simple center square
1395
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1396
+ mask = torch.ones(N, h, w).to(self.device)
1397
+ # zeros will be filled in
1398
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1399
+ mask = mask[:, None, ...]
1400
+ with self.ema_scope("Plotting Inpaint"):
1401
+
1402
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1403
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1404
+ x_samples = self.decode_first_stage(samples.to(self.device))
1405
+ log["samples_inpainting"] = x_samples
1406
+ log["mask"] = mask
1407
+
1408
+ # outpaint
1409
+ with self.ema_scope("Plotting Outpaint"):
1410
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1411
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1412
+ x_samples = self.decode_first_stage(samples.to(self.device))
1413
+ log["samples_outpainting"] = x_samples
1414
+
1415
+ if plot_progressive_rows:
1416
+ with self.ema_scope("Plotting Progressives"):
1417
+ img, progressives = self.progressive_denoising(c,
1418
+ shape=(self.channels, self.image_size, self.image_size),
1419
+ batch_size=N)
1420
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1421
+ log["progressive_row"] = prog_row
1422
+
1423
+ if return_keys:
1424
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1425
+ return log
1426
+ else:
1427
+ return {key: log[key] for key in return_keys}
1428
+ return log
1429
+
1430
+ def configure_optimizers(self):
1431
+ lr = self.learning_rate
1432
+ params = list(self.model.parameters())
1433
+ if self.cond_stage_trainable:
1434
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1435
+ params = params + list(self.cond_stage_model.parameters())
1436
+ if self.learn_logvar:
1437
+ print('Diffusion model optimizing logvar')
1438
+ params.append(self.logvar)
1439
+ opt = torch.optim.AdamW(params, lr=lr)
1440
+ if self.use_scheduler:
1441
+ assert 'target' in self.scheduler_config
1442
+ scheduler = instantiate_from_config(self.scheduler_config)
1443
+
1444
+ print("Setting up LambdaLR scheduler...")
1445
+ scheduler = [
1446
+ {
1447
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1448
+ 'interval': 'step',
1449
+ 'frequency': 1
1450
+ }]
1451
+ return [opt], scheduler
1452
+ return opt
1453
+
1454
+ @torch.no_grad()
1455
+ def to_rgb(self, x):
1456
+ x = x.float()
1457
+ if not hasattr(self, "colorize"):
1458
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1459
+ x = nn.functional.conv2d(x, weight=self.colorize)
1460
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1461
+ return x
1462
+
1463
+
1464
+ class DiffusionWrapper(pl.LightningModule):
1465
+ def __init__(self, diff_model_config, conditioning_key):
1466
+ super().__init__()
1467
+ #self.automatic_optimization = False
1468
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1469
+ self.conditioning_key = conditioning_key
1470
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1471
+
1472
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1473
+ if self.conditioning_key is None:
1474
+ out = self.diffusion_model(x, t)
1475
+ elif self.conditioning_key == 'concat':
1476
+ xc = torch.cat([x] + c_concat, dim=1)
1477
+ out = self.diffusion_model(xc, t)
1478
+ elif self.conditioning_key == 'crossattn':
1479
+ cc = torch.cat(c_crossattn, 1)
1480
+ out = self.diffusion_model(x, t, context=cc)
1481
+ elif self.conditioning_key == 'hybrid':
1482
+ xc = torch.cat([x] + c_concat, dim=1)
1483
+ cc = torch.cat(c_crossattn, 1)
1484
+ out = self.diffusion_model(xc, t, context=cc)
1485
+ elif self.conditioning_key == 'adm':
1486
+ cc = c_crossattn[0]
1487
+ out = self.diffusion_model(x, t, y=cc)
1488
+ else:
1489
+ raise NotImplementedError()
1490
+
1491
+ return out
1492
+
1493
+
1494
+ class Layout2ImgDiffusion(LatentDiffusion):
1495
+ # TODO: move all layout-specific hacks to this class
1496
+ def __init__(self, cond_stage_key, *args, **kwargs):
1497
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1498
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1499
+
1500
+ def log_images(self, batch, N=8, *args, **kwargs):
1501
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1502
+
1503
+ key = 'train' if self.training else 'validation'
1504
+ dset = self.trainer.datamodule.datasets[key]
1505
+ mapper = dset.conditional_builders[self.cond_stage_key]
1506
+
1507
+ bbox_imgs = []
1508
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1509
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1510
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1511
+ bbox_imgs.append(bboximg)
1512
+
1513
+ cond_img = torch.stack(bbox_imgs, dim=0)
1514
+ logs['bbox_image'] = cond_img
1515
+ return logs
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class PLMSSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ if ddim_eta != 0:
26
+ raise ValueError('ddim_eta must be 0 for PLMS')
27
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
+ ddim_timesteps=self.ddim_timesteps,
47
+ eta=ddim_eta,verbose=verbose)
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ conditioning=None,
63
+ callback=None,
64
+ normals_sequence=None,
65
+ img_callback=None,
66
+ quantize_x0=False,
67
+ eta=0.,
68
+ mask=None,
69
+ x0=None,
70
+ temperature=1.,
71
+ noise_dropout=0.,
72
+ score_corrector=None,
73
+ corrector_kwargs=None,
74
+ verbose=True,
75
+ x_T=None,
76
+ log_every_t=100,
77
+ unconditional_guidance_scale=1.,
78
+ unconditional_conditioning=None,
79
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+ else:
88
+ if conditioning.shape[0] != batch_size:
89
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
+
91
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
+ # sampling
93
+ C, H, W = shape
94
+ size = (batch_size, C, H, W)
95
+ print(f'Data shape for PLMS sampling is {size}')
96
+
97
+ samples, intermediates = self.plms_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ )
112
+ return samples, intermediates
113
+
114
+ @torch.no_grad()
115
+ def plms_sampling(self, cond, shape,
116
+ x_T=None, ddim_use_original_steps=False,
117
+ callback=None, timesteps=None, quantize_denoised=False,
118
+ mask=None, x0=None, img_callback=None, log_every_t=100,
119
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
121
+ device = self.model.betas.device
122
+ b = shape[0]
123
+ if x_T is None:
124
+ img = torch.randn(shape, device=device)
125
+ else:
126
+ img = x_T
127
+
128
+ if timesteps is None:
129
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130
+ elif timesteps is not None and not ddim_use_original_steps:
131
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132
+ timesteps = self.ddim_timesteps[:subset_end]
133
+
134
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
135
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
138
+
139
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140
+ old_eps = []
141
+
142
+ for i, step in enumerate(iterator):
143
+ index = total_steps - i - 1
144
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
145
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146
+
147
+ if mask is not None:
148
+ assert x0 is not None
149
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150
+ img = img_orig * mask + (1. - mask) * img
151
+
152
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153
+ quantize_denoised=quantize_denoised, temperature=temperature,
154
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
155
+ corrector_kwargs=corrector_kwargs,
156
+ unconditional_guidance_scale=unconditional_guidance_scale,
157
+ unconditional_conditioning=unconditional_conditioning,
158
+ old_eps=old_eps, t_next=ts_next)
159
+ img, pred_x0, e_t = outs
160
+ old_eps.append(e_t)
161
+ if len(old_eps) >= 4:
162
+ old_eps.pop(0)
163
+ if callback: callback(i)
164
+ if img_callback: img_callback(pred_x0, i)
165
+
166
+ if index % log_every_t == 0 or index == total_steps - 1:
167
+ intermediates['x_inter'].append(img)
168
+ intermediates['pred_x0'].append(pred_x0)
169
+
170
+ return img, intermediates
171
+
172
+ @torch.no_grad()
173
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176
+ b, *_, device = *x.shape, x.device
177
+
178
+ def get_model_output(x, t):
179
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180
+ e_t = self.model.apply_model(x, t, c)
181
+ else:
182
+ x_in = torch.cat([x] * 2)
183
+ t_in = torch.cat([t] * 2)
184
+ c_in = torch.cat([unconditional_conditioning, c])
185
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187
+
188
+ if score_corrector is not None:
189
+ assert self.model.parameterization == "eps"
190
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191
+
192
+ return e_t
193
+
194
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198
+
199
+ def get_x_prev_and_pred_x0(e_t, index):
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
+ if quantize_denoised:
209
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
+ # direction pointing to x_t
211
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213
+ if noise_dropout > 0.:
214
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
+ return x_prev, pred_x0
217
+
218
+ e_t = get_model_output(x, t)
219
+ if len(old_eps) == 0:
220
+ # Pseudo Improved Euler (2nd order)
221
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222
+ e_t_next = get_model_output(x_prev, t_next)
223
+ e_t_prime = (e_t + e_t_next) / 2
224
+ elif len(old_eps) == 1:
225
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
227
+ elif len(old_eps) == 2:
228
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230
+ elif len(old_eps) >= 3:
231
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233
+
234
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235
+
236
+ return x_prev, pred_x0, e_t