RamAnanth1 commited on
Commit
cb69642
·
1 Parent(s): 9bfa18a

Delete cldm

Browse files
Files changed (2) hide show
  1. cldm/cldm.py +0 -417
  2. cldm/model.py +0 -21
cldm/cldm.py DELETED
@@ -1,417 +0,0 @@
1
- import einops
2
- import torch
3
- import torch as th
4
- import torch.nn as nn
5
-
6
- from ldm.modules.diffusionmodules.util import (
7
- conv_nd,
8
- linear,
9
- zero_module,
10
- timestep_embedding,
11
- )
12
-
13
- from einops import rearrange, repeat
14
- from torchvision.utils import make_grid
15
- from ldm.modules.attention import SpatialTransformer
16
- from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
17
- from ldm.models.diffusion.ddpm import LatentDiffusion
18
- from ldm.util import log_txt_as_img, exists, instantiate_from_config
19
- from ldm.models.diffusion.ddim import DDIMSampler
20
-
21
-
22
- class ControlledUnetModel(UNetModel):
23
- def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
24
- hs = []
25
- with torch.no_grad():
26
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
27
- emb = self.time_embed(t_emb)
28
- h = x.type(self.dtype)
29
- for module in self.input_blocks:
30
- h = module(h, emb, context)
31
- hs.append(h)
32
- h = self.middle_block(h, emb, context)
33
-
34
- h += control.pop()
35
-
36
- for i, module in enumerate(self.output_blocks):
37
- if only_mid_control:
38
- h = torch.cat([h, hs.pop()], dim=1)
39
- else:
40
- h = torch.cat([h, hs.pop() + control.pop()], dim=1)
41
- h = module(h, emb, context)
42
-
43
- h = h.type(x.dtype)
44
- return self.out(h)
45
-
46
-
47
- class ControlNet(nn.Module):
48
- def __init__(
49
- self,
50
- image_size,
51
- in_channels,
52
- model_channels,
53
- hint_channels,
54
- num_res_blocks,
55
- attention_resolutions,
56
- dropout=0,
57
- channel_mult=(1, 2, 4, 8),
58
- conv_resample=True,
59
- dims=2,
60
- use_checkpoint=False,
61
- use_fp16=False,
62
- num_heads=-1,
63
- num_head_channels=-1,
64
- num_heads_upsample=-1,
65
- use_scale_shift_norm=False,
66
- resblock_updown=False,
67
- use_new_attention_order=False,
68
- use_spatial_transformer=False, # custom transformer support
69
- transformer_depth=1, # custom transformer support
70
- context_dim=None, # custom transformer support
71
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
72
- legacy=True,
73
- disable_self_attentions=None,
74
- num_attention_blocks=None,
75
- disable_middle_self_attn=False,
76
- use_linear_in_transformer=False,
77
- ):
78
- super().__init__()
79
- if use_spatial_transformer:
80
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
81
-
82
- if context_dim is not None:
83
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
84
- from omegaconf.listconfig import ListConfig
85
- if type(context_dim) == ListConfig:
86
- context_dim = list(context_dim)
87
-
88
- if num_heads_upsample == -1:
89
- num_heads_upsample = num_heads
90
-
91
- if num_heads == -1:
92
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
93
-
94
- if num_head_channels == -1:
95
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
96
-
97
- self.dims = dims
98
- self.image_size = image_size
99
- self.in_channels = in_channels
100
- self.model_channels = model_channels
101
- if isinstance(num_res_blocks, int):
102
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
103
- else:
104
- if len(num_res_blocks) != len(channel_mult):
105
- raise ValueError("provide num_res_blocks either as an int (globally constant) or "
106
- "as a list/tuple (per-level) with the same length as channel_mult")
107
- self.num_res_blocks = num_res_blocks
108
- if disable_self_attentions is not None:
109
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
110
- assert len(disable_self_attentions) == len(channel_mult)
111
- if num_attention_blocks is not None:
112
- assert len(num_attention_blocks) == len(self.num_res_blocks)
113
- assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
114
- print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
115
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
116
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
117
- f"attention will still not be set.")
118
-
119
- self.attention_resolutions = attention_resolutions
120
- self.dropout = dropout
121
- self.channel_mult = channel_mult
122
- self.conv_resample = conv_resample
123
- self.use_checkpoint = use_checkpoint
124
- self.dtype = th.float16 if use_fp16 else th.float32
125
- self.num_heads = num_heads
126
- self.num_head_channels = num_head_channels
127
- self.num_heads_upsample = num_heads_upsample
128
- self.predict_codebook_ids = n_embed is not None
129
-
130
- time_embed_dim = model_channels * 4
131
- self.time_embed = nn.Sequential(
132
- linear(model_channels, time_embed_dim),
133
- nn.SiLU(),
134
- linear(time_embed_dim, time_embed_dim),
135
- )
136
-
137
- self.input_blocks = nn.ModuleList(
138
- [
139
- TimestepEmbedSequential(
140
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
141
- )
142
- ]
143
- )
144
- self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
145
-
146
- self.input_hint_block = TimestepEmbedSequential(
147
- conv_nd(dims, hint_channels, 16, 3, padding=1),
148
- nn.SiLU(),
149
- conv_nd(dims, 16, 16, 3, padding=1),
150
- nn.SiLU(),
151
- conv_nd(dims, 16, 32, 3, padding=1, stride=2),
152
- nn.SiLU(),
153
- conv_nd(dims, 32, 32, 3, padding=1),
154
- nn.SiLU(),
155
- conv_nd(dims, 32, 96, 3, padding=1, stride=2),
156
- nn.SiLU(),
157
- conv_nd(dims, 96, 96, 3, padding=1),
158
- nn.SiLU(),
159
- conv_nd(dims, 96, 256, 3, padding=1, stride=2),
160
- nn.SiLU(),
161
- zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
162
- )
163
-
164
- self._feature_size = model_channels
165
- input_block_chans = [model_channels]
166
- ch = model_channels
167
- ds = 1
168
- for level, mult in enumerate(channel_mult):
169
- for nr in range(self.num_res_blocks[level]):
170
- layers = [
171
- ResBlock(
172
- ch,
173
- time_embed_dim,
174
- dropout,
175
- out_channels=mult * model_channels,
176
- dims=dims,
177
- use_checkpoint=use_checkpoint,
178
- use_scale_shift_norm=use_scale_shift_norm,
179
- )
180
- ]
181
- ch = mult * model_channels
182
- if ds in attention_resolutions:
183
- if num_head_channels == -1:
184
- dim_head = ch // num_heads
185
- else:
186
- num_heads = ch // num_head_channels
187
- dim_head = num_head_channels
188
- if legacy:
189
- #num_heads = 1
190
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
191
- if exists(disable_self_attentions):
192
- disabled_sa = disable_self_attentions[level]
193
- else:
194
- disabled_sa = False
195
-
196
- if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
197
- layers.append(
198
- AttentionBlock(
199
- ch,
200
- use_checkpoint=use_checkpoint,
201
- num_heads=num_heads,
202
- num_head_channels=dim_head,
203
- use_new_attention_order=use_new_attention_order,
204
- ) if not use_spatial_transformer else SpatialTransformer(
205
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
206
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
207
- use_checkpoint=use_checkpoint
208
- )
209
- )
210
- self.input_blocks.append(TimestepEmbedSequential(*layers))
211
- self.zero_convs.append(self.make_zero_conv(ch))
212
- self._feature_size += ch
213
- input_block_chans.append(ch)
214
- if level != len(channel_mult) - 1:
215
- out_ch = ch
216
- self.input_blocks.append(
217
- TimestepEmbedSequential(
218
- ResBlock(
219
- ch,
220
- time_embed_dim,
221
- dropout,
222
- out_channels=out_ch,
223
- dims=dims,
224
- use_checkpoint=use_checkpoint,
225
- use_scale_shift_norm=use_scale_shift_norm,
226
- down=True,
227
- )
228
- if resblock_updown
229
- else Downsample(
230
- ch, conv_resample, dims=dims, out_channels=out_ch
231
- )
232
- )
233
- )
234
- ch = out_ch
235
- input_block_chans.append(ch)
236
- self.zero_convs.append(self.make_zero_conv(ch))
237
- ds *= 2
238
- self._feature_size += ch
239
-
240
- if num_head_channels == -1:
241
- dim_head = ch // num_heads
242
- else:
243
- num_heads = ch // num_head_channels
244
- dim_head = num_head_channels
245
- if legacy:
246
- #num_heads = 1
247
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
248
- self.middle_block = TimestepEmbedSequential(
249
- ResBlock(
250
- ch,
251
- time_embed_dim,
252
- dropout,
253
- dims=dims,
254
- use_checkpoint=use_checkpoint,
255
- use_scale_shift_norm=use_scale_shift_norm,
256
- ),
257
- AttentionBlock(
258
- ch,
259
- use_checkpoint=use_checkpoint,
260
- num_heads=num_heads,
261
- num_head_channels=dim_head,
262
- use_new_attention_order=use_new_attention_order,
263
- ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
264
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
265
- disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
266
- use_checkpoint=use_checkpoint
267
- ),
268
- ResBlock(
269
- ch,
270
- time_embed_dim,
271
- dropout,
272
- dims=dims,
273
- use_checkpoint=use_checkpoint,
274
- use_scale_shift_norm=use_scale_shift_norm,
275
- ),
276
- )
277
- self.middle_block_out = self.make_zero_conv(ch)
278
- self._feature_size += ch
279
-
280
- def make_zero_conv(self, channels):
281
- return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
282
-
283
- def forward(self, x, hint, timesteps, context, **kwargs):
284
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
285
- emb = self.time_embed(t_emb)
286
-
287
- guided_hint = self.input_hint_block(hint, emb, context)
288
-
289
- outs = []
290
-
291
- h = x.type(self.dtype)
292
- for module, zero_conv in zip(self.input_blocks, self.zero_convs):
293
- if guided_hint is not None:
294
- h = module(h, emb, context)
295
- h += guided_hint
296
- guided_hint = None
297
- else:
298
- h = module(h, emb, context)
299
- outs.append(zero_conv(h, emb, context))
300
-
301
- h = self.middle_block(h, emb, context)
302
- outs.append(self.middle_block_out(h, emb, context))
303
-
304
- return outs
305
-
306
-
307
- class ControlLDM(LatentDiffusion):
308
-
309
- def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
310
- super().__init__(*args, **kwargs)
311
- self.control_model = instantiate_from_config(control_stage_config)
312
- self.control_key = control_key
313
- self.only_mid_control = only_mid_control
314
-
315
- @torch.no_grad()
316
- def get_input(self, batch, k, bs=None, *args, **kwargs):
317
- x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
318
- control = batch[self.control_key]
319
- if bs is not None:
320
- control = control[:bs]
321
- control = control.to(self.device)
322
- control = einops.rearrange(control, 'b h w c -> b c h w')
323
- control = control.to(memory_format=torch.contiguous_format).float()
324
- return x, dict(c_crossattn=[c], c_concat=[control])
325
-
326
- def apply_model(self, x_noisy, t, cond, *args, **kwargs):
327
- assert isinstance(cond, dict)
328
- diffusion_model = self.model.diffusion_model
329
- cond_txt = torch.cat(cond['c_crossattn'], 1)
330
- cond_hint = torch.cat(cond['c_concat'], 1)
331
-
332
- control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
333
- eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
334
-
335
- return eps
336
-
337
- @torch.no_grad()
338
- def get_unconditional_conditioning(self, N):
339
- return self.get_learned_conditioning([""] * N)
340
-
341
- @torch.no_grad()
342
- def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
343
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
344
- plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
345
- use_ema_scope=True,
346
- **kwargs):
347
- use_ddim = ddim_steps is not None
348
-
349
- log = dict()
350
- z, c = self.get_input(batch, self.first_stage_key, bs=N)
351
- c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
352
- N = min(z.shape[0], N)
353
- n_row = min(z.shape[0], n_row)
354
- log["reconstruction"] = self.decode_first_stage(z)
355
- log["control"] = c_cat * 2.0 - 1.0
356
- log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
357
-
358
- if plot_diffusion_rows:
359
- # get diffusion row
360
- diffusion_row = list()
361
- z_start = z[:n_row]
362
- for t in range(self.num_timesteps):
363
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
364
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
365
- t = t.to(self.device).long()
366
- noise = torch.randn_like(z_start)
367
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
368
- diffusion_row.append(self.decode_first_stage(z_noisy))
369
-
370
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
371
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
372
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
373
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
374
- log["diffusion_row"] = diffusion_grid
375
-
376
- if sample:
377
- # get denoise row
378
- samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
379
- batch_size=N, ddim=use_ddim,
380
- ddim_steps=ddim_steps, eta=ddim_eta)
381
- x_samples = self.decode_first_stage(samples)
382
- log["samples"] = x_samples
383
- if plot_denoise_rows:
384
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
385
- log["denoise_row"] = denoise_grid
386
-
387
- if unconditional_guidance_scale > 1.0:
388
- uc_cross = self.get_unconditional_conditioning(N)
389
- uc_cat = c_cat # torch.zeros_like(c_cat)
390
- uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
391
- samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
392
- batch_size=N, ddim=use_ddim,
393
- ddim_steps=ddim_steps, eta=ddim_eta,
394
- unconditional_guidance_scale=unconditional_guidance_scale,
395
- unconditional_conditioning=uc_full,
396
- )
397
- x_samples_cfg = self.decode_first_stage(samples_cfg)
398
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
399
-
400
- return log
401
-
402
- @torch.no_grad()
403
- def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
404
- ddim_sampler = DDIMSampler(self)
405
- b, c, h, w = cond["c_concat"][0].shape
406
- shape = (self.channels, h // 8, w // 8)
407
- samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
408
- return samples, intermediates
409
-
410
- def configure_optimizers(self):
411
- lr = self.learning_rate
412
- params = list(self.control_model.parameters())
413
- if not self.sd_locked:
414
- params += list(self.model.diffusion_model.output_blocks.parameters())
415
- params += list(self.model.diffusion_model.out.parameters())
416
- opt = torch.optim.AdamW(params, lr=lr)
417
- return opt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cldm/model.py DELETED
@@ -1,21 +0,0 @@
1
- import torch
2
-
3
- from omegaconf import OmegaConf
4
- from ldm.util import instantiate_from_config
5
-
6
-
7
- def get_state_dict(d):
8
- return d.get('state_dict', d)
9
-
10
-
11
- def load_state_dict(ckpt_path, location='cpu'):
12
- state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
13
- print(f'Loaded state_dict from [{ckpt_path}]')
14
- return state_dict
15
-
16
-
17
- def create_model(config_path):
18
- config = OmegaConf.load(config_path)
19
- model = instantiate_from_config(config.model).cpu()
20
- print(f'Loaded model config from [{config_path}]')
21
- return model