import torch from torch import nn from torch.nn import functional as F from functools import partial import numpy as np import typing as tp from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config from .dit import DiffusionTransformer from .factory import create_pretransform_from_config from .pretransforms import Pretransform from ..inference.generation import generate_diffusion_cond from .adp import UNetCFG1d, UNet1d from time import time class Profiler: def __init__(self): self.ticks = [[time(), None]] def tick(self, msg): self.ticks.append([time(), msg]) def __repr__(self): rep = 80 * "=" + "\n" for i in range(1, len(self.ticks)): msg = self.ticks[i][1] ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] rep += msg + f": {ellapsed*1000:.2f}ms\n" rep += 80 * "=" + "\n\n\n" return rep class DiffusionModel(nn.Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x, t, **kwargs): raise NotImplementedError() class DiffusionModelWrapper(nn.Module): def __init__( self, model: DiffusionModel, io_channels, sample_size, sample_rate, min_input_length, pretransform: tp.Optional[Pretransform] = None, ): super().__init__() self.io_channels = io_channels self.sample_size = sample_size self.sample_rate = sample_rate self.min_input_length = min_input_length self.model = model if pretransform is not None: self.pretransform = pretransform else: self.pretransform = None def forward(self, x, t, **kwargs): return self.model(x, t, **kwargs) class ConditionedDiffusionModel(nn.Module): def __init__(self, *args, supports_cross_attention: bool = False, supports_input_concat: bool = False, supports_global_cond: bool = False, supports_prepend_cond: bool = False, **kwargs): super().__init__(*args, **kwargs) self.supports_cross_attention = supports_cross_attention self.supports_input_concat = supports_input_concat self.supports_global_cond = supports_global_cond self.supports_prepend_cond = supports_prepend_cond def forward(self, x: torch.Tensor, t: torch.Tensor, cross_attn_cond: torch.Tensor = None, cross_attn_mask: torch.Tensor = None, input_concat_cond: torch.Tensor = None, global_embed: torch.Tensor = None, prepend_cond: torch.Tensor = None, prepend_cond_mask: torch.Tensor = None, cfg_scale: float = 1.0, cfg_dropout_prob: float = 0.0, batch_cfg: bool = False, rescale_cfg: bool = False, **kwargs): raise NotImplementedError() class ConditionedDiffusionModelWrapper(nn.Module): """ A diffusion model that takes in conditioning """ def __init__( self, model: ConditionedDiffusionModel, conditioner: MultiConditioner, io_channels, sample_rate, min_input_length: int, diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", pretransform: tp.Optional[Pretransform] = None, cross_attn_cond_ids: tp.List[str] = [], global_cond_ids: tp.List[str] = [], input_concat_ids: tp.List[str] = [], prepend_cond_ids: tp.List[str] = [], ): super().__init__() self.model = model self.conditioner = conditioner self.io_channels = io_channels self.sample_rate = sample_rate self.diffusion_objective = diffusion_objective self.pretransform = pretransform self.cross_attn_cond_ids = cross_attn_cond_ids self.global_cond_ids = global_cond_ids self.input_concat_ids = input_concat_ids self.prepend_cond_ids = prepend_cond_ids self.min_input_length = min_input_length def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): cross_attention_input = None cross_attention_masks = None global_cond = None input_concat_cond = None prepend_cond = None prepend_cond_mask = None if len(self.cross_attn_cond_ids) > 0: # Concatenate all cross-attention inputs over the sequence dimension # Assumes that the cross-attention inputs are of shape (batch, seq, channels) cross_attention_input = [] cross_attention_masks = [] for key in self.cross_attn_cond_ids: # added to check if llm_loss is present if len(conditioning_tensors[key]) == 2: cross_attn_in, cross_attn_mask = conditioning_tensors[key] elif len(conditioning_tensors[key]) == 3: cross_attn_in, cross_attn_mask, _ = conditioning_tensors[key] # Add sequence dimension if it's not there if len(cross_attn_in.shape) == 2: cross_attn_in = cross_attn_in.unsqueeze(1) cross_attn_mask = cross_attn_mask.unsqueeze(1) cross_attention_input.append(cross_attn_in) cross_attention_masks.append(cross_attn_mask) cross_attention_input = torch.cat(cross_attention_input, dim=1) cross_attention_masks = torch.cat(cross_attention_masks, dim=1) if len(self.global_cond_ids) > 0: # Concatenate all global conditioning inputs over the channel dimension # Assumes that the global conditioning inputs are of shape (batch, channels) global_conds = [] for key in self.global_cond_ids: global_cond_input = conditioning_tensors[key][0] global_conds.append(global_cond_input) # Concatenate over the channel dimension global_cond = torch.cat(global_conds, dim=-1) if len(global_cond.shape) == 3: global_cond = global_cond.squeeze(1) if len(self.input_concat_ids) > 0: # Concatenate all input concat conditioning inputs over the channel dimension # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq) input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1) if len(self.prepend_cond_ids) > 0: # Concatenate all prepend conditioning inputs over the sequence dimension # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) prepend_conds = [] prepend_cond_masks = [] for key in self.prepend_cond_ids: prepend_cond_input, prepend_cond_mask = conditioning_tensors[key] prepend_conds.append(prepend_cond_input) prepend_cond_masks.append(prepend_cond_mask) prepend_cond = torch.cat(prepend_conds, dim=1) prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1) if negative: return { "negative_cross_attn_cond": cross_attention_input, "negative_cross_attn_mask": cross_attention_masks, "negative_global_cond": global_cond, "negative_input_concat_cond": input_concat_cond } else: return { "cross_attn_cond": cross_attention_input, "cross_attn_mask": cross_attention_masks, "global_cond": global_cond, "input_concat_cond": input_concat_cond, "prepend_cond": prepend_cond, "prepend_cond_mask": prepend_cond_mask } def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs) def generate(self, *args, **kwargs): return generate_diffusion_cond(self, *args, **kwargs) class UNetCFG1DWrapper(ConditionedDiffusionModel): def __init__( self, *args, **kwargs ): super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) self.model = UNetCFG1d(*args, **kwargs) with torch.no_grad(): for param in self.model.parameters(): param *= 0.5 def forward(self, x, t, cross_attn_cond=None, cross_attn_mask=None, input_concat_cond=None, global_cond=None, cfg_scale=1.0, cfg_dropout_prob: float = 0.0, batch_cfg: bool = False, rescale_cfg: bool = False, negative_cross_attn_cond=None, negative_cross_attn_mask=None, negative_global_cond=None, negative_input_concat_cond=None, prepend_cond=None, prepend_cond_mask=None, **kwargs): p = Profiler() p.tick("start") channels_list = None if input_concat_cond is not None: channels_list = [input_concat_cond] outputs = self.model( x, t, embedding=cross_attn_cond, embedding_mask=cross_attn_mask, features=global_cond, channels_list=channels_list, embedding_scale=cfg_scale, embedding_mask_proba=cfg_dropout_prob, batch_cfg=batch_cfg, rescale_cfg=rescale_cfg, negative_embedding=negative_cross_attn_cond, negative_embedding_mask=negative_cross_attn_mask, **kwargs) p.tick("UNetCFG1D forward") #print(f"Profiler: {p}") return outputs class UNet1DCondWrapper(ConditionedDiffusionModel): def __init__( self, *args, **kwargs ): super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) self.model = UNet1d(*args, **kwargs) with torch.no_grad(): for param in self.model.parameters(): param *= 0.5 def forward(self, x, t, input_concat_cond=None, global_cond=None, cross_attn_cond=None, cross_attn_mask=None, prepend_cond=None, prepend_cond_mask=None, cfg_scale=1.0, cfg_dropout_prob: float = 0.0, batch_cfg: bool = False, rescale_cfg: bool = False, negative_cross_attn_cond=None, negative_cross_attn_mask=None, negative_global_cond=None, negative_input_concat_cond=None, **kwargs): channels_list = None if input_concat_cond is not None: # Interpolate input_concat_cond to the same length as x if input_concat_cond.shape[2] != x.shape[2]: input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') channels_list = [input_concat_cond] outputs = self.model( x, t, features=global_cond, channels_list=channels_list, **kwargs) return outputs class UNet1DUncondWrapper(DiffusionModel): def __init__( self, in_channels, *args, **kwargs ): super().__init__() self.model = UNet1d(in_channels=in_channels, *args, **kwargs) self.io_channels = in_channels with torch.no_grad(): for param in self.model.parameters(): param *= 0.5 def forward(self, x, t, **kwargs): return self.model(x, t, **kwargs) class DAU1DCondWrapper(ConditionedDiffusionModel): def __init__( self, *args, **kwargs ): super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) self.model = DiffusionAttnUnet1D(*args, **kwargs) with torch.no_grad(): for param in self.model.parameters(): param *= 0.5 def forward(self, x, t, input_concat_cond=None, cross_attn_cond=None, cross_attn_mask=None, global_cond=None, cfg_scale=1.0, cfg_dropout_prob: float = 0.0, batch_cfg: bool = False, rescale_cfg: bool = False, negative_cross_attn_cond=None, negative_cross_attn_mask=None, negative_global_cond=None, negative_input_concat_cond=None, prepend_cond=None, **kwargs): return self.model(x, t, cond = input_concat_cond) class DiffusionAttnUnet1D(nn.Module): def __init__( self, io_channels = 2, depth=14, n_attn_layers = 6, channels = [128, 128, 256, 256] + [512] * 10, cond_dim = 0, cond_noise_aug = False, kernel_size = 5, learned_resample = False, strides = [2] * 13, conv_bias = True, use_snake = False ): super().__init__() self.cond_noise_aug = cond_noise_aug self.io_channels = io_channels if self.cond_noise_aug: self.rng = torch.quasirandom.SobolEngine(1, scramble=True) self.timestep_embed = FourierFeatures(1, 16) attn_layer = depth - n_attn_layers strides = [1] + strides block = nn.Identity() conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) for i in range(depth, 0, -1): c = channels[i - 1] stride = strides[i-1] if stride > 2 and not learned_resample: raise ValueError("Must have stride 2 without learned resampling") if i > 1: c_prev = channels[i - 2] add_attn = i >= attn_layer and n_attn_layers > 0 block = SkipBlock( Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), conv_block(c_prev, c, c), SelfAttention1d( c, c // 32) if add_attn else nn.Identity(), conv_block(c, c, c), SelfAttention1d( c, c // 32) if add_attn else nn.Identity(), conv_block(c, c, c), SelfAttention1d( c, c // 32) if add_attn else nn.Identity(), block, conv_block(c * 2 if i != depth else c, c, c), SelfAttention1d( c, c // 32) if add_attn else nn.Identity(), conv_block(c, c, c), SelfAttention1d( c, c // 32) if add_attn else nn.Identity(), conv_block(c, c, c_prev), SelfAttention1d(c_prev, c_prev // 32) if add_attn else nn.Identity(), Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") ) else: cond_embed_dim = 16 if not self.cond_noise_aug else 32 block = nn.Sequential( conv_block((io_channels + cond_dim) + cond_embed_dim, c, c), conv_block(c, c, c), conv_block(c, c, c), block, conv_block(c * 2, c, c), conv_block(c, c, c), conv_block(c, c, io_channels, is_last=True), ) self.net = block with torch.no_grad(): for param in self.net.parameters(): param *= 0.5 def forward(self, x, t, cond=None, cond_aug_scale=None): timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape) inputs = [x, timestep_embed] if cond is not None: if cond.shape[2] != x.shape[2]: cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) if self.cond_noise_aug: # Get a random number between 0 and 1, uniformly sampled if cond_aug_scale is None: aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond) else: aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond) # Add noise to the conditioning signal cond = cond + torch.randn_like(cond) * aug_level[:, None, None] # Get embedding for noise cond level, reusing timestamp_embed aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape) inputs.append(aug_level_embed) inputs.append(cond) outputs = self.net(torch.cat(inputs, dim=1)) return outputs class DiTWrapper(ConditionedDiffusionModel): def __init__( self, *args, **kwargs ): super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) self.model = DiffusionTransformer(*args, **kwargs) with torch.no_grad(): for param in self.model.parameters(): param *= 0.5 def forward(self, x, t, cross_attn_cond=None, cross_attn_mask=None, negative_cross_attn_cond=None, negative_cross_attn_mask=None, input_concat_cond=None, negative_input_concat_cond=None, global_cond=None, negative_global_cond=None, prepend_cond=None, prepend_cond_mask=None, cfg_scale=1.0, cfg_dropout_prob: float = 0.0, batch_cfg: bool = True, rescale_cfg: bool = False, scale_phi: float = 0.0, **kwargs): assert batch_cfg, "batch_cfg must be True for DiTWrapper" #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" return self.model( x, t, cross_attn_cond=cross_attn_cond, cross_attn_cond_mask=cross_attn_mask, negative_cross_attn_cond=negative_cross_attn_cond, negative_cross_attn_mask=negative_cross_attn_mask, input_concat_cond=input_concat_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, cfg_scale=cfg_scale, cfg_dropout_prob=cfg_dropout_prob, scale_phi=scale_phi, global_embed=global_cond, **kwargs) class DiTUncondWrapper(DiffusionModel): def __init__( self, in_channels, *args, **kwargs ): super().__init__() self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs) self.io_channels = in_channels with torch.no_grad(): for param in self.model.parameters(): param *= 0.5 def forward(self, x, t, **kwargs): return self.model(x, t, **kwargs) def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): diffusion_uncond_config = config["model"] model_type = diffusion_uncond_config.get('type', None) diffusion_config = diffusion_uncond_config.get('config', {}) assert model_type is not None, "Must specify model type in config" pretransform = diffusion_uncond_config.get("pretransform", None) sample_size = config.get("sample_size", None) assert sample_size is not None, "Must specify sample size in config" sample_rate = config.get("sample_rate", None) assert sample_rate is not None, "Must specify sample rate in config" if pretransform is not None: pretransform = create_pretransform_from_config(pretransform, sample_rate) min_input_length = pretransform.downsampling_ratio else: min_input_length = 1 if model_type == 'DAU1d': model = DiffusionAttnUnet1D( **diffusion_config ) elif model_type == "adp_uncond_1d": model = UNet1DUncondWrapper( **diffusion_config ) elif model_type == "dit": model = DiTUncondWrapper( **diffusion_config ) else: raise NotImplementedError(f'Unknown model type: {model_type}') return DiffusionModelWrapper(model, io_channels=model.io_channels, sample_size=sample_size, sample_rate=sample_rate, pretransform=pretransform, min_input_length=min_input_length) def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): model_config = config["model"] model_type = config["model_type"] diffusion_config = model_config.get('diffusion', None) assert diffusion_config is not None, "Must specify diffusion config" diffusion_model_type = diffusion_config.get('type', None) assert diffusion_model_type is not None, "Must specify diffusion model type" diffusion_model_config = diffusion_config.get('config', None) assert diffusion_model_config is not None, "Must specify diffusion model config" if diffusion_model_type == 'adp_cfg_1d': diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) elif diffusion_model_type == 'adp_1d': diffusion_model = UNet1DCondWrapper(**diffusion_model_config) elif diffusion_model_type == 'dit': diffusion_model = DiTWrapper(**diffusion_model_config) io_channels = model_config.get('io_channels', None) assert io_channels is not None, "Must specify io_channels in model config" sample_rate = config.get('sample_rate', None) assert sample_rate is not None, "Must specify sample_rate in config" diffusion_objective = diffusion_config.get('diffusion_objective', 'v') conditioning_config = model_config.get('conditioning', None) conditioner = None if conditioning_config is not None: conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) global_cond_ids = diffusion_config.get('global_cond_ids', []) input_concat_ids = diffusion_config.get('input_concat_ids', []) prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) pretransform = model_config.get("pretransform", None) if pretransform is not None: pretransform = create_pretransform_from_config(pretransform, sample_rate) min_input_length = pretransform.downsampling_ratio else: min_input_length = 1 if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": min_input_length *= np.prod(diffusion_model_config["factors"]) elif diffusion_model_type == "dit": min_input_length *= diffusion_model.model.patch_size # Get the proper wrapper class extra_kwargs = {} if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint": wrapper_fn = ConditionedDiffusionModelWrapper extra_kwargs["diffusion_objective"] = diffusion_objective elif model_type == "diffusion_prior": prior_type = model_config.get("prior_type", None) assert prior_type is not None, "Must specify prior_type in diffusion prior model config" if prior_type == "mono_stereo": from .diffusion_prior import MonoToStereoDiffusionPrior wrapper_fn = MonoToStereoDiffusionPrior return wrapper_fn( diffusion_model, conditioner, min_input_length=min_input_length, sample_rate=sample_rate, cross_attn_cond_ids=cross_attention_ids, global_cond_ids=global_cond_ids, input_concat_ids=input_concat_ids, prepend_cond_ids=prepend_cond_ids, pretransform=pretransform, io_channels=io_channels, **extra_kwargs )