diff --git a/orator/src/orator/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/__pycache__/__init__.cpython-311.pyc index 6ad8f5de7c9ca3e1810d0327c25196f38459a597..4d73364ffc4777d01681ff959635ee6123e075a6 100644 Binary files a/orator/src/orator/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/__pycache__/tts.cpython-311.pyc b/orator/src/orator/__pycache__/tts.cpython-311.pyc index 29395581a11f480b8650fe8156c2c05d04b7c889..090fe10720565fb5bc2b44a709c10af6512786bb 100644 Binary files a/orator/src/orator/__pycache__/tts.cpython-311.pyc and b/orator/src/orator/__pycache__/tts.cpython-311.pyc differ diff --git a/orator/src/orator/models/bigvgan/__pycache__/activations.cpython-311.pyc b/orator/src/orator/models/bigvgan/__pycache__/activations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abdfafb4abc0f4c125638ebe9d4f456039bf68fc Binary files /dev/null and b/orator/src/orator/models/bigvgan/__pycache__/activations.cpython-311.pyc differ diff --git a/orator/src/orator/models/bigvgan/__pycache__/bigvgan.cpython-311.pyc b/orator/src/orator/models/bigvgan/__pycache__/bigvgan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bc39b3811393b032fad3f25eb8c822ad831b6c0 Binary files /dev/null and b/orator/src/orator/models/bigvgan/__pycache__/bigvgan.cpython-311.pyc differ diff --git a/orator/src/orator/models/bigvgan/activations.py b/orator/src/orator/models/bigvgan/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..30a3c85145eb147e61331f9dbd5d2b3650146851 --- /dev/null +++ b/orator/src/orator/models/bigvgan/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/orator/src/orator/models/bigvgan/alias_free_torch/__init__.py b/orator/src/orator/models/bigvgan/alias_free_torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f756ed83f87f9839e457b240f60469bc187707d --- /dev/null +++ b/orator/src/orator/models/bigvgan/alias_free_torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * diff --git a/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdf57d13c2b1e94a2c20321d6fcab00ee86ba913 Binary files /dev/null and b/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/act.cpython-311.pyc b/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/act.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e4a139e20f899bddacf05d467861e2857286268 Binary files /dev/null and b/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/act.cpython-311.pyc differ diff --git a/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/filter.cpython-311.pyc b/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/filter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6416602224f887fa03f8bce27fc952f8f6ff23a Binary files /dev/null and b/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/filter.cpython-311.pyc differ diff --git a/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/resample.cpython-311.pyc b/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/resample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af56e62e2e4bffcd9444f653101a91af4241494b Binary files /dev/null and b/orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/resample.cpython-311.pyc differ diff --git a/orator/src/orator/models/bigvgan/alias_free_torch/act.py b/orator/src/orator/models/bigvgan/alias_free_torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..ef231b01506f01c2b66d2dc4f3f0891219b3b41a --- /dev/null +++ b/orator/src/orator/models/bigvgan/alias_free_torch/act.py @@ -0,0 +1,28 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn + +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B, C, T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x diff --git a/orator/src/orator/models/bigvgan/alias_free_torch/filter.py b/orator/src/orator/models/bigvgan/alias_free_torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..066dce8eef9f31a868554f08efbef7c3f4422b7b --- /dev/null +++ b/orator/src/orator/models/bigvgan/alias_free_torch/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out diff --git a/orator/src/orator/models/bigvgan/alias_free_torch/resample.py b/orator/src/orator/models/bigvgan/alias_free_torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..73670db9735504a51231fbe93cb812f722fb74ae --- /dev/null +++ b/orator/src/orator/models/bigvgan/alias_free_torch/resample.py @@ -0,0 +1,55 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F + +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size + ) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size + ) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/orator/src/orator/models/bigvgan/bigvgan.py b/orator/src/orator/models/bigvgan/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..356142106f6c91b0cd4c8db4ec28e04811e8e1ef --- /dev/null +++ b/orator/src/orator/models/bigvgan/bigvgan.py @@ -0,0 +1,212 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import logging +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm +from torch.nn.utils.weight_norm import WeightNorm + +from .activations import SnakeBeta +from .alias_free_torch import * + + + +LRELU_SLOPE = 0.1 + +logger = logging.getLogger(__name__) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(AMPBlock1, self).__init__() + + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList([ + Activation1d(activation=SnakeBeta(channels, alpha_logscale=True)) + for _ in range(self.num_layers) + ]) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def set_weight_norm(self, enabled: bool): + weight_norm_fn = weight_norm if enabled else remove_weight_norm + for l in self.convs1: + weight_norm_fn(l) + for l in self.convs2: + weight_norm_fn(l) + + +class BigVGAN(nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + + # We've got a model in prod that has the wrong hparams for this. It's simpler to add this check than to + # redistribute the model. + ignore_state_dict_unexpected = ("cond_layer.*",) + + def __init__(self): + super().__init__() + + input_dims = 80 + + upsample_rates = [10, 8, 4, 2] + upsample_kernel_sizes = [x * 2 for x in upsample_rates] + upsample_initial_channel = 1024 + + resblock_kernel_sizes = [3, 7, 11] + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + # pre conv + self.conv_pre = weight_norm(Conv1d(input_dims, upsample_initial_channel, 7, 1, padding=3)) + self.cond_layer = None + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(nn.ModuleList([ + weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2)) + ])) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(AMPBlock1(ch, k, d)) + + # post conv + activation_post = SnakeBeta(ch, alpha_logscale=True) + self.activation_post = Activation1d(activation=activation_post) + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x) -> torch.Tensor: + """ + Args + ---- + x: torch.Tensor of shape [B, T, C] + """ + with torch.inference_mode(): + + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + + # Bound the output to [-1, 1] + x = torch.tanh(x) + + return x + + @property + def weight_norm_enabled(self) -> bool: + return any( + isinstance(hook, WeightNorm) and hook.name == "weight" + for k, hook in self.conv_pre._forward_pre_hooks.items() + ) + + def set_weight_norm(self, enabled: bool): + """ + N.B.: weight norm modifies the state dict, causing incompatibilities. Conventions: + - BigVGAN runs with weight norm for training, without for inference (done automatically by instantiate()) + - All checkpoints are saved with weight norm (allows resuming training) + """ + if enabled != self.weight_norm_enabled: + weight_norm_fn = weight_norm if enabled else remove_weight_norm + logger.debug(f"{'Applying' if enabled else 'Removing'} weight norm...") + + for l in self.ups: + for l_i in l: + weight_norm_fn(l_i) + for l in self.resblocks: + l.set_weight_norm(enabled) + weight_norm_fn(self.conv_pre) + weight_norm_fn(self.conv_post) + + def train_mode(self): + self.train() + self.set_weight_norm(enabled=True) + + def inference_mode(self): + self.eval() + self.set_weight_norm(enabled=False) + + +if __name__ == '__main__': + import sys + import soundfile as sf + model = BigVGAN() + + state_dict = torch.load("bigvgan32k.pt") + msg = model.load_state_dict(state_dict) + model.eval() + model.set_weight_norm(enabled=False) + + print(msg) + mels = torch.load("mels.pt") + with torch.inference_mode(): + y = model(mels.cpu()) + + for i, wav in enumerate(y): + wav = wav.view(-1).detach().numpy() + sf.write(f"bigvgan_test{i}.flac", wav, samplerate=32_000, format="FLAC") diff --git a/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc index 8c1699d85577c0eb3fe46fe8d05804981f0498e1..5d342fa7f91de42eb90c2f718e96aee92e8a508b 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc index d6fe0e739088104765e9d4b0a805d61b5fca4bc9..e48f54f1c8576d6a38cc373a299e3db210217574 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc index 15039b8ac9d6b33e48e3171267d0f625f572d501..ec011f3ab10de9b0af8106bdc40f7d899bdb0ea2 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc index 8a166e2f186d6fdf7d82570266e1ac96ec6add09..7d48cc3d1d19db9d44c7181067ab5603ec06554d 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc index fdbe55fdb8e75b258ce91e94ca09cc743adbb2eb..87974e30b64cf53f11a44dbfd7e98b9e9aedfb92 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc index 96606c6e34fb092f600b96dad307e292162edbd6..595dfad5532ed2b6585ed8a7c0a63ab3de713f74 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc index 9ab636cef70ced1ae7dbd376e19e73f03d089a04..2efcda58d717d84fa19895f595c60569bd871aae 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc index f643c02e44e0ad043029e5379998c15ea69e9c0f..a9c95f44491c7e486960719e2ca6e1bc81e4896d 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc b/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc index d2bf0cbd41dc525b9131d9c18f96021ecccd4b72..75038ede7bcb6e855476701b9b49005babe03412 100644 Binary files a/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc b/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc index 42e1166c8d244884c63a58957d94ec95160ea70b..a4720cffdda43eddaa412a82e32a28b2c4da0fd9 100644 Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc b/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc index cf4494fa06afb946d3cc8e613a9b6fb2cf6f2411..cc8c38ede378987f616ba398ba51b6856825de33 100644 Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc b/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc index 562bb3e3db1dc26262d483b98ff00e9184264a81..cd888d39388ea076c44e8209dba94d354d760d97 100644 Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc index 68975983af6e89a0802d465c904e25165e7a94f9..fb5c4d5f281b7e7a94b55d890414723c172333cd 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc index c4c2e7e5ed0829a5f0b1d810c25a2e6da94c302e..b8388f1425fc6faddad43b710f0c58d5c374f58a 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc index 21b931aa26786a0bc88a3aa8cea5e1c174fdea37..cfcd257b733783a40035b0c8fbe62d0df2409c20 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc index 0bd11bf383f2959ca171b136301a685ae73898da..d7cbc1355320d0b1cdcbb187e80b7b78d9802453 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc index f45bc1da1ca3ec42d268e82e89f551d5517740bd..619023d20d5831a83d777d768eb5876bebfff72b 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc index 0bfca5ce699f05d5afa8e077df735eb7684e56e3..051a22fa4053a02ce51f1e6768a97acb3a8bd6f1 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc index 8493bd49f27b689b902b376ed4e8e60660e05162..ad9e6ed4c9925c561c6c1f7026478bffa7ab2ab2 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc index 069b00e6c39c956c5398b9f9da80d0801a56fe14..44b67d483ae3c70b9b144f5030b68877eb075393 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc b/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc index 3266bb5614d1c229a25543d11fa6c9f0d39c9d74..fcf1cec2735b03a6b2442bcf57003ac624ea552f 100644 Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc b/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc index 12857ba8f4f8bd70e07bd6e21b959c5d7f11ffec..4b8cc4ebe7e8779f376df268307fc0fa20d4fb20 100644 Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc b/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc index b9c6b04b80c213f9f464d8619bcdf4d6ccb80abb..d3bd7c9f20a0c4b75cff9d1ccc723c2839045706 100644 Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc b/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc index 214a3ebeab56b71cd7de1a499192d6d4b5d893ed..42e63bec8db32bb9a5349b24e752bbbfb9c288f9 100644 Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc index 717760c538cda8eb3167fe178e2619cd52b8bc29..21105ed2e9931207e586ab4020974f13613385f9 100644 Binary files a/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc b/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc index d25539d27187172122848ab843e24cfd1bf0fb1e..ec1a2cb71573c30c0b17137a662f4d22009f804d 100644 Binary files a/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc and b/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc index 26e581a98c437c8f004bace5aed0f5e445248941..4955aa017c54594b34bd714a7ab635edc6c7f0a3 100644 Binary files a/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc b/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc index ee55c55d42c9253ee521f637f35cf3c61a6d9c5d..c2e37f3930829ce356a9e01fd75a51a856624f0f 100644 Binary files a/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc b/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc index 8604a9a811aaf1ff22e22df605bb7207bdb54820..d70491dab33ea926ac13d89f18800ca491b1fa0e 100644 Binary files a/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc b/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc index 9a3a360d772733fd0e439662f94c3eade57ce7db..e72820205c9dc263715ccbcdef1321f7644b7d32 100644 Binary files a/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc and b/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/inference/alignment_stream_analyzer.py b/orator/src/orator/models/t3/inference/alignment_stream_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a144f0f7f0cdef4a7a4c049db3b5433744296e --- /dev/null +++ b/orator/src/orator/models/t3/inference/alignment_stream_analyzer.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025 Resemble AI +# Author: John Meade, Jeremy Hsu +# MIT License +import logging +import torch +from dataclasses import dataclass +from types import MethodType + + +logger = logging.getLogger(__name__) + + +@dataclass +class AlignmentAnalysisResult: + # was this frame detected as being part of a noisy beginning chunk with potential hallucinations? + false_start: bool + # was this frame detected as being part of a long tail with potential hallucinations? + long_tail: bool + # was this frame detected as repeating existing text content? + repetition: bool + # was the alignment position of this frame too far from the previous frame? + discontinuity: bool + # has inference reached the end of the text tokens? eg, this remains false if inference stops early + complete: bool + # approximate position in the text token sequence. Can be used for generating online timestamps. + position: int + + +class AlignmentStreamAnalyzer: + def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0): + """ + Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention + activation maps. This module exploits this to perform online integrity checks which streaming. + A hook is injected into the specified attention layer, and heuristics are used to determine alignment + position, repetition, etc. + + NOTE: currently requires no queues. + """ + # self.queue = queue + self.text_tokens_slice = (i, j) = text_tokens_slice + self.eos_idx = eos_idx + self.alignment = torch.zeros(0, j-i) + # self.alignment_bin = torch.zeros(0, j-i) + self.curr_frame_pos = 0 + self.text_position = 0 + + self.started = False + self.started_at = None + + self.complete = False + self.completed_at = None + + # Using `output_attentions=True` is incompatible with optimized attention kernels, so + # using it for all layers slows things down too much. We can apply it to just one layer + # by intercepting the kwargs and adding a forward hook (credit: jrm) + self.last_aligned_attn = None + self._add_attention_spy(tfmr, alignment_layer_idx) + + def _add_attention_spy(self, tfmr, alignment_layer_idx): + """ + Adds a forward hook to a specific attention layer to collect outputs. + Using `output_attentions=True` is incompatible with optimized attention kernels, so + using it for all layers slows things down too much. + (credit: jrm) + """ + + def attention_forward_hook(module, input, output): + """ + See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`. + NOTE: + - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`. + - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th. + """ + step_attention = output[1].cpu() # (B, 16, N, N) + self.last_aligned_attn = step_attention[0].mean(0) # (N, N) + + target_layer = tfmr.layers[alignment_layer_idx].self_attn + hook_handle = target_layer.register_forward_hook(attention_forward_hook) + + # Backup original forward + original_forward = target_layer.forward + def patched_forward(self, *args, **kwargs): + kwargs['output_attentions'] = True + return original_forward(*args, **kwargs) + + # TODO: how to unpatch it? + target_layer.forward = MethodType(patched_forward, target_layer) + + def step(self, logits): + """ + Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS. + """ + # extract approximate alignment matrix chunk (1 frame at a time after the first chunk) + aligned_attn = self.last_aligned_attn # (N, N) + i, j = self.text_tokens_slice + if self.curr_frame_pos == 0: + # first chunk has conditioning info, text tokens, and BOS token + A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S) + else: + # subsequent chunks have 1 frame due to KV-caching + A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S) + + # TODO: monotonic masking; could have issue b/c spaces are often skipped. + A_chunk[:, self.curr_frame_pos + 1:] = 0 + + + self.alignment = torch.cat((self.alignment, A_chunk), dim=0) + + A = self.alignment + T, S = A.shape + + # update position + cur_text_posn = A_chunk[-1].argmax() + discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient! + if not discontinuity: + self.text_position = cur_text_posn + + # Hallucinations at the start of speech show up as activations at the bottom of the attention maps! + # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens, + # and there are some strong activations in the first few tokens. + false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5) + self.started = not false_start + if self.started and self.started_at is None: + self.started_at = T + + # Is generation likely complete? + self.complete = self.complete or self.text_position >= S - 3 + if self.complete and self.completed_at is None: + self.completed_at = T + + # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens. + # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens. + last_text_token_duration = A[15:, -3:].sum() + + # Activations for the final token that last too long are likely hallucinations. + long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms + + # If there are activations in previous tokens after generation has completed, assume this is a repetition error. + repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5) + + # If a bad ending is detected, force emit EOS by modifying logits + # NOTE: this means logits may be inconsistent with latents! + if long_tail or repetition: + logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}") + # (±2**15 is safe for all dtypes >= 16bit) + logits = -(2**15) * torch.ones_like(logits) + logits[..., self.eos_idx] = 2**15 + + # Suppress EoS to prevent early termination + if cur_text_posn < S - 3: # FIXME: arbitrary + logits[..., self.eos_idx] = -2**15 + + self.curr_frame_pos += 1 + return logits diff --git a/orator/src/orator/models/t3/inference/t3_hf_backend.py b/orator/src/orator/models/t3/inference/t3_hf_backend.py index 8d2b175093074e9b8a566ce02a807de9804160a0..6130722ce967b5d82dc0ca29390fb21748358424 100644 --- a/orator/src/orator/models/t3/inference/t3_hf_backend.py +++ b/orator/src/orator/models/t3/inference/t3_hf_backend.py @@ -23,14 +23,14 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin): speech_head, latents_queue=None, logits_queue=None, + alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None, ): super().__init__(config) self.model = llama self.speech_enc = speech_enc self.speech_head = speech_head - self.latents_queue = latents_queue - self.logits_queue = logits_queue self._added_cond = False + self.alignment_stream_analyzer = alignment_stream_analyzer @torch.inference_mode() def prepare_inputs_for_generation( @@ -101,12 +101,12 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin): return_dict=True, ) hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim) - if self.latents_queue is not None: - self.latents_queue.put(hidden_states) logits = self.speech_head(hidden_states) - if self.logits_queue is not None: - self.logits_queue.put(logits) + assert inputs_embeds.size(0) == 1 + + # NOTE: hallucination handler may modify logits to force emit an EOS token + logits = self.alignment_stream_analyzer.step(logits) return CausalLMOutputWithCrossAttentions( logits=logits, diff --git a/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc b/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc index 2c1cbde46e791ff61e0fe278140506bb2d299994..4b99ee1fb64a8d54d165437d566e7cf64223e3cb 100644 Binary files a/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc b/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc index f5fac9d18536a99dd46ef5719c9aba8d574dc727..31256b321e52e439fc42a8d4bcccad9a76eeb7c6 100644 Binary files a/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc b/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc index 39d8bd691b40e87ee83e217913874321f04bfcad..935c51d564eab2bf7d6870592d7b4020a1d8f9d5 100644 Binary files a/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc b/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc index d74a6c5dcf3897b1fcac425735a289fa3901b6e7..31beca9bd0a70041f29c21a2145f09cb06426211 100644 Binary files a/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc differ diff --git a/orator/src/orator/models/t3/modules/perceiver.py b/orator/src/orator/models/t3/modules/perceiver.py index eaa4b87c65832d380741c332c0a8288a4e8a9854..be9c5b863ce43ab43c0124a8ae0fa125b0da9673 100644 --- a/orator/src/orator/models/t3/modules/perceiver.py +++ b/orator/src/orator/models/t3/modules/perceiver.py @@ -1,3 +1,6 @@ +# Copyright (c) 2025 Resemble AI +# Author: Manmay Nakhashi +# MIT License import math import torch @@ -168,6 +171,7 @@ class AttentionBlock2(nn.Module): class Perceiver(nn.Module): + """Inspired by https://arxiv.org/abs/2103.03206""" def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4): """ Initialize the perceiver module. diff --git a/orator/src/orator/models/t3/t3.py b/orator/src/orator/models/t3/t3.py index 39978dfa8588f8d7bbcf2ea639c739119503708e..d1af8a1c07293748f6f7874d76f3aabfabde1633 100644 --- a/orator/src/orator/models/t3/t3.py +++ b/orator/src/orator/models/t3/t3.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025 Resemble AI +# MIT License import logging from typing import Union, Optional, List @@ -10,8 +12,9 @@ from .modules.learned_pos_emb import LearnedPositionEmbeddings from .modules.cond_enc import T3CondEnc, T3Cond from .modules.t3_config import T3Config -from .inference.t3_hf_backend import T3HuggingfaceBackend from .llama_configs import LLAMA_CONFIGS +from .inference.t3_hf_backend import T3HuggingfaceBackend +from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer logger = logging.getLogger(__name__) @@ -221,9 +224,6 @@ class T3(nn.Module): """ Args: text_tokens: a 1D (unbatched) or 2D (batched) tensor. - tokens_queue: if a ReferenceQueue is provided, tokens will be streamed to it during generation. - latents_queue: if a ReferenceQueue is provided, latents will be streamed to it during generation. - logits_queue: if a ReferenceQueue is provided, logits will be streamed to it during generation. """ # Validate / sanitize inputs assert prepend_prompt_speech_tokens is None, "not implemented" @@ -235,7 +235,7 @@ class T3(nn.Module): initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1]) # Prepare custom input embeds - embeds, _ = self.prepare_input_embeds( + embeds, len_cond = self.prepare_input_embeds( t3_cond=t3_cond, text_tokens=text_tokens, speech_tokens=initial_speech_tokens, @@ -249,11 +249,19 @@ class T3(nn.Module): # TODO? synchronize the expensive compile function # with self.compile_lock: if not self.compiled: + alignment_stream_analyzer = AlignmentStreamAnalyzer( + self.tfmr, + None, + text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)), + alignment_layer_idx=9, # TODO: hparam or something? + eos_idx=self.hp.stop_speech_token, + ) patched_model = T3HuggingfaceBackend( config=self.cfg, llama=self.tfmr, speech_enc=self.speech_emb, speech_head=self.speech_head, + alignment_stream_analyzer=alignment_stream_analyzer, ) self.patched_model = patched_model self.compiled = True diff --git a/orator/src/orator/models/tokenizers/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/tokenizers/__pycache__/__init__.cpython-311.pyc index c275725d8ec6d80c237b139479031d3ede3646b9..fad1839ea6e3aba23d1aece28f34443d95f441c8 100644 Binary files a/orator/src/orator/models/tokenizers/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/tokenizers/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc b/orator/src/orator/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc index 835bf5816abdb64bc97371e9f597a618a6d337ad..22fff1b2973015e0489db75db537633e880c2c98 100644 Binary files a/orator/src/orator/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc and b/orator/src/orator/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc differ diff --git a/orator/src/orator/models/voice_encoder/__pycache__/__init__.cpython-311.pyc b/orator/src/orator/models/voice_encoder/__pycache__/__init__.cpython-311.pyc index 734cd8a35d12eb71f0f1c29d55bbb39347e42839..f3a8c6b767936a3466e9ede2182d17a5e87de5ff 100644 Binary files a/orator/src/orator/models/voice_encoder/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/voice_encoder/__pycache__/__init__.cpython-311.pyc differ diff --git a/orator/src/orator/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc b/orator/src/orator/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc index c93d77f6ac7f26d75b260410a5cc2b90b54b5fa5..f7e8414b8ab2656a531547f37a18d85c88628fc1 100644 Binary files a/orator/src/orator/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc and b/orator/src/orator/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc differ diff --git a/orator/src/orator/models/voice_encoder/config.py b/orator/src/orator/models/voice_encoder/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9782a20eac8bc41afaf38d80a8af862adac232 --- /dev/null +++ b/orator/src/orator/models/voice_encoder/config.py @@ -0,0 +1,18 @@ +class VoiceEncConfig: + num_mels = 40 + sample_rate = 16000 + speaker_embed_size = 256 + ve_hidden_size = 256 + flatten_lstm_params = False + n_fft = 400 + hop_size = 160 + win_size = 400 + fmax = 8000 + fmin = 0 + preemphasis = 0. + mel_power = 2.0 + mel_type = "amp" + normalized_mels = False + ve_partial_frames = 160 + ve_final_relu = True + stft_magnitude_min = 1e-4 diff --git a/orator/src/orator/models/voice_encoder/melspec.py b/orator/src/orator/models/voice_encoder/melspec.py new file mode 100644 index 0000000000000000000000000000000000000000..6b324ae0af4dfcb432488bacb6bea9b320dd272e --- /dev/null +++ b/orator/src/orator/models/voice_encoder/melspec.py @@ -0,0 +1,75 @@ +from functools import lru_cache + +import numpy as np +import torch +from torchaudio.transforms import MelSpectrogram + +from .config import VoiceEncConfig + + +class ResembleMelSpectrogram(torch.nn.Module): + def __init__(self, hp=VoiceEncConfig()): + """ + Torch implementation of Resemble's mel extraction. + Note that the values are NOT identical to librosa's implementation due to floating point precisions, however + the results are very very close. One test file gave an L1 error of just 0.005%, full results: + Librosa mel max: 0.871768 + Torch mel max: 0.871768 + Librosa mel mean: 0.316302 + Torch mel mean: 0.316289 + Max diff: 0.061105 + Mean diff: 1.453384e-05 + Percent error: 0.004595% + """ + super().__init__() + self.melspec = MelSpectrogram( + hp.sample_rate, + n_fft=hp.n_fft, + win_length=hp.win_size, + hop_length=hp.hop_size, + f_min=hp.fmin, + f_max=hp.fmax, + n_mels=hp.num_mels, + power=1, + normalized=False, + # NOTE: Folowing librosa's default. + pad_mode="constant", + norm="slaney", + mel_scale="slaney", + ) + self.register_buffer( + "stft_magnitude_min", + torch.FloatTensor([hp.stft_magnitude_min]) + ) + self.min_level_db = 20 * np.log10(hp.stft_magnitude_min) + self.preemphasis = hp.preemphasis + self.hop_size = hp.hop_size + + def forward(self, wav, pad=True): + """ + Args: + wav: [B, T] + """ + if self.preemphasis > 0: + wav = torch.nn.functional.pad(wav, [1, 0], value=0) + wav = wav[..., 1:] - self.preemphasis * wav[..., :-1] + + mel = self.melspec(wav) + + mel = self._amp_to_db(mel) + mel_normed = self._normalize(mel) + assert not pad or mel_normed.shape[-1] == 1 + \ + wav.shape[-1] // self.hop_size # Sanity check + return mel_normed # (M, T) + + def _normalize(self, s, headroom_db=15): + s = (s - self.min_level_db) / (-self.min_level_db + headroom_db) + return s + + def _amp_to_db(self, x): + return 20 * torch.maximum(self.stft_magnitude_min, x).log10() + + +@lru_cache() +def melspectrogram(): + return ResembleMelSpectrogram() diff --git a/orator/src/orator/models/voice_encoder/voice_encoder.py b/orator/src/orator/models/voice_encoder/voice_encoder.py index 68b398d113eedd4f77a91ebd143389aa2be69b15..41c5f6ca381971c5cce90a05c66cd764c6d28396 100644 --- a/orator/src/orator/models/voice_encoder/voice_encoder.py +++ b/orator/src/orator/models/voice_encoder/voice_encoder.py @@ -1,37 +1,54 @@ # Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning # MIT License - from typing import List, Union, Optional import numpy as np from numpy.lib.stride_tricks import as_strided import librosa -from librosa import resample import torch import torch.nn.functional as F from torch import nn, Tensor -from ....orator.transforms.spectrogram import melspectrogram -from ....orator.transforms.syn_transforms import pack - - -class VoiceEncConfig: - num_mels = 40 - sample_rate = 16000 - speaker_embed_size = 256 - ve_hidden_size = 256 - flatten_lstm_params = False - n_fft = 400 - hop_size = 160 - win_size = 400 - fmax = 8000 - fmin = 0 - preemphasis = 0. - mel_power = 2.0 - mel_type = "amp" - normalized_mels = False - ve_partial_frames = 160 - ve_final_relu = True +from .config import VoiceEncConfig +from .melspec import melspectrogram + + +def pack(arrays, seq_len: int=None, pad_value=0): + """ + Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of + shape (B, T, ...) by padding each individual array on the right. + + :param arrays: a list of array-like objects of matching shapes except for the first axis. + :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at + minimum. Will default to that value if None. + :param pad_value: the value to pad the arrays with. + :return: a (B, T, ...) tensor + """ + if seq_len is None: + seq_len = max(len(array) for array in arrays) + else: + assert seq_len >= max(len(array) for array in arrays) + + # Convert lists to np.array + if isinstance(arrays[0], list): + arrays = [np.array(array) for array in arrays] + + # Convert to tensor and handle device + device = None + if isinstance(arrays[0], torch.Tensor): + tensors = arrays + device = tensors[0].device + else: + tensors = [torch.as_tensor(array) for array in arrays] + + # Fill the packed tensor with the array data + packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:]) + packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device) + + for i, tensor in enumerate(tensors): + packed_tensor[i, :tensor.size(0)] = tensor + + return packed_tensor def get_num_wins( @@ -242,7 +259,7 @@ class VoiceEncoder(nn.Module): """ if sample_rate != self.hp.sample_rate: wavs = [ - resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast") + librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast") for wav in wavs ] @@ -252,5 +269,8 @@ class VoiceEncoder(nn.Module): if "rate" not in kwargs: kwargs["rate"] = 1.3 # Resemble's default value. - mels = [melspectrogram(w, self.hp).T for w in wavs] + mel_func = melspectrogram() + mels = [mel_func(torch.from_numpy(w) + [None])[0].T for w in wavs] + return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs) diff --git a/orator/src/orator/transforms/__pycache__/spectrogram.cpython-311.pyc b/orator/src/orator/transforms/__pycache__/spectrogram.cpython-311.pyc index e3a682cb6495b552f14d322b779476cf250fc1d0..b5ec09740da2846bda8787e4ae5c8d64d4d9989a 100644 Binary files a/orator/src/orator/transforms/__pycache__/spectrogram.cpython-311.pyc and b/orator/src/orator/transforms/__pycache__/spectrogram.cpython-311.pyc differ diff --git a/orator/src/orator/transforms/__pycache__/syn_transforms.cpython-311.pyc b/orator/src/orator/transforms/__pycache__/syn_transforms.cpython-311.pyc index 037b45e08c6f8d5db21241c609a0e8212fe10d2f..2f220a91e3eec386e62a21db3d03ca6648cb527f 100644 Binary files a/orator/src/orator/transforms/__pycache__/syn_transforms.cpython-311.pyc and b/orator/src/orator/transforms/__pycache__/syn_transforms.cpython-311.pyc differ