Spaces:
Runtime error
Runtime error
import einops | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from einops import rearrange, repeat | |
from torchvision.utils import make_grid | |
from ldm.models.diffusion.ddpm import LatentDiffusion | |
from ldm.util import log_txt_as_img, instantiate_from_config | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from models.q_formers import load_qformer_model | |
class AnyControlNet(LatentDiffusion): | |
def __init__(self, mode, qformer_config=None, local_control_config=None, global_control_config=None, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
assert mode in ['local', 'uni'] | |
self.mode = mode | |
self.qformer_config = qformer_config | |
self.local_control_config = local_control_config | |
self.global_control_config = global_control_config | |
self.model.diffusion_model.requires_grad_(False) | |
self.model.diffusion_model.requires_grad_(False) | |
self.model.diffusion_model.requires_grad_(False) | |
q_former, (vis_processor, txt_processor) = load_qformer_model(qformer_config) | |
self.q_former = q_former | |
self.qformer_vis_processor = vis_processor | |
self.qformer_txt_processor = txt_processor | |
self.local_adapter = instantiate_from_config(local_control_config) | |
self.local_control_scales = [1.0] * 13 | |
self.global_adapter = instantiate_from_config(global_control_config) if self.mode == 'uni' else None | |
self.clip_embeddings_dim = global_control_config.params.clip_embeddings_dim | |
self.color_in_dim = global_control_config.params.color_in_dim | |
def get_input(self, batch, k, bs=None, *args, **kwargs): | |
# latent and text | |
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs) | |
bs = bs or x.size(0) | |
shape = self.get_shape(batch, bs) | |
local_control = self.get_local_conditions_for_vision_encoder(batch, bs) | |
local_control = local_control.to(memory_format=torch.contiguous_format).float() | |
global_control = {} | |
global_conditions = batch['global_conditions'][:bs] | |
for key in batch['global_conditions'][0].data.keys(): | |
global_cond = torch.stack([torch.Tensor(dc.data[key]) for dc in global_conditions]) | |
global_cond = global_cond.to(self.device).to(memory_format=torch.contiguous_format).float() | |
global_control[key] = global_cond | |
conditions = dict( | |
text=[batch['txt']], | |
c_crossattn=[c], | |
local_control=[local_control], | |
global_control=[global_control], | |
) | |
return x, conditions | |
def apply_model(self, x_noisy, t, cond, local_strength=1.0, content_strength=1.0, color_strength=1.0, *args, **kwargs): | |
assert isinstance(cond, dict) | |
diffusion_model = self.model.diffusion_model | |
cond_txt = torch.cat(cond['c_crossattn'], 1) | |
text = cond['text'][0] | |
bs = x_noisy.shape[0] | |
# extract global control | |
if self.mode in ['uni']: | |
content_control, color_control = self.global_adapter( | |
cond['global_control'][0]['clipembedding'], cond['global_control'][0]['color']) | |
else: | |
content_control = torch.zeros(bs, self.clip_embeddings_dim).to(self.device).to(memory_format=torch.contiguous_format).float() | |
color_control = torch.zeros(bs, self.color_in_dim).to(self.device).to(memory_format=torch.contiguous_format).float() | |
# extract local control | |
if self.mode in ['local', 'uni']: | |
local_features = self.local_adapter.extract_local_features(self.q_former, text, cond['local_control'][0]) | |
local_control = self.local_adapter(x=x_noisy, timesteps=t, context=cond_txt, local_features=local_features) | |
local_control = [c * scale for c, scale in zip(local_control, self.local_control_scales)] | |
eps = diffusion_model( | |
x=x_noisy, timesteps=t, context=cond_txt, | |
local_control=local_control, local_w=local_strength, | |
content_control=content_control, extra_w=content_strength, | |
color_control=color_control, color_w=color_strength) | |
return eps | |
def get_unconditional_conditioning(self, N): | |
return self.get_learned_conditioning([""] * N) | |
def get_unconditional_global_conditioning(self, c): | |
if isinstance(c, dict): | |
return {k:torch.zeros_like(v) for k,v in c.items()} | |
elif isinstance(c, list): | |
return [torch.zeros_like(v) for v in c] | |
else: | |
return torch.zeros_like(c) | |
def get_shape(self, batch, N): | |
return [dc.data[0].shape[:2] for dc in batch['local_conditions'][:N]] | |
def get_local_conditions_for_vision_encoder(self, batch, N): | |
# return: local_conditions, (bs, num_conds * 3, h, w) | |
local_conditions = [] | |
max_len = max([len(dc.data) for dc in batch['local_conditions'][:N]]) | |
for dc in batch['local_conditions'][:N]: | |
conds = torch.cat([self.qformer_vis_processor['eval'](Image.fromarray(img)).unsqueeze(0) for img in dc.data], dim=1) | |
local_conditions.append(conds) | |
local_conditions = [F.pad(cond, (0,0,0,0,0,max_len*3-cond.shape[1],0,0)) for cond in local_conditions] | |
local_conditions = torch.cat(local_conditions, dim=0).to(self.device) | |
return local_conditions | |
def get_local_conditions_for_logging(self, batch, N): | |
local_conditions = [] | |
max_len = max([len(dc.data) for dc in batch['local_conditions'][:N]]) | |
for dc in batch['local_conditions'][:N]: | |
conds = torch.stack([torch.Tensor(img).permute(2,0,1) for img in dc.data], dim=0) # (n, c, h, w) | |
conds = conds.float() / 255. | |
conds = conds * 2.0 - 1.0 | |
local_conditions.append(conds) | |
local_conditions = [F.pad(cond, (0,0,0,0,0,0,0,max_len-cond.shape[0])) for cond in local_conditions] | |
local_conditions = torch.stack(local_conditions, dim=0).to(self.device) # (bs, n, c, h, w) | |
local_conditions = local_conditions.flatten(1,2) | |
return local_conditions | |
def clip_batch(self, batch, key, N, flag=True): | |
if isinstance(batch, torch.Tensor): | |
return batch[:N] | |
elif isinstance(batch, list): | |
return batch[:N] | |
batch = batch[key][0] if flag else batch[key] | |
if isinstance(batch, torch.Tensor): | |
return batch[:N] | |
elif isinstance(batch, list): | |
return batch[:N] | |
elif isinstance(batch, dict): | |
return {k:self.clip_batch(v,'',N,flag=False) for k,v in batch.items()} | |
else: | |
raise ValueError(f'Unsupported type {type(batch)}') | |
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, | |
plot_denoise_rows=False, plot_diffusion_rows=False, unconditional_guidance_scale=9.0, **kwargs): | |
use_ddim = ddim_steps is not None | |
log = dict() | |
z, c = self.get_input(batch, self.first_stage_key, bs=N) | |
shape = self.get_shape(batch, N) | |
c_local = self.clip_batch(c, "local_control", N) | |
c_global = self.clip_batch(c, "global_control", N) | |
c_context = self.clip_batch(c, "c_crossattn", N) | |
c_text = self.clip_batch(batch, self.cond_stage_key, N, False) | |
N = min(z.shape[0], N) | |
n_row = min(z.shape[0], n_row) | |
log["reconstruction"] = self.decode_first_stage(z) | |
log["conditioning"] = log_txt_as_img((512, 512), c_text, size=16) | |
log["local_control"] = self.get_local_conditions_for_logging(batch, N) | |
if plot_diffusion_rows: | |
diffusion_row = list() | |
z_start = z[:n_row] | |
for t in range(self.num_timesteps): | |
if t % self.log_every_t == 0 or t == self.num_timesteps - 1: | |
t = repeat(torch.tensor([t]), '1 -> b', b=n_row) | |
t = t.to(self.device).long() | |
noise = torch.randn_like(z_start) | |
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) | |
diffusion_row.append(self.decode_first_stage(z_noisy)) | |
diffusion_row = torch.stack(diffusion_row) | |
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') | |
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') | |
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) | |
log["diffusion_row"] = diffusion_grid | |
cond_dict = dict( | |
local_control=[c_local], | |
global_control=[c_global], | |
c_crossattn=[c_context], | |
text=[c_text], | |
shape=[shape], | |
) | |
if sample: | |
samples, z_denoise_row = self.sample_log(cond=cond_dict, | |
batch_size=N, ddim=use_ddim, | |
ddim_steps=ddim_steps, eta=ddim_eta, | |
log_every_t=self.log_every_t * 0.05) | |
x_samples = self.decode_first_stage(samples) | |
log["samples"] = x_samples | |
if plot_denoise_rows: | |
if isinstance(z_denoise_row, dict): | |
for key in ['pred_x0', 'x_inter']: | |
z_denoise_row_key = z_denoise_row[key] | |
denoise_grid = self._get_denoise_row_from_list(z_denoise_row_key) | |
log[f"denoise_row_{key}"] = denoise_grid | |
else: | |
denoise_grid = self._get_denoise_row_from_list(z_denoise_row) | |
log["denoise_row"] = denoise_grid | |
if unconditional_guidance_scale > 1.0: | |
uc_context = self.get_unconditional_conditioning(N) | |
uc_global = self.get_unconditional_global_conditioning(c_global) | |
uc_local = c_local | |
uc_text = c_text | |
uncond_dict = dict( | |
local_control=[uc_local], | |
global_control=[uc_global], | |
c_crossattn=[uc_context], | |
text=[uc_text], | |
shape=[shape] | |
) | |
samples_cfg, _ = self.sample_log(cond=cond_dict, | |
batch_size=N, ddim=use_ddim, | |
ddim_steps=ddim_steps, eta=ddim_eta, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_conditioning=uncond_dict, | |
) | |
x_samples_cfg = self.decode_first_stage(samples_cfg) | |
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg | |
return log | |
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): | |
ddim_sampler = DDIMSampler(self) | |
if cond['shape'] is None: | |
h, w = 512, 512 | |
else: | |
h, w = cond["shape"][0][0] | |
shape = (self.channels, h // 8, w // 8) | |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) | |
return samples, intermediates | |
def configure_optimizers(self): | |
lr = self.learning_rate | |
params = list(self.q_former.parameters()) + list(self.local_adapter.parameters()) | |
if not self.sd_locked: | |
params += list(self.model.diffusion_model.output_blocks.parameters()) | |
params += list(self.model.diffusion_model.out.parameters()) | |
opt = torch.optim.AdamW(params, lr=lr) | |
return opt | |
def low_vram_shift(self, is_diffusing): | |
if is_diffusing: | |
self.model = self.model.cuda() | |
self.local_adapter = self.local_adapter.cuda() | |
self.first_stage_model = self.first_stage_model.cpu() | |
self.cond_stage_model = self.cond_stage_model.cpu() | |
else: | |
self.model = self.model.cpu() | |
self.local_adapter = self.local_adapter.cpu() | |
self.first_stage_model = self.first_stage_model.cuda() | |
self.cond_stage_model = self.cond_stage_model.cuda() | |