""" utils.py Desc: A file for miscellaneous util functions """ import numpy as np import torch # MonoTransform, this does not exist in PyTorch anymore since it is a simple mean calculation. We provide an implementation here class MonoTransform(object): """ Convert audio sample to mono channel Args for __call__: audio_sample with shape (C, T) or (B, C, T), where C is the number of channels. TODO: IMPLEMENT __call__ """ def __init__(self): pass def __call__(self, sample): pass """ Below: Helper functions for Grad-TTS """ ## Duration Loss ## Desc: A function for computing the duration loss for the duration predictor def duration_loss(logw, logw_, lengths): loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) return loss def intersperse(lst, item): # Adds blank symbol result = [item] * (len(lst) * 2 + 1) result[1::2] = lst return result def sequence_mask(length, max_length=None): if max_length is None: max_length = length.max() x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) return x.unsqueeze(0) < length.unsqueeze(1) def fix_len_compatibility(length, num_downsamplings_in_unet=2): while True: if length % (2**num_downsamplings_in_unet) == 0: return length length += 1 def convert_pad_shape(pad_shape): l = pad_shape[::-1] pad_shape = [item for sublist in l for item in sublist] return pad_shape def generate_path(duration, mask): device = duration.device b, t_x, t_y = mask.shape cum_duration = torch.cumsum(duration, 1) path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = path * mask return path