import logging import os import functools import math from pathlib import Path import re import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from ttts.utils.xtransformers import ContinuousTransformerWrapper, RelativePositionBias import glob def get_paths_with_cache(search_path, cache_path=None): out_paths=None if cache_path!=None and os.path.exists(cache_path): out_paths = torch.load(cache_path) else: path = Path(search_path) out_paths = find_audio_files(path, ['.wav','.m4a','.mp3']) if cache_path is not None: print("Building cache.."), cache_path) return out_paths def find_audio_files(folder_path, suffixes): files = [] for suffix in suffixes: files.extend(glob.glob(os.path.join(folder_path, '**', f'*{suffix}'),recursive=True)) return files def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): for k, v in scalars.items(): writer.add_scalar(k, v, global_step) for k, v in histograms.items(): writer.add_histogram(k, v, global_step) for k, v in images.items(): writer.add_image(k, v, global_step, dataformats='HWC') for k, v in audios.items(): writer.add_audio(k, v, global_step, audio_sampling_rate) MATPLOTLIB_FLAG = False def plot_spectrogram_to_numpy(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib matplotlib.use("Agg") MATPLOTLIB_FLAG = True mpl_logger = logging.getLogger('matplotlib') mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(10,2)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation='none') plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data logger = logging def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True): """Freeing up space by deleting saved ckpts Arguments: path_to_models -- Path to the model directory n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth sort_by_time -- True -> chronologically delete ckpts False -> lexicographically delete ckpts """ ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] name_key = (lambda _f: int(re.compile('model-(\d+)\.pt').match(_f).group(1))) time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))) sort_key = time_key if sort_by_time else name_key x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key) to_del = [os.path.join(path_to_models, fn) for fn in (x_sorted('model')[:-n_ckpts_to_keep])] del_info = lambda fn:".. Free up space by deleting ckpt {fn}") del_routine = lambda x: [os.remove(x), del_info(x)] rs = [del_routine(fn) for fn in to_del] # exponential moving average class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new def update_moving_average(ema_updater, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight =, = ema_updater.update_average(old_weight, up_weight) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) def normalization(channels): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ groups = 32 if channels <= 16: groups = 8 elif channels <= 64: groups = 16 while channels % groups != 0: groups = int(groups / 2) assert groups > 2 return GroupNorm32(groups, channels) class QKVAttentionLegacy(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping """ def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv, mask=None, rel_pos=None): """ Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = torch.einsum( "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards if rel_pos is not None: weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) if mask is not None: # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. mask = mask.repeat(self.n_heads, 1).unsqueeze(1) weight = weight * mask a = torch.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. """ def __init__( self, channels, num_heads=1, num_head_channels=-1, do_checkpoint=True, relative_pos_embeddings=False, ): super().__init__() self.channels = channels self.do_checkpoint = do_checkpoint if num_head_channels == -1: self.num_heads = num_heads else: assert ( channels % num_head_channels == 0 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.norm = normalization(channels) self.qkv = nn.Conv1d(channels, channels * 3, 1) # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) if relative_pos_embeddings: self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) else: self.relative_pos_embeddings = None def forward(self, x, mask=None): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) h = self.attention(qkv, mask, self.relative_pos_embeddings) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. """ def __init__(self, channels, use_conv, out_channels=None, factor=4): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.factor = factor if use_conv: ksize = 5 pad = 2 self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad) def forward(self, x): assert x.shape[1] == self.channels x = F.interpolate(x, scale_factor=self.factor, mode="nearest") if self.use_conv: x = self.conv(x) return x class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. """ def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv stride = factor if use_conv: self.op = nn.Conv1d( self.channels, self.out_channels, ksize, stride=stride, padding=pad ) else: assert self.channels == self.out_channels self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResBlock(nn.Module): def __init__( self, channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, up=False, down=False, kernel_size=3, ): super().__init__() self.channels = channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_scale_shift_norm = use_scale_shift_norm padding = 1 if kernel_size == 3 else 2 self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False) self.x_upd = Upsample(channels, False) elif down: self.h_upd = Downsample(channels, False) self.x_upd = Downsample(channels, False) else: self.h_upd = self.x_upd = nn.Identity() self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = nn.Conv1d( channels, self.out_channels, kernel_size, padding=padding ) else: self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) def forward(self, x): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) h = self.out_layers(h) return self.skip_connection(x) + h class AudioMiniEncoder(nn.Module): def __init__(self, spec_dim, embedding_dim, base_channels=128, depth=2, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0, downsample_factor=2, kernel_size=3): super().__init__() self.init = nn.Sequential( nn.Conv1d(spec_dim, base_channels, 3, padding=1) ) ch = base_channels res = [] for l in range(depth): for r in range(resnet_blocks): res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor)) ch *= 2 self.res = nn.Sequential(*res) = nn.Sequential( normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1) ) attn = [] for a in range(attn_blocks): attn.append(AttentionBlock(embedding_dim, num_attn_heads,)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim def forward(self, x): h = self.init(x) h = self.res(h) h = h = self.attn(h) return h[:, :, 0] DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth') class TorchMelSpectrogram(nn.Module): def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000, sampling_rate=22050, normalize=False, mel_norm_file=DEFAULT_MEL_NORM_FILE): super().__init__() # These are the default tacotron values for the MEL spectrogram. self.filter_length = filter_length self.hop_length = hop_length self.win_length = win_length self.n_mel_channels = n_mel_channels self.mel_fmin = mel_fmin self.mel_fmax = mel_fmax self.sampling_rate = sampling_rate self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, win_length=self.win_length, power=2, normalized=normalize, sample_rate=self.sampling_rate, f_min=self.mel_fmin, f_max=self.mel_fmax, n_mels=self.n_mel_channels, norm="slaney") self.mel_norm_file = mel_norm_file if self.mel_norm_file is not None: self.mel_norms = torch.load(self.mel_norm_file) else: self.mel_norms = None def forward(self, inp): if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) inp = inp.squeeze(1) assert len(inp.shape) == 2 if torch.backends.mps.is_available(): inp ='cpu') self.mel_stft = mel = self.mel_stft(inp) # Perform dynamic range compression mel = torch.log(torch.clamp(mel, min=1e-5)) if self.mel_norms is not None: self.mel_norms = mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) return mel class CheckpointedLayer(nn.Module): """ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses checkpoint for all other args. """ def __init__(self, wrap): super().__init__() self.wrap = wrap def forward(self, x, *args, **kwargs): for k, v in kwargs.items(): assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. partial = functools.partial(self.wrap, **kwargs) return partial(x, *args) class CheckpointedXTransformerEncoder(nn.Module): """ Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid to channels-last that XTransformer expects. """ def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): super().__init__() self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) self.needs_permute = needs_permute self.exit_permute = exit_permute if not checkpoint: return for i in range(len(self.transformer.attn_layers.layers)): n, b, r = self.transformer.attn_layers.layers[i] self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) def forward(self, x, **kwargs): if self.needs_permute: x = x.permute(0,2,1) h = self.transformer(x, **kwargs) if self.exit_permute: h = h.permute(0,2,1) return h