Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,834 Bytes
56a1295 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
from torch import nn
import math
from modules.v2.dit_model import ModelArgs, Transformer
from modules.commons import sequence_mask
from torch.nn.utils import weight_norm
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000, scale=1000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = scale * t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class DiT(torch.nn.Module):
def __init__(
self,
time_as_token,
style_as_token,
uvit_skip_connection,
block_size,
depth,
num_heads,
hidden_dim,
in_channels,
content_dim,
style_encoder_dim,
class_dropout_prob,
dropout_rate,
attn_dropout_rate,
):
super(DiT, self).__init__()
self.time_as_token = time_as_token
self.style_as_token = style_as_token
self.uvit_skip_connection = uvit_skip_connection
model_args = ModelArgs(
block_size=block_size,
n_layer=depth,
n_head=num_heads,
dim=hidden_dim,
head_dim=hidden_dim // num_heads,
vocab_size=1, # we don't use this
uvit_skip_connection=self.uvit_skip_connection,
time_as_token=self.time_as_token,
dropout_rate=dropout_rate,
attn_dropout_rate=attn_dropout_rate,
)
self.transformer = Transformer(model_args)
self.in_channels = in_channels
self.out_channels = in_channels
self.num_heads = num_heads
self.x_embedder = weight_norm(nn.Linear(in_channels, hidden_dim, bias=True))
self.content_dim = content_dim # for continuous content
self.cond_projection = nn.Linear(content_dim, hidden_dim, bias=True) # continuous content
self.t_embedder = TimestepEmbedder(hidden_dim)
self.final_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, in_channels),
)
self.class_dropout_prob = class_dropout_prob
self.cond_x_merge_linear = nn.Linear(hidden_dim + in_channels + in_channels, hidden_dim)
self.style_in = nn.Linear(style_encoder_dim, hidden_dim)
def forward(self, x, prompt_x, x_lens, t, style, cond):
class_dropout = False
content_dropout = False
if self.training and torch.rand(1) < self.class_dropout_prob:
class_dropout = True
if self.training and torch.rand(1) < 0.5:
content_dropout = True
cond_in_module = self.cond_projection
B, _, T = x.size()
t1 = self.t_embedder(t) # (N, D)
cond = cond_in_module(cond)
x = x.transpose(1, 2)
prompt_x = prompt_x.transpose(1, 2)
x_in = torch.cat([x, prompt_x, cond], dim=-1)
if class_dropout:
x_in[..., self.in_channels:self.in_channels*2] = 0
if content_dropout:
x_in[..., self.in_channels*2:] = 0
x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
style = self.style_in(style)
style = torch.zeros_like(style) if class_dropout else style
if self.style_as_token:
x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
if self.time_as_token:
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token, max_length=x_in.size(1)).to(x.device).unsqueeze(1)
input_pos = torch.arange(x_in.size(1)).to(x.device)
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1)
x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded)
x_res = x_res[:, 1:] if self.time_as_token else x_res
x_res = x_res[:, 1:] if self.style_as_token else x_res
x = self.final_mlp(x_res)
x = x.transpose(1, 2)
return x
|