Spaces:
Sleeping
Sleeping
import json | |
import os | |
from pathlib import Path | |
from datetime import datetime | |
from matplotlib import pyplot as plt | |
from ttts.unet1d.embeddings import TextTimeEmbedding | |
from ttts.unet1d.unet_1d_condition import UNet1DConditionModel | |
from vocos import Vocos | |
from torch import expm1, nn | |
import ttts.diffusion.commons as commons | |
from accelerate import Accelerator | |
from ttts.diffusion.operations import OPERATIONS_ENCODER | |
from accelerate import DistributedDataParallelKwargs | |
import math | |
from multiprocessing import cpu_count | |
from pathlib import Path | |
from random import random | |
from functools import partial | |
from collections import namedtuple | |
from torch.utils.tensorboard import SummaryWriter | |
import logging | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, einsum | |
from torch.optim import AdamW | |
from torch.utils.data import Dataset, DataLoader | |
from einops import rearrange, reduce, repeat | |
from einops.layers.torch import Rearrange | |
from tqdm.auto import tqdm | |
TACOTRON_MEL_MAX = 5.5451774444795624753378569716654 | |
TACOTRON_MEL_MIN = -16.118095650958319788125940182791 | |
# TACOTRON_MEL_MIN = -11.512925464970228420089957273422 | |
# -16.118095650958319788125940182791 | |
def denormalize_tacotron_mel(norm_mel): | |
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN | |
def normalize_tacotron_mel(mel): | |
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 | |
def exists(x): | |
return x is not None | |
def cycle(dl): | |
while True: | |
for data in dl: | |
yield data | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
class TransformerEncoderLayer(nn.Module): | |
def __init__(self, layer, hidden_size, dropout): | |
super().__init__() | |
self.layer = layer | |
self.hidden_size = hidden_size | |
self.dropout = dropout | |
self.op = OPERATIONS_ENCODER[layer](hidden_size, dropout) | |
def forward(self, x, **kwargs): | |
return self.op(x, **kwargs) | |
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): | |
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) | |
class ConvTBC(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, padding=0): | |
super(ConvTBC, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.padding = padding | |
self.weight = torch.nn.Parameter(torch.Tensor( | |
self.kernel_size, in_channels, out_channels)) | |
self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) | |
def forward(self, input): | |
return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding) | |
class ConvLayer(nn.Module): | |
def __init__(self, c_in, c_out, kernel_size, dropout=0): | |
super().__init__() | |
self.layer_norm = LayerNorm(c_in) | |
conv = ConvTBC(c_in, c_out, kernel_size, padding=kernel_size // 2) | |
std = math.sqrt((4 * (1.0 - dropout)) / (kernel_size * c_in)) | |
nn.init.normal_(conv.weight, mean=0, std=std) | |
nn.init.constant_(conv.bias, 0) | |
self.conv = conv | |
def forward(self, x, encoder_padding_mask=None, **kwargs): | |
layer_norm_training = kwargs.get('layer_norm_training', None) | |
if layer_norm_training is not None: | |
self.layer_norm.training = layer_norm_training | |
if encoder_padding_mask is not None: | |
x = x.masked_fill(encoder_padding_mask.t().unsqueeze(-1), 0) | |
x = self.layer_norm(x) | |
x = self.conv(x) | |
return x | |
class PhoneEncoder(nn.Module): | |
def __init__(self, | |
in_channels=128, | |
hidden_channels=512, | |
out_channels=512, | |
n_layers=6, | |
p_dropout=0.2, | |
last_ln = True): | |
super().__init__() | |
self.arch = [8 for _ in range(n_layers)] | |
self.num_layers = n_layers | |
self.hidden_size = hidden_channels | |
self.padding_idx = 0 | |
self.dropout = p_dropout | |
self.layers = nn.ModuleList([]) | |
self.layers.extend([ | |
TransformerEncoderLayer(self.arch[i], self.hidden_size, self.dropout) | |
for i in range(self.num_layers) | |
]) | |
self.last_ln = last_ln | |
self.pre = ConvLayer(in_channels, hidden_channels, 1, p_dropout) | |
# self.prompt_proj = ConvLayer(in_channels, hidden_channels, 1, p_dropout) | |
self.out_proj = ConvLayer(hidden_channels, out_channels, 1, p_dropout) | |
if last_ln: | |
self.layer_norm = LayerNorm(out_channels) | |
self.spk_proj = nn.Conv1d(100,hidden_channels,1) | |
def forward(self, src_tokens, lengths, g=None): | |
# B x C x T -> T x B x C | |
src_tokens = self.spk_proj(src_tokens+g) | |
src_tokens = rearrange(src_tokens, 'b c t -> t b c') | |
# compute padding mask | |
encoder_padding_mask = ~commons.sequence_mask(lengths, src_tokens.size(0)).to(torch.bool) | |
# prompt_mask = ~commons.sequence_mask(prompt_lengths, prompt.size(0)).to(torch.bool) | |
x = src_tokens | |
x = self.pre(x, encoder_padding_mask=encoder_padding_mask) | |
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] | |
# prompt = self.prompt_proj(prompt, encoder_padding_mask=prompt_mask) | |
# encoder layers | |
for i in range(self.num_layers): | |
x = self.layers[i](x, encoder_padding_mask=encoder_padding_mask) | |
# x = x+self.attn_blocks[i](x, prompt, prompt, key_padding_mask=prompt_mask)[0] | |
x = self.out_proj(x, encoder_padding_mask=encoder_padding_mask) | |
if self.last_ln: | |
x = self.layer_norm(x) | |
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] | |
x = rearrange(x, 't b c-> b c t') | |
return x | |
class PromptEncoder(nn.Module): | |
def __init__(self, | |
in_channels=128, | |
hidden_channels=256, | |
out_channels=512, | |
n_layers=6, | |
p_dropout=0.2, | |
last_ln = True): | |
super().__init__() | |
self.arch = [8 for _ in range(n_layers)] | |
self.num_layers = n_layers | |
self.hidden_size = hidden_channels | |
self.padding_idx = 0 | |
self.dropout = p_dropout | |
self.layers = nn.ModuleList([]) | |
self.layers.extend([ | |
TransformerEncoderLayer(self.arch[i], self.hidden_size, self.dropout) | |
for i in range(self.num_layers) | |
]) | |
self.last_ln = last_ln | |
if last_ln: | |
self.layer_norm = LayerNorm(out_channels) | |
self.pre = ConvLayer(in_channels, hidden_channels, 1, p_dropout) | |
self.out_proj = ConvLayer(hidden_channels, out_channels, 1, p_dropout) | |
def forward(self, src_tokens, lengths=None): | |
# B x C x T -> T x B x C | |
src_tokens = rearrange(src_tokens, 'b c t -> t b c') | |
# compute padding mask | |
encoder_padding_mask = ~commons.sequence_mask(lengths, src_tokens.size(0)).to(torch.bool) | |
x = src_tokens | |
x = self.pre(x, encoder_padding_mask=encoder_padding_mask) | |
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] | |
# encoder layers | |
for layer in self.layers: | |
x = layer(x, encoder_padding_mask=encoder_padding_mask) | |
x = self.out_proj(x, encoder_padding_mask=encoder_padding_mask) | |
if self.last_ln: | |
x = self.layer_norm(x) | |
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] | |
x = rearrange(x, 't b c-> b c t') | |
return x | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
def silu(x): | |
return x * torch.sigmoid(x) | |
class ResidualBlock(nn.Module): | |
def __init__(self, n_mels, residual_channels, dilation, kernel_size, dropout): | |
''' | |
:param n_mels: inplanes of conv1x1 for spectrogram conditional | |
:param residual_channels: audio conv | |
:param dilation: audio conv dilation | |
:param uncond: disable spectrogram conditional | |
''' | |
super().__init__() | |
if dilation==1: | |
padding = kernel_size//2 | |
else: | |
padding = dilation | |
self.dilated_conv = ConvLayer(residual_channels, 2 * residual_channels, kernel_size) | |
self.conditioner_projection = ConvLayer(n_mels, 2 * residual_channels, 1) | |
# self.output_projection = ConvLayer(residual_channels, 2 * residual_channels, 1) | |
self.output_projection = ConvLayer(residual_channels, residual_channels, 1) | |
self.t_proj = ConvLayer(residual_channels, residual_channels, 1) | |
self.drop = nn.Dropout(dropout) | |
def forward(self, x, diffusion_step, conditioner,x_mask): | |
assert (conditioner is None and self.conditioner_projection is None) or \ | |
(conditioner is not None and self.conditioner_projection is not None) | |
#T B C | |
y = x + self.t_proj(diffusion_step.unsqueeze(0)) | |
y = y.masked_fill(x_mask.t().unsqueeze(-1), 0) | |
conditioner = self.conditioner_projection(conditioner) | |
conditioner = self.drop(conditioner) | |
y = self.dilated_conv(y) + conditioner | |
y = y.masked_fill(x_mask.t().unsqueeze(-1), 0) | |
gate, filter_ = torch.chunk(y, 2, dim=-1) | |
y = torch.sigmoid(gate) * torch.tanh(filter_) | |
y = y.masked_fill(x_mask.t().unsqueeze(-1), 0) | |
y = self.output_projection(y) | |
return y | |
# y = y.masked_fill(x_mask.t().unsqueeze(-1), 0) | |
# residual, skip = torch.chunk(y, 2, dim=-1) | |
# return (x + residual) / math.sqrt(2.0), skip | |
class Pre_model(nn.Module): | |
def __init__(self, cfg) -> None: | |
super().__init__() | |
self.cfg = cfg | |
self.phoneme_encoder = PhoneEncoder(**self.cfg['phoneme_encoder']) | |
print("phoneme params:", count_parameters(self.phoneme_encoder)) | |
self.prompt_encoder = PromptEncoder(**self.cfg['prompt_encoder']) | |
print("prompt params:", count_parameters(self.prompt_encoder)) | |
dim = self.cfg['phoneme_encoder']['out_channels'] | |
self.ref_enc = TextTimeEmbedding(100, 100, 1) | |
def forward(self,data, g=None): | |
mel_recon_padded, mel_padded, mel_lengths, refer_padded, refer_lengths = data | |
mel_recon_padded, refer_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(refer_padded) | |
g = self.ref_enc(refer_padded.transpose(1,2)).unsqueeze(-1) | |
audio_prompt = self.prompt_encoder(refer_padded,refer_lengths) | |
content = self.phoneme_encoder(mel_recon_padded, mel_lengths, g) | |
return content, audio_prompt | |
def infer(self, data): | |
mel_recon_padded, refer_padded, mel_lengths, refer_lengths = data | |
mel_recon_padded, refer_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(refer_padded) | |
g = self.ref_enc(refer_padded.transpose(1,2)).unsqueeze(-1) | |
audio_prompt = self.prompt_encoder(refer_padded,refer_lengths) | |
content = self.phoneme_encoder(mel_recon_padded, mel_lengths, g) | |
return content, audio_prompt | |
class Diffusion_Encoder(nn.Module): | |
def __init__(self, | |
in_channels=128, | |
out_channels=128, | |
hidden_channels=256, | |
block_out_channels = [128,256,384,512], | |
n_heads=8, | |
p_dropout=0.2, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.hidden_channels = hidden_channels | |
self.n_heads=n_heads | |
self.unet = UNet1DConditionModel( | |
in_channels=in_channels+hidden_channels, | |
out_channels=out_channels, | |
block_out_channels=block_out_channels, | |
norm_num_groups=8, | |
cross_attention_dim=hidden_channels, | |
attention_head_dim=n_heads, | |
addition_embed_type='text', | |
resnet_time_scale_shift='scale_shift', | |
) | |
def forward(self, x, data, t): | |
assert torch.isnan(x).any() == False | |
contentvec, prompt, contentvec_lengths, prompt_lengths = data | |
prompt = rearrange(prompt,' b c t-> b t c') | |
x = torch.cat([x, contentvec], dim=1) | |
prompt_mask = commons.sequence_mask(prompt_lengths, prompt.size(1)).to(torch.bool) | |
x = self.unet(x, t, prompt, encoder_attention_mask=prompt_mask) | |
return x.sample | |
# tensor helper functions | |
def log(t, eps = 1e-20): | |
return torch.log(t.clamp(min = eps)) | |
def extract(a, t, x_shape): | |
b, *_ = t.shape | |
out = a.gather(-1, t) | |
return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
def linear_beta_schedule(timesteps): | |
""" | |
linear schedule, proposed in original ddpm paper | |
""" | |
scale = 1000 / timesteps | |
beta_start = scale * 0.0001 | |
beta_end = scale * 0.02 | |
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if callable(d) else d | |
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) | |
class Diffuser(nn.Module): | |
def __init__(self, | |
cfg, | |
ddim_sampling_eta = 0, | |
min_snr_loss_weight = False, | |
min_snr_gamma = 5, | |
conditioning_free = True, | |
conditioning_free_k = 1.0 | |
): | |
super().__init__() | |
self.pre_model = Pre_model(cfg) | |
print("pre params: ", count_parameters(self.pre_model)) | |
self.diff_model = Diffusion_Encoder(**cfg['diffusion']) | |
print("diff params: ", count_parameters(self.diff_model)) | |
self.dim = self.diff_model.in_channels | |
timesteps = cfg['train']['timesteps'] | |
beta_schedule_fn = linear_beta_schedule | |
betas = beta_schedule_fn(timesteps) | |
alphas = 1. - betas | |
alphas_cumprod = torch.cumprod(alphas, dim = 0) | |
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) | |
timesteps, = betas.shape | |
self.num_timesteps = timesteps | |
self.unconditioned_content = nn.Parameter(torch.randn(1,cfg['phoneme_encoder']['out_channels'],1)) | |
# self.sampling_timesteps = cfg['train']['sampling_timesteps'] | |
self.ddim_sampling_eta = ddim_sampling_eta | |
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) | |
register_buffer('betas', betas) | |
register_buffer('alphas_cumprod', alphas_cumprod) | |
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) | |
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) | |
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) | |
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) | |
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) | |
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) | |
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) | |
register_buffer('posterior_variance', posterior_variance) | |
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) | |
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) | |
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) | |
snr = alphas_cumprod / (1 - alphas_cumprod) | |
maybe_clipped_snr = snr.clone() | |
if min_snr_loss_weight: | |
maybe_clipped_snr.clamp_(max = min_snr_gamma) | |
register_buffer('loss_weight', maybe_clipped_snr) | |
self.conditioning_free = conditioning_free | |
self.conditioning_free_k = conditioning_free_k | |
def predict_noise_from_start(self, x_t, t, x0): | |
return ( | |
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ | |
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
) | |
def q_posterior(self, x_start, x_t, t): | |
posterior_mean = ( | |
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + | |
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
) | |
posterior_variance = extract(self.posterior_variance, t, x_t.shape) | |
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) | |
return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
def model_predictions(self, x, t, data = None): | |
model_output = self.diff_model(x,data, t) | |
t = t.type(torch.int64) | |
x_start = model_output | |
pred_noise = self.predict_noise_from_start(x, t, x_start) | |
return ModelPrediction(pred_noise, x_start) | |
def sample_fun(self, x, t, data = None): | |
if self.conditioning_free: | |
# data[1] = self.unconditioned_refer[] | |
model_output_no_conditioning = self.diff_model(x, data, t) | |
model_output = self.diff_model(x,data, t) | |
t = t.type(torch.int64) | |
pred_noise = model_output | |
if self.conditioning_free: | |
cfk = self.conditioning_free_k | |
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning | |
return pred_noise | |
def p_mean_variance(self, x, t, data): | |
preds = self.model_predictions(x, t, data) | |
x_start = preds.pred_x_start | |
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) | |
return model_mean, posterior_variance, posterior_log_variance, x_start | |
def p_sample(self, x, t: int, data): | |
b, *_, device = *x.shape, x.device | |
batched_times = torch.full((b,), t, device = device, dtype = torch.long) | |
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, data=data) | |
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 | |
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise | |
return pred_img, x_start | |
def p_sample_loop(self, content, refer, lengths, refer_lengths, f0, uv, auto_predict_f0 = True): | |
data = (content, refer, f0, 0, 0, lengths, refer_lengths, uv) | |
content, refer = self.pre_model.infer(data) | |
shape = (content.shape[1], self.dim, content.shape[0]) | |
batch, device = shape[0], refer.device | |
img = torch.randn(shape, device = device) | |
imgs = [img] | |
x_start = None | |
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): | |
img, x_start = self.p_sample(img, t, (content,refer,lengths,refer_lengths)) | |
imgs.append(img) | |
ret = img | |
return ret | |
def ddim_sample(self, content, refer, lengths, refer_lengths, f0, uv, auto_predict_f0 = True): | |
data = (content, refer, f0, 0, 0, lengths, refer_lengths, uv) | |
content, refer = self.pre_model.infer(data,auto_predict_f0=auto_predict_f0) | |
shape = (content.shape[1], self.dim, content.shape[0]) | |
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta | |
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps | |
times = list(reversed(times.int().tolist())) | |
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] | |
img = torch.randn(shape, device = device) | |
imgs = [img] | |
x_start = None | |
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): | |
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) | |
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, (content,refer,lengths,refer_lengths)) | |
if time_next < 0: | |
img = x_start | |
imgs.append(img) | |
continue | |
alpha = self.alphas_cumprod[time] | |
alpha_next = self.alphas_cumprod[time_next] | |
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() | |
c = (1 - alpha_next - sigma ** 2).sqrt() | |
noise = torch.randn_like(img) | |
img = x_start * alpha_next.sqrt() + \ | |
c * pred_noise + \ | |
sigma * noise | |
imgs.append(img) | |
ret = img | |
return ret | |
def sample(self, | |
mel_recon, refer, lengths, refer_lengths, | |
# c, refer, f0, uv, lengths, refer_lengths, vocos, | |
sampling_timesteps=100, sample_method='unipc' | |
): | |
mel_recon, refer = normalize_tacotron_mel(mel_recon), normalize_tacotron_mel(refer) | |
if refer.shape[0]==2: | |
refer = refer[0].unsqueeze(0) | |
self.sampling_timesteps = sampling_timesteps | |
if sample_method == 'ddpm': | |
sample_fn = self.p_sample_loop | |
# audio = sample_fn(c, refer, lengths, refer_lengths, f0, uv, auto_predict_f0) | |
elif sample_method == 'ddim': | |
sample_fn = self.ddim_sample | |
# audio = sample_fn(c, refer, lengths, refer_lengths, f0, uv, auto_predict_f0) | |
elif sample_method == 'dpmsolver': | |
from sampler.dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver | |
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas) | |
def my_wrapper(fn): | |
def wrapped(x, t, **kwargs): | |
ret = fn(x, t, **kwargs) | |
self.bar.update(1) | |
return ret | |
return wrapped | |
# data = (c, refer, f0, 0, 0, lengths, refer_lengths, uv) | |
# content, refer = self.pre_model.infer(data,auto_predict_f0=auto_predict_f0) | |
shape = (content.shape[1], self.dim, content.shape[0]) | |
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta | |
audio = torch.randn(shape, device = device) | |
model_fn = model_wrapper( | |
my_wrapper(self.sample_fun), | |
noise_schedule, | |
model_type="x_start", #"noise" or "x_start" or "v" or "score" | |
model_kwargs={"data":(content,refer,lengths,refer_lengths)} | |
) | |
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") | |
steps = 40 | |
self.bar = tqdm(desc="sample time step", total=steps) | |
audio = dpm_solver.sample( | |
audio, | |
steps=steps, | |
order=2, | |
skip_type="time_uniform", | |
method="multistep", | |
) | |
self.bar.close() | |
elif sample_method =='unipc': | |
from ttts.sampler.uni_pc import NoiseScheduleVP, model_wrapper, UniPC | |
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas) | |
def my_wrapper(fn): | |
def wrapped(x, t, **kwargs): | |
ret = fn(x, t, **kwargs) | |
self.bar.update(1) | |
return ret | |
return wrapped | |
data = (mel_recon, refer, lengths, refer_lengths) | |
content, refer = self.pre_model.infer(data) | |
shape = (content.shape[0], self.dim, content.shape[2]) | |
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta | |
audio = torch.randn(shape, device = device) | |
model_fn = model_wrapper( | |
my_wrapper(self.sample_fun), | |
noise_schedule, | |
model_type="noise", #"noise" or "x_start" or "v" or "score" | |
model_kwargs={"data":(content,refer,lengths,refer_lengths)} | |
) | |
uni_pc = UniPC(model_fn, noise_schedule, variant='bh2') | |
steps = 30 | |
self.bar = tqdm(desc="sample time step", total=steps) | |
mel = uni_pc.sample( | |
audio, | |
steps=steps, | |
order=2, | |
skip_type="time_uniform", | |
method="multistep", | |
) | |
self.bar.close() | |
# mel = audio | |
# vocos.to(audio.device) | |
# audio = vocos.decode(audio) | |
# if audio.ndim == 3: | |
# audio = rearrange(audio, 'b 1 n -> b n') | |
# return denormalize(mel) | |
return denormalize_tacotron_mel(mel) | |
def q_sample(self, x_start, t, noise = None): | |
noise = default(noise, lambda: torch.randn_like(x_start)) | |
return ( | |
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + | |
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise | |
) | |
def forward(self, data, conditioning_free=False): | |
unused_params = [] | |
mel_recon_padded, mel_padded, mel_lengths, refer_padded, refer_lengths = data | |
mel_recon_padded, mel_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(mel_recon_padded) | |
assert mel_recon_padded.shape[2] == mel_padded.shape[2] | |
b, d, n, device = *mel_padded.shape, mel_padded.device | |
x_mask = torch.unsqueeze(commons.sequence_mask(mel_lengths, mel_padded.size(2)), 1).to(mel_padded.dtype) | |
x_start = mel_padded*x_mask | |
# get pre model outputs | |
content, refer = self.pre_model(data) | |
if conditioning_free==True: | |
refer = self.unconditioned_refer.repeat(data[0].shape[0], 1 ,1) + refer.mean()*0 | |
else: | |
unused_params.append(self.unconditioned_refer) | |
t = torch.randint(0, self.num_timesteps, (b,), device=device).long() | |
noise = torch.randn_like(x_start)*x_mask | |
# noise sample | |
x = self.q_sample(x_start = x_start, t = t, noise = noise) | |
# predict and take gradient step | |
model_out = self.diff_model(x,(content,refer,mel_lengths,refer_lengths), t) | |
target = noise | |
loss = F.mse_loss(model_out, target, reduction = 'none') | |
loss_diff = reduce(loss, 'b ... -> b (...)', 'mean') | |
loss_diff = loss_diff * extract(self.loss_weight, t, loss.shape) | |
loss_diff = loss_diff.mean() | |
loss = loss_diff | |
extraneous_addition = 0 | |
for p in unused_params: | |
extraneous_addition = extraneous_addition + p.mean() | |
loss = loss + extraneous_addition * 0 | |
return loss | |
def get_grad_norm(model): | |
total_norm = 0 | |
for name,p in model.named_parameters(): | |
try: | |
param_norm = p.grad.data.norm(2) | |
total_norm += param_norm.item() ** 2 | |
except: | |
print(name) | |
total_norm = total_norm ** (1. / 2) | |
return total_norm | |
logging.getLogger('matplotlib').setLevel(logging.WARNING) | |
logging.getLogger('numba').setLevel(logging.WARNING) |