Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # References: | |
| # GLIDE: https://github.com/openai/glide-text2im | |
| # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py | |
| # -------------------------------------------------------- | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import math | |
| import collections.abc | |
| from itertools import repeat | |
| from ldm.modules.new_attention import PositionEmbedding | |
| from einops import rearrange | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
| def to_2tuple(x): | |
| if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | |
| return x | |
| return tuple(repeat(x, 2)) | |
| ################################################################ | |
| # Embedding Layers for Timesteps # | |
| ################################################################ | |
| class TimestepEmbedder(nn.Module): | |
| """ | |
| Embeds scalar timesteps into vector representations. | |
| """ | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.proj_w = nn.Linear(frequency_embedding_size,frequency_embedding_size,bias=False) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param t: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an (N, D) Tensor of positional embeddings. | |
| """ | |
| # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t, w_cond=None): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| if w_cond is not None: | |
| t_freq = t_freq + self.proj_w(w_cond) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| class Conv1DFinalLayer(nn.Module): | |
| """ | |
| The final layer of CrossAttnDiT. | |
| """ | |
| def __init__(self, hidden_size, out_channels): | |
| super().__init__() | |
| self.norm_final = nn.GroupNorm(16,hidden_size) | |
| self.conv1d = nn.Conv1d(hidden_size, out_channels,kernel_size=1) | |
| def forward(self, x): # x:(B,C,T) | |
| x = self.norm_final(x) | |
| x = self.conv1d(x) | |
| return x | |
| class ConditionEmbedder(nn.Module): | |
| def __init__(self, hidden_size, context_dim): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(context_dim, hidden_size, bias=True), | |
| nn.GELU(approximate='tanh'), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| nn.LayerNorm(hidden_size) | |
| ) | |
| def forward(self,x): | |
| return self.mlp(x) | |
| from ldm.modules.new_attention import CrossAttention,Conv1dFeedForward,checkpoint,Normalize,zero_module | |
| class BasicTransformerBlock(nn.Module): | |
| def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True): # 1 self 1 cross or 2 self | |
| super().__init__() | |
| self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention,if context is none | |
| self.ff = Conv1dFeedForward(dim, dropout=dropout, glu=gated_ff) | |
| self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # use as cross attention | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.norm3 = nn.LayerNorm(dim) | |
| self.checkpoint = checkpoint | |
| def forward(self, x): | |
| return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) | |
| def _forward(self, x):# x shape:(B,T,C) | |
| x = self.attn1(self.norm1(x)) + x | |
| x = self.attn2(self.norm2(x)) + x | |
| x = self.ff(self.norm3(x).permute(0,2,1)).permute(0,2,1) + x | |
| return x | |
| class TemporalTransformer(nn.Module): | |
| """ | |
| Transformer block for image-like data. | |
| First, project the input (aka embedding) | |
| and reshape to b, t, d. | |
| Then apply standard transformer action. | |
| Finally, reshape to image | |
| """ | |
| def __init__(self, in_channels, n_heads, d_head, | |
| depth=1, dropout=0., context_dim=None): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| inner_dim = n_heads * d_head | |
| self.norm = Normalize(in_channels) | |
| self.proj_in = nn.Conv1d(in_channels, | |
| inner_dim, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| self.transformer_blocks = nn.ModuleList( | |
| [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout) | |
| for d in range(depth)] | |
| ) | |
| self.proj_out = zero_module(nn.Conv1d(inner_dim, | |
| in_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0))# initialize with zero | |
| def forward(self, x):# x shape (b,c,t) | |
| # note: if no context is given, cross-attention defaults to self-attention | |
| x_in = x | |
| x = self.norm(x)# group norm | |
| x = self.proj_in(x)# no shape change | |
| x = rearrange(x,'b c t -> b t c') | |
| for block in self.transformer_blocks: | |
| x = block(x)# context shape [b,seq_len=77,context_dim] | |
| x = rearrange(x,'b t c -> b c t') | |
| x = self.proj_out(x) | |
| x = x + x_in | |
| return x | |
| class ConcatDiT(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| context_dim, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| max_len = 1000, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels # vae dim | |
| self.out_channels = in_channels | |
| self.num_heads = num_heads | |
| kernel_size = 5 | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.c_embedder = ConditionEmbedder(hidden_size,context_dim) | |
| self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2) | |
| self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size) | |
| self.blocks = nn.ModuleList([ | |
| TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth) | |
| ]) | |
| self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): # | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| def forward(self, x, t, context, w_cond=None): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, T) tensor of temporal inputs (latent representations of melspec) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,max_tokens_len=77, context_dim) | |
| """ | |
| t = self.t_embedder(t, w_cond=w_cond).unsqueeze(1) # (N,1,hidden_size) | |
| c = self.c_embedder(context) # (N,c_len,hidden_size) | |
| extra_len = c.shape[1] + 1 | |
| x = self.proj_in(x) | |
| x = rearrange(x,'b c t -> b t c') | |
| x = torch.concat([t,c,x],dim=1) | |
| x = self.pos_emb(x) | |
| x = rearrange(x,'b t c -> b c t') | |
| for block in self.blocks: | |
| x = block(x) # (N, D, extra_len+T) | |
| x = x[...,extra_len:] # (N,D,T) | |
| x = self.final_layer(x) # (N, out_channels,T) | |
| return x | |
| class ConcatDiT2MLP(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| context_dim, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| max_len = 1000, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels # vae dim | |
| self.out_channels = in_channels | |
| self.num_heads = num_heads | |
| kernel_size = 5 | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.c1_embedder = ConditionEmbedder(hidden_size,context_dim) | |
| self.c2_embedder = ConditionEmbedder(hidden_size,context_dim) | |
| self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2) | |
| self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size) | |
| self.blocks = nn.ModuleList([ | |
| TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth) | |
| ]) | |
| self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): # | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| def forward(self, x, t, context, w_cond=None): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, T) tensor of temporal inputs (latent representations of melspec) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,max_tokens_len=77, context_dim) | |
| """ | |
| t = self.t_embedder(t, w_cond=w_cond).unsqueeze(1) # (N,1,hidden_size) | |
| c1,c2 = context.chunk(2,dim=1) | |
| c1 = self.c1_embedder(c1) # (N,c_len,hidden_size) | |
| c2 = self.c2_embedder(c2) # (N,c_len,hidden_size) | |
| c = torch.cat((c1,c2),dim=1) | |
| extra_len = c.shape[1] + 1 | |
| x = self.proj_in(x) | |
| x = rearrange(x,'b c t -> b t c') | |
| x = torch.concat([t,c,x],dim=1) | |
| x = self.pos_emb(x) | |
| x = rearrange(x,'b t c -> b c t') | |
| for block in self.blocks: | |
| x = block(x) # (N, D, extra_len+T) | |
| x = x[...,extra_len:] # (N,D,T) | |
| x = self.final_layer(x) # (N, out_channels,T) | |
| return x | |
| class ConcatOrderDiT(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| context_dim, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| max_len = 1000, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels # vae dim | |
| self.out_channels = in_channels | |
| self.num_heads = num_heads | |
| kernel_size = 5 | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.c_embedder = ConditionEmbedder(hidden_size,context_dim) | |
| self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2) | |
| self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size) | |
| self.order_embedding = nn.Embedding(num_embeddings=100,embedding_dim = hidden_size) | |
| self.blocks = nn.ModuleList([ | |
| TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth) | |
| ]) | |
| self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): # | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| def add_order_embedding(self,token_emb,token_ids,orders_list): | |
| """ | |
| token_emb: shape (N,max_tokens_len=77, hidden_size) | |
| token_ids: shape (N,max_tokens) | |
| order_list: [N*list]. len(order_list[i]) == objs_num in text[i] | |
| """ | |
| for b,orderl in enumerate(orders_list): | |
| orderl = torch.LongTensor(orderl).to(device=self.order_embedding.weight.device) | |
| order_emb = self.order_embedding(orderl) | |
| obj2index = [] | |
| cur_obj = 0 | |
| for i in range(token_ids.shape[1]):# max_length | |
| token_id = token_ids[b][i] | |
| if token_id in [101,102,0,1064]: # <start>,<eos>,<pad>,<|> . if another Tokenizer is used, this should be changed | |
| obj2index.append(-1) | |
| if token_id == 1064: | |
| cur_obj += 1 | |
| else: | |
| obj2index.append(cur_obj) | |
| for i,order_index in enumerate(obj2index): | |
| if order_index != -1: | |
| token_emb[b][i] += order_emb[order_index] | |
| return token_emb | |
| def forward(self, x, t, context): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, T) tensor of temporal inputs (latent representations of melspec) | |
| t: (N,) tensor of diffusion timesteps | |
| context: dict{'token_embedding':(N,max_tokens_len=77, context_dim),'token_ids':tokens:(N,max_tokens_len=77),'orders':orders_list} | |
| """ | |
| token_embedding = context['token_embedding'] | |
| token_ids = context['token_ids'] | |
| orders = context['orders'] | |
| t = self.t_embedder(t).unsqueeze(1) # (N,1,hidden_size) | |
| c = self.c_embedder(token_embedding) # (N,c_len,hidden_size) | |
| c = self.add_order_embedding(c,token_ids,orders) | |
| extra_len = c.shape[1] + 1 | |
| x = self.proj_in(x) | |
| x = rearrange(x,'b c t -> b t c') | |
| x = torch.concat([t,c,x],dim=1) | |
| x = self.pos_emb(x) | |
| x = rearrange(x,'b t c -> b c t') | |
| for block in self.blocks: | |
| x = block(x) # (N, D, extra_len+T) | |
| x = x[...,extra_len:] # (N,D,T) | |
| x = self.final_layer(x) # (N, out_channels,T) | |
| return x | |
| class ConcatOrderDiT2(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. concat by token | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| context_dim, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| max_len = 1000, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels # vae dim | |
| self.out_channels = in_channels | |
| self.num_heads = num_heads | |
| kernel_size = 5 | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.c_embedder = ConditionEmbedder(hidden_size,context_dim) | |
| self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2) | |
| self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size) | |
| self.max_objs = 10 | |
| self.max_objs_order = 100 | |
| self.order_embedding = nn.Embedding(num_embeddings=self.max_objs_order + 1,embedding_dim = hidden_size) | |
| self.blocks = nn.ModuleList([ | |
| TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth) | |
| ]) | |
| self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): # | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| def concat_order_embedding(self,token_emb,token_ids,orders_list): | |
| """ | |
| token_emb: shape (N,max_tokens_len=77, hidden_size) | |
| token_ids: shape (N,max_tokens) | |
| order_list: [N*list]. len(order_list[i]) == objs_num in text[i] | |
| return token_emb: shape (N,max_tokens_len+self.max_objs, hidden_size) | |
| """ | |
| bsz,t,c = token_emb.shape | |
| token_emb = list(torch.tensor_split(token_emb,bsz))# token_emb[i] shape (1,t,c) | |
| orders_list = deepcopy(orders_list) # avoid inplace modification | |
| for i in range(bsz): | |
| token_emb[i] = list(torch.tensor_split(token_emb[i].squeeze(0),t))# token_emb[i][j] shape(1,c) | |
| for b,orderl in enumerate(orders_list): | |
| orderl.append(self.max_objs_order)# the last is for pad | |
| orderl = torch.LongTensor(orderl).to(device=self.order_embedding.weight.device) | |
| order_emb = self.order_embedding(orderl)# shape(len(orderl),hidden_size) | |
| order_emb = torch.tensor_split(order_emb,len(orderl))# order_emb[i] shape (1,hidden_size) | |
| obj_insert_index = [] | |
| for i in range(token_ids.shape[1]):# max_length | |
| token_id = token_ids[b][i] | |
| if token_id == 1064: # <|> after each word . if another Tokenizer is used, this should be changed | |
| obj_insert_index.append(i+len(obj_insert_index)) | |
| for i,index in enumerate(obj_insert_index): | |
| token_emb[b].insert(index,order_emb[i]) | |
| #print(f"len1:{len(token_emb[b])}") | |
| for i in range(self.max_objs-len(orderl)+1): | |
| token_emb[b].append(order_emb[-1])# pad to max_tokens_len+self.max_objs | |
| token_emb[b] = torch.concat(token_emb[b])# shape:(max_tokens_len+self.max_objs,hidden_size) | |
| #print(f"tokenemb shape:{token_emb[b].shape}") | |
| token_emb = torch.stack(token_emb) | |
| return token_emb | |
| def forward(self, x, t, context): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, T) tensor of temporal inputs (latent representations of melspec) | |
| t: (N,) tensor of diffusion timesteps | |
| context: dict{'token_embedding':(N,max_tokens_len=77, context_dim),'token_ids':tokens:(N,max_tokens_len=77),'orders':orders_list} | |
| """ | |
| token_embedding = context['token_embedding'] | |
| token_ids = context['token_ids'] | |
| orders = context['orders'] | |
| t = self.t_embedder(t).unsqueeze(1) # (N,1,hidden_size) | |
| c = self.c_embedder(token_embedding) # (N,c_len,hidden_size) | |
| c = self.concat_order_embedding(c,token_ids,orders) | |
| extra_len = c.shape[1] + 1 | |
| x = self.proj_in(x) | |
| x = rearrange(x,'b c t -> b t c') | |
| x = torch.concat([t,c,x],dim=1) | |
| x = self.pos_emb(x) | |
| x = rearrange(x,'b t c -> b c t') | |
| for block in self.blocks: | |
| x = block(x) # (N, D, extra_len+T) | |
| x = x[...,extra_len:] # (N,D,T) | |
| x = self.final_layer(x) # (N, out_channels,T) | |
| return x | |