TTTS / ttts /diffusion /diffusion_util.py
mrfakename's picture
Add source code
4ee33aa
import json
import os
from pathlib import Path
from datetime import datetime
from matplotlib import pyplot as plt
from ttts.unet1d.embeddings import TextTimeEmbedding
from ttts.unet1d.unet_1d_condition import UNet1DConditionModel
from vocos import Vocos
from torch import expm1, nn
import ttts.diffusion.commons as commons
from accelerate import Accelerator
from ttts.diffusion.operations import OPERATIONS_ENCODER
from accelerate import DistributedDataParallelKwargs
import math
from multiprocessing import cpu_count
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from torch.utils.tensorboard import SummaryWriter
import logging
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from tqdm.auto import tqdm
TACOTRON_MEL_MAX = 5.5451774444795624753378569716654
TACOTRON_MEL_MIN = -16.118095650958319788125940182791
# TACOTRON_MEL_MIN = -11.512925464970228420089957273422
# -16.118095650958319788125940182791
def denormalize_tacotron_mel(norm_mel):
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
def exists(x):
return x is not None
def cycle(dl):
while True:
for data in dl:
yield data
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TransformerEncoderLayer(nn.Module):
def __init__(self, layer, hidden_size, dropout):
super().__init__()
self.layer = layer
self.hidden_size = hidden_size
self.dropout = dropout
self.op = OPERATIONS_ENCODER[layer](hidden_size, dropout)
def forward(self, x, **kwargs):
return self.op(x, **kwargs)
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class ConvTBC(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
super(ConvTBC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = torch.nn.Parameter(torch.Tensor(
self.kernel_size, in_channels, out_channels))
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
def forward(self, input):
return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
class ConvLayer(nn.Module):
def __init__(self, c_in, c_out, kernel_size, dropout=0):
super().__init__()
self.layer_norm = LayerNorm(c_in)
conv = ConvTBC(c_in, c_out, kernel_size, padding=kernel_size // 2)
std = math.sqrt((4 * (1.0 - dropout)) / (kernel_size * c_in))
nn.init.normal_(conv.weight, mean=0, std=std)
nn.init.constant_(conv.bias, 0)
self.conv = conv
def forward(self, x, encoder_padding_mask=None, **kwargs):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm.training = layer_norm_training
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.t().unsqueeze(-1), 0)
x = self.layer_norm(x)
x = self.conv(x)
return x
class PhoneEncoder(nn.Module):
def __init__(self,
in_channels=128,
hidden_channels=512,
out_channels=512,
n_layers=6,
p_dropout=0.2,
last_ln = True):
super().__init__()
self.arch = [8 for _ in range(n_layers)]
self.num_layers = n_layers
self.hidden_size = hidden_channels
self.padding_idx = 0
self.dropout = p_dropout
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(self.arch[i], self.hidden_size, self.dropout)
for i in range(self.num_layers)
])
self.last_ln = last_ln
self.pre = ConvLayer(in_channels, hidden_channels, 1, p_dropout)
# self.prompt_proj = ConvLayer(in_channels, hidden_channels, 1, p_dropout)
self.out_proj = ConvLayer(hidden_channels, out_channels, 1, p_dropout)
if last_ln:
self.layer_norm = LayerNorm(out_channels)
self.spk_proj = nn.Conv1d(100,hidden_channels,1)
def forward(self, src_tokens, lengths, g=None):
# B x C x T -> T x B x C
src_tokens = self.spk_proj(src_tokens+g)
src_tokens = rearrange(src_tokens, 'b c t -> t b c')
# compute padding mask
encoder_padding_mask = ~commons.sequence_mask(lengths, src_tokens.size(0)).to(torch.bool)
# prompt_mask = ~commons.sequence_mask(prompt_lengths, prompt.size(0)).to(torch.bool)
x = src_tokens
x = self.pre(x, encoder_padding_mask=encoder_padding_mask)
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
# prompt = self.prompt_proj(prompt, encoder_padding_mask=prompt_mask)
# encoder layers
for i in range(self.num_layers):
x = self.layers[i](x, encoder_padding_mask=encoder_padding_mask)
# x = x+self.attn_blocks[i](x, prompt, prompt, key_padding_mask=prompt_mask)[0]
x = self.out_proj(x, encoder_padding_mask=encoder_padding_mask)
if self.last_ln:
x = self.layer_norm(x)
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
x = rearrange(x, 't b c-> b c t')
return x
class PromptEncoder(nn.Module):
def __init__(self,
in_channels=128,
hidden_channels=256,
out_channels=512,
n_layers=6,
p_dropout=0.2,
last_ln = True):
super().__init__()
self.arch = [8 for _ in range(n_layers)]
self.num_layers = n_layers
self.hidden_size = hidden_channels
self.padding_idx = 0
self.dropout = p_dropout
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(self.arch[i], self.hidden_size, self.dropout)
for i in range(self.num_layers)
])
self.last_ln = last_ln
if last_ln:
self.layer_norm = LayerNorm(out_channels)
self.pre = ConvLayer(in_channels, hidden_channels, 1, p_dropout)
self.out_proj = ConvLayer(hidden_channels, out_channels, 1, p_dropout)
def forward(self, src_tokens, lengths=None):
# B x C x T -> T x B x C
src_tokens = rearrange(src_tokens, 'b c t -> t b c')
# compute padding mask
encoder_padding_mask = ~commons.sequence_mask(lengths, src_tokens.size(0)).to(torch.bool)
x = src_tokens
x = self.pre(x, encoder_padding_mask=encoder_padding_mask)
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask=encoder_padding_mask)
x = self.out_proj(x, encoder_padding_mask=encoder_padding_mask)
if self.last_ln:
x = self.layer_norm(x)
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
x = rearrange(x, 't b c-> b c t')
return x
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
@torch.jit.script
def silu(x):
return x * torch.sigmoid(x)
class ResidualBlock(nn.Module):
def __init__(self, n_mels, residual_channels, dilation, kernel_size, dropout):
'''
:param n_mels: inplanes of conv1x1 for spectrogram conditional
:param residual_channels: audio conv
:param dilation: audio conv dilation
:param uncond: disable spectrogram conditional
'''
super().__init__()
if dilation==1:
padding = kernel_size//2
else:
padding = dilation
self.dilated_conv = ConvLayer(residual_channels, 2 * residual_channels, kernel_size)
self.conditioner_projection = ConvLayer(n_mels, 2 * residual_channels, 1)
# self.output_projection = ConvLayer(residual_channels, 2 * residual_channels, 1)
self.output_projection = ConvLayer(residual_channels, residual_channels, 1)
self.t_proj = ConvLayer(residual_channels, residual_channels, 1)
self.drop = nn.Dropout(dropout)
def forward(self, x, diffusion_step, conditioner,x_mask):
assert (conditioner is None and self.conditioner_projection is None) or \
(conditioner is not None and self.conditioner_projection is not None)
#T B C
y = x + self.t_proj(diffusion_step.unsqueeze(0))
y = y.masked_fill(x_mask.t().unsqueeze(-1), 0)
conditioner = self.conditioner_projection(conditioner)
conditioner = self.drop(conditioner)
y = self.dilated_conv(y) + conditioner
y = y.masked_fill(x_mask.t().unsqueeze(-1), 0)
gate, filter_ = torch.chunk(y, 2, dim=-1)
y = torch.sigmoid(gate) * torch.tanh(filter_)
y = y.masked_fill(x_mask.t().unsqueeze(-1), 0)
y = self.output_projection(y)
return y
# y = y.masked_fill(x_mask.t().unsqueeze(-1), 0)
# residual, skip = torch.chunk(y, 2, dim=-1)
# return (x + residual) / math.sqrt(2.0), skip
class Pre_model(nn.Module):
def __init__(self, cfg) -> None:
super().__init__()
self.cfg = cfg
self.phoneme_encoder = PhoneEncoder(**self.cfg['phoneme_encoder'])
print("phoneme params:", count_parameters(self.phoneme_encoder))
self.prompt_encoder = PromptEncoder(**self.cfg['prompt_encoder'])
print("prompt params:", count_parameters(self.prompt_encoder))
dim = self.cfg['phoneme_encoder']['out_channels']
self.ref_enc = TextTimeEmbedding(100, 100, 1)
def forward(self,data, g=None):
mel_recon_padded, mel_padded, mel_lengths, refer_padded, refer_lengths = data
mel_recon_padded, refer_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(refer_padded)
g = self.ref_enc(refer_padded.transpose(1,2)).unsqueeze(-1)
audio_prompt = self.prompt_encoder(refer_padded,refer_lengths)
content = self.phoneme_encoder(mel_recon_padded, mel_lengths, g)
return content, audio_prompt
def infer(self, data):
mel_recon_padded, refer_padded, mel_lengths, refer_lengths = data
mel_recon_padded, refer_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(refer_padded)
g = self.ref_enc(refer_padded.transpose(1,2)).unsqueeze(-1)
audio_prompt = self.prompt_encoder(refer_padded,refer_lengths)
content = self.phoneme_encoder(mel_recon_padded, mel_lengths, g)
return content, audio_prompt
class Diffusion_Encoder(nn.Module):
def __init__(self,
in_channels=128,
out_channels=128,
hidden_channels=256,
block_out_channels = [128,256,384,512],
n_heads=8,
p_dropout=0.2,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.n_heads=n_heads
self.unet = UNet1DConditionModel(
in_channels=in_channels+hidden_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
norm_num_groups=8,
cross_attention_dim=hidden_channels,
attention_head_dim=n_heads,
addition_embed_type='text',
resnet_time_scale_shift='scale_shift',
)
def forward(self, x, data, t):
assert torch.isnan(x).any() == False
contentvec, prompt, contentvec_lengths, prompt_lengths = data
prompt = rearrange(prompt,' b c t-> b t c')
x = torch.cat([x, contentvec], dim=1)
prompt_mask = commons.sequence_mask(prompt_lengths, prompt.size(1)).to(torch.bool)
x = self.unet(x, t, prompt, encoder_attention_mask=prompt_mask)
return x.sample
# tensor helper functions
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def linear_beta_schedule(timesteps):
"""
linear schedule, proposed in original ddpm paper
"""
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
class Diffuser(nn.Module):
def __init__(self,
cfg,
ddim_sampling_eta = 0,
min_snr_loss_weight = False,
min_snr_gamma = 5,
conditioning_free = True,
conditioning_free_k = 1.0
):
super().__init__()
self.pre_model = Pre_model(cfg)
print("pre params: ", count_parameters(self.pre_model))
self.diff_model = Diffusion_Encoder(**cfg['diffusion'])
print("diff params: ", count_parameters(self.diff_model))
self.dim = self.diff_model.in_channels
timesteps = cfg['train']['timesteps']
beta_schedule_fn = linear_beta_schedule
betas = beta_schedule_fn(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = timesteps
self.unconditioned_content = nn.Parameter(torch.randn(1,cfg['phoneme_encoder']['out_channels'],1))
# self.sampling_timesteps = cfg['train']['sampling_timesteps']
self.ddim_sampling_eta = ddim_sampling_eta
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
register_buffer('posterior_variance', posterior_variance)
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
snr = alphas_cumprod / (1 - alphas_cumprod)
maybe_clipped_snr = snr.clone()
if min_snr_loss_weight:
maybe_clipped_snr.clamp_(max = min_snr_gamma)
register_buffer('loss_weight', maybe_clipped_snr)
self.conditioning_free = conditioning_free
self.conditioning_free_k = conditioning_free_k
def predict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def model_predictions(self, x, t, data = None):
model_output = self.diff_model(x,data, t)
t = t.type(torch.int64)
x_start = model_output
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start)
def sample_fun(self, x, t, data = None):
if self.conditioning_free:
# data[1] = self.unconditioned_refer[]
model_output_no_conditioning = self.diff_model(x, data, t)
model_output = self.diff_model(x,data, t)
t = t.type(torch.int64)
pred_noise = model_output
if self.conditioning_free:
cfk = self.conditioning_free_k
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
return pred_noise
def p_mean_variance(self, x, t, data):
preds = self.model_predictions(x, t, data)
x_start = preds.pred_x_start
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t: int, data):
b, *_, device = *x.shape, x.device
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, data=data)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
@torch.no_grad()
def p_sample_loop(self, content, refer, lengths, refer_lengths, f0, uv, auto_predict_f0 = True):
data = (content, refer, f0, 0, 0, lengths, refer_lengths, uv)
content, refer = self.pre_model.infer(data)
shape = (content.shape[1], self.dim, content.shape[0])
batch, device = shape[0], refer.device
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img, x_start = self.p_sample(img, t, (content,refer,lengths,refer_lengths))
imgs.append(img)
ret = img
return ret
@torch.no_grad()
def ddim_sample(self, content, refer, lengths, refer_lengths, f0, uv, auto_predict_f0 = True):
data = (content, refer, f0, 0, 0, lengths, refer_lengths, uv)
content, refer = self.pre_model.infer(data,auto_predict_f0=auto_predict_f0)
shape = (content.shape[1], self.dim, content.shape[0])
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, (content,refer,lengths,refer_lengths))
if time_next < 0:
img = x_start
imgs.append(img)
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
imgs.append(img)
ret = img
return ret
@torch.no_grad()
def sample(self,
mel_recon, refer, lengths, refer_lengths,
# c, refer, f0, uv, lengths, refer_lengths, vocos,
sampling_timesteps=100, sample_method='unipc'
):
mel_recon, refer = normalize_tacotron_mel(mel_recon), normalize_tacotron_mel(refer)
if refer.shape[0]==2:
refer = refer[0].unsqueeze(0)
self.sampling_timesteps = sampling_timesteps
if sample_method == 'ddpm':
sample_fn = self.p_sample_loop
# audio = sample_fn(c, refer, lengths, refer_lengths, f0, uv, auto_predict_f0)
elif sample_method == 'ddim':
sample_fn = self.ddim_sample
# audio = sample_fn(c, refer, lengths, refer_lengths, f0, uv, auto_predict_f0)
elif sample_method == 'dpmsolver':
from sampler.dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas)
def my_wrapper(fn):
def wrapped(x, t, **kwargs):
ret = fn(x, t, **kwargs)
self.bar.update(1)
return ret
return wrapped
# data = (c, refer, f0, 0, 0, lengths, refer_lengths, uv)
# content, refer = self.pre_model.infer(data,auto_predict_f0=auto_predict_f0)
shape = (content.shape[1], self.dim, content.shape[0])
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta
audio = torch.randn(shape, device = device)
model_fn = model_wrapper(
my_wrapper(self.sample_fun),
noise_schedule,
model_type="x_start", #"noise" or "x_start" or "v" or "score"
model_kwargs={"data":(content,refer,lengths,refer_lengths)}
)
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
steps = 40
self.bar = tqdm(desc="sample time step", total=steps)
audio = dpm_solver.sample(
audio,
steps=steps,
order=2,
skip_type="time_uniform",
method="multistep",
)
self.bar.close()
elif sample_method =='unipc':
from ttts.sampler.uni_pc import NoiseScheduleVP, model_wrapper, UniPC
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas)
def my_wrapper(fn):
def wrapped(x, t, **kwargs):
ret = fn(x, t, **kwargs)
self.bar.update(1)
return ret
return wrapped
data = (mel_recon, refer, lengths, refer_lengths)
content, refer = self.pre_model.infer(data)
shape = (content.shape[0], self.dim, content.shape[2])
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta
audio = torch.randn(shape, device = device)
model_fn = model_wrapper(
my_wrapper(self.sample_fun),
noise_schedule,
model_type="noise", #"noise" or "x_start" or "v" or "score"
model_kwargs={"data":(content,refer,lengths,refer_lengths)}
)
uni_pc = UniPC(model_fn, noise_schedule, variant='bh2')
steps = 30
self.bar = tqdm(desc="sample time step", total=steps)
mel = uni_pc.sample(
audio,
steps=steps,
order=2,
skip_type="time_uniform",
method="multistep",
)
self.bar.close()
# mel = audio
# vocos.to(audio.device)
# audio = vocos.decode(audio)
# if audio.ndim == 3:
# audio = rearrange(audio, 'b 1 n -> b n')
# return denormalize(mel)
return denormalize_tacotron_mel(mel)
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def forward(self, data, conditioning_free=False):
unused_params = []
mel_recon_padded, mel_padded, mel_lengths, refer_padded, refer_lengths = data
mel_recon_padded, mel_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(mel_recon_padded)
assert mel_recon_padded.shape[2] == mel_padded.shape[2]
b, d, n, device = *mel_padded.shape, mel_padded.device
x_mask = torch.unsqueeze(commons.sequence_mask(mel_lengths, mel_padded.size(2)), 1).to(mel_padded.dtype)
x_start = mel_padded*x_mask
# get pre model outputs
content, refer = self.pre_model(data)
if conditioning_free==True:
refer = self.unconditioned_refer.repeat(data[0].shape[0], 1 ,1) + refer.mean()*0
else:
unused_params.append(self.unconditioned_refer)
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
noise = torch.randn_like(x_start)*x_mask
# noise sample
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# predict and take gradient step
model_out = self.diff_model(x,(content,refer,mel_lengths,refer_lengths), t)
target = noise
loss = F.mse_loss(model_out, target, reduction = 'none')
loss_diff = reduce(loss, 'b ... -> b (...)', 'mean')
loss_diff = loss_diff * extract(self.loss_weight, t, loss.shape)
loss_diff = loss_diff.mean()
loss = loss_diff
extraneous_addition = 0
for p in unused_params:
extraneous_addition = extraneous_addition + p.mean()
loss = loss + extraneous_addition * 0
return loss
def get_grad_norm(model):
total_norm = 0
for name,p in model.named_parameters():
try:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
except:
print(name)
total_norm = total_norm ** (1. / 2)
return total_norm
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('numba').setLevel(logging.WARNING)