Spaces:
Build error
Build error
| import functools | |
| import math | |
| import os | |
| import fsspec | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from transformers import LogitsWarper | |
| from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class GroupNorm32(nn.GroupNorm): | |
| def forward(self, x): | |
| return super().forward(x.float()).type(x.dtype) | |
| def normalization(channels): | |
| """ | |
| Make a standard normalization layer. | |
| :param channels: number of input channels. | |
| :return: an nn.Module for normalization. | |
| """ | |
| groups = 32 | |
| if channels <= 16: | |
| groups = 8 | |
| elif channels <= 64: | |
| groups = 16 | |
| while channels % groups != 0: | |
| groups = int(groups / 2) | |
| assert groups > 2 | |
| return GroupNorm32(groups, channels) | |
| class QKVAttentionLegacy(nn.Module): | |
| """ | |
| A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping | |
| """ | |
| def __init__(self, n_heads): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| def forward(self, qkv, mask=None, rel_pos=None): | |
| """ | |
| Apply QKV attention. | |
| :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. | |
| :return: an [N x (H * C) x T] tensor after attention. | |
| """ | |
| bs, width, length = qkv.shape | |
| assert width % (3 * self.n_heads) == 0 | |
| ch = width // (3 * self.n_heads) | |
| q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) | |
| scale = 1 / math.sqrt(math.sqrt(ch)) | |
| weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards | |
| if rel_pos is not None: | |
| weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( | |
| bs * self.n_heads, weight.shape[-2], weight.shape[-1] | |
| ) | |
| weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
| if mask is not None: | |
| # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. | |
| mask = mask.repeat(self.n_heads, 1).unsqueeze(1) | |
| weight = weight * mask | |
| a = torch.einsum("bts,bcs->bct", weight, v) | |
| return a.reshape(bs, -1, length) | |
| class AttentionBlock(nn.Module): | |
| """ | |
| An attention block that allows spatial positions to attend to each other. | |
| Originally ported from here, but adapted to the N-d case. | |
| https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. | |
| """ | |
| def __init__( | |
| self, | |
| channels, | |
| num_heads=1, | |
| num_head_channels=-1, | |
| do_checkpoint=True, | |
| relative_pos_embeddings=False, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.do_checkpoint = do_checkpoint | |
| if num_head_channels == -1: | |
| self.num_heads = num_heads | |
| else: | |
| assert ( | |
| channels % num_head_channels == 0 | |
| ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" | |
| self.num_heads = channels // num_head_channels | |
| self.norm = normalization(channels) | |
| self.qkv = nn.Conv1d(channels, channels * 3, 1) | |
| # split heads before split qkv | |
| self.attention = QKVAttentionLegacy(self.num_heads) | |
| self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) | |
| if relative_pos_embeddings: | |
| self.relative_pos_embeddings = RelativePositionBias( | |
| scale=(channels // self.num_heads) ** 0.5, | |
| causal=False, | |
| heads=num_heads, | |
| num_buckets=32, | |
| max_distance=64, | |
| ) | |
| else: | |
| self.relative_pos_embeddings = None | |
| def forward(self, x, mask=None): | |
| b, c, *spatial = x.shape | |
| x = x.reshape(b, c, -1) | |
| qkv = self.qkv(self.norm(x)) | |
| h = self.attention(qkv, mask, self.relative_pos_embeddings) | |
| h = self.proj_out(h) | |
| return (x + h).reshape(b, c, *spatial) | |
| class Upsample(nn.Module): | |
| """ | |
| An upsampling layer with an optional convolution. | |
| :param channels: channels in the inputs and outputs. | |
| :param use_conv: a bool determining if a convolution is applied. | |
| """ | |
| def __init__(self, channels, use_conv, out_channels=None, factor=4): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.factor = factor | |
| if use_conv: | |
| ksize = 5 | |
| pad = 2 | |
| self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad) | |
| def forward(self, x): | |
| assert x.shape[1] == self.channels | |
| x = F.interpolate(x, scale_factor=self.factor, mode="nearest") | |
| if self.use_conv: | |
| x = self.conv(x) | |
| return x | |
| class Downsample(nn.Module): | |
| """ | |
| A downsampling layer with an optional convolution. | |
| :param channels: channels in the inputs and outputs. | |
| :param use_conv: a bool determining if a convolution is applied. | |
| """ | |
| def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| stride = factor | |
| if use_conv: | |
| self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad) | |
| else: | |
| assert self.channels == self.out_channels | |
| self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) | |
| def forward(self, x): | |
| assert x.shape[1] == self.channels | |
| return self.op(x) | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| dropout, | |
| out_channels=None, | |
| use_conv=False, | |
| use_scale_shift_norm=False, | |
| up=False, | |
| down=False, | |
| kernel_size=3, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.dropout = dropout | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.use_scale_shift_norm = use_scale_shift_norm | |
| padding = 1 if kernel_size == 3 else 2 | |
| self.in_layers = nn.Sequential( | |
| normalization(channels), | |
| nn.SiLU(), | |
| nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), | |
| ) | |
| self.updown = up or down | |
| if up: | |
| self.h_upd = Upsample(channels, False) | |
| self.x_upd = Upsample(channels, False) | |
| elif down: | |
| self.h_upd = Downsample(channels, False) | |
| self.x_upd = Downsample(channels, False) | |
| else: | |
| self.h_upd = self.x_upd = nn.Identity() | |
| self.out_layers = nn.Sequential( | |
| normalization(self.out_channels), | |
| nn.SiLU(), | |
| nn.Dropout(p=dropout), | |
| zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), | |
| ) | |
| if self.out_channels == channels: | |
| self.skip_connection = nn.Identity() | |
| elif use_conv: | |
| self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding) | |
| else: | |
| self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) | |
| def forward(self, x): | |
| if self.updown: | |
| in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] | |
| h = in_rest(x) | |
| h = self.h_upd(h) | |
| x = self.x_upd(x) | |
| h = in_conv(h) | |
| else: | |
| h = self.in_layers(x) | |
| h = self.out_layers(h) | |
| return self.skip_connection(x) + h | |
| class AudioMiniEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| spec_dim, | |
| embedding_dim, | |
| base_channels=128, | |
| depth=2, | |
| resnet_blocks=2, | |
| attn_blocks=4, | |
| num_attn_heads=4, | |
| dropout=0, | |
| downsample_factor=2, | |
| kernel_size=3, | |
| ): | |
| super().__init__() | |
| self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) | |
| ch = base_channels | |
| res = [] | |
| for l in range(depth): | |
| for r in range(resnet_blocks): | |
| res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) | |
| res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) | |
| ch *= 2 | |
| self.res = nn.Sequential(*res) | |
| self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) | |
| attn = [] | |
| for a in range(attn_blocks): | |
| attn.append( | |
| AttentionBlock( | |
| embedding_dim, | |
| num_attn_heads, | |
| ) | |
| ) | |
| self.attn = nn.Sequential(*attn) | |
| self.dim = embedding_dim | |
| def forward(self, x): | |
| h = self.init(x) | |
| h = self.res(h) | |
| h = self.final(h) | |
| h = self.attn(h) | |
| return h[:, :, 0] | |
| DEFAULT_MEL_NORM_FILE = "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth" | |
| class TorchMelSpectrogram(nn.Module): | |
| def __init__( | |
| self, | |
| filter_length=1024, | |
| hop_length=256, | |
| win_length=1024, | |
| n_mel_channels=80, | |
| mel_fmin=0, | |
| mel_fmax=8000, | |
| sampling_rate=22050, | |
| normalize=False, | |
| mel_norm_file=DEFAULT_MEL_NORM_FILE, | |
| ): | |
| super().__init__() | |
| # These are the default tacotron values for the MEL spectrogram. | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.n_mel_channels = n_mel_channels | |
| self.mel_fmin = mel_fmin | |
| self.mel_fmax = mel_fmax | |
| self.sampling_rate = sampling_rate | |
| self.mel_stft = torchaudio.transforms.MelSpectrogram( | |
| n_fft=self.filter_length, | |
| hop_length=self.hop_length, | |
| win_length=self.win_length, | |
| power=2, | |
| normalized=normalize, | |
| sample_rate=self.sampling_rate, | |
| f_min=self.mel_fmin, | |
| f_max=self.mel_fmax, | |
| n_mels=self.n_mel_channels, | |
| norm="slaney", | |
| ) | |
| self.mel_norm_file = mel_norm_file | |
| if self.mel_norm_file is not None: | |
| with fsspec.open(self.mel_norm_file) as f: | |
| self.mel_norms = torch.load(f) | |
| else: | |
| self.mel_norms = None | |
| def forward(self, inp): | |
| if ( | |
| len(inp.shape) == 3 | |
| ): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) | |
| inp = inp.squeeze(1) | |
| assert len(inp.shape) == 2 | |
| self.mel_stft = self.mel_stft.to(inp.device) | |
| mel = self.mel_stft(inp) | |
| # Perform dynamic range compression | |
| mel = torch.log(torch.clamp(mel, min=1e-5)) | |
| if self.mel_norms is not None: | |
| self.mel_norms = self.mel_norms.to(mel.device) | |
| mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) | |
| return mel | |
| class CheckpointedLayer(nn.Module): | |
| """ | |
| Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses | |
| checkpoint for all other args. | |
| """ | |
| def __init__(self, wrap): | |
| super().__init__() | |
| self.wrap = wrap | |
| def forward(self, x, *args, **kwargs): | |
| for k, v in kwargs.items(): | |
| assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. | |
| partial = functools.partial(self.wrap, **kwargs) | |
| return partial(x, *args) | |
| class CheckpointedXTransformerEncoder(nn.Module): | |
| """ | |
| Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid | |
| to channels-last that XTransformer expects. | |
| """ | |
| def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): | |
| super().__init__() | |
| self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) | |
| self.needs_permute = needs_permute | |
| self.exit_permute = exit_permute | |
| if not checkpoint: | |
| return | |
| for i in range(len(self.transformer.attn_layers.layers)): | |
| n, b, r = self.transformer.attn_layers.layers[i] | |
| self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) | |
| def forward(self, x, **kwargs): | |
| if self.needs_permute: | |
| x = x.permute(0, 2, 1) | |
| h = self.transformer(x, **kwargs) | |
| if self.exit_permute: | |
| h = h.permute(0, 2, 1) | |
| return h | |
| class TypicalLogitsWarper(LogitsWarper): | |
| def __init__( | |
| self, | |
| mass: float = 0.9, | |
| filter_value: float = -float("Inf"), | |
| min_tokens_to_keep: int = 1, | |
| ): | |
| self.filter_value = filter_value | |
| self.mass = mass | |
| self.min_tokens_to_keep = min_tokens_to_keep | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| # calculate entropy | |
| normalized = torch.nn.functional.log_softmax(scores, dim=-1) | |
| p = torch.exp(normalized) | |
| ent = -(normalized * p).nansum(-1, keepdim=True) | |
| # shift and sort | |
| shifted_scores = torch.abs((-normalized) - ent) | |
| sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) | |
| sorted_logits = scores.gather(-1, sorted_indices) | |
| cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
| # Remove tokens with cumulative mass above the threshold | |
| last_ind = (cumulative_probs < self.mass).sum(dim=1) | |
| last_ind[last_ind < 0] = 0 | |
| sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) | |
| if self.min_tokens_to_keep > 1: | |
| # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |
| sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| scores = scores.masked_fill(indices_to_remove, self.filter_value) | |
| return scores | |