Spaces:
Runtime error
Runtime error
Commit
·
99cc645
1
Parent(s):
c30668f
add alignment
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- orator/src/orator/__pycache__/__init__.cpython-311.pyc +0 -0
- orator/src/orator/__pycache__/tts.cpython-311.pyc +0 -0
- orator/src/orator/models/bigvgan/__pycache__/activations.cpython-311.pyc +0 -0
- orator/src/orator/models/bigvgan/__pycache__/bigvgan.cpython-311.pyc +0 -0
- orator/src/orator/models/bigvgan/activations.py +120 -0
- orator/src/orator/models/bigvgan/alias_free_torch/__init__.py +6 -0
- orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/__init__.cpython-311.pyc +0 -0
- orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/act.cpython-311.pyc +0 -0
- orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/filter.cpython-311.pyc +0 -0
- orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/resample.cpython-311.pyc +0 -0
- orator/src/orator/models/bigvgan/alias_free_torch/act.py +28 -0
- orator/src/orator/models/bigvgan/alias_free_torch/filter.py +95 -0
- orator/src/orator/models/bigvgan/alias_free_torch/resample.py +55 -0
- orator/src/orator/models/bigvgan/bigvgan.py +212 -0
- orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc +0 -0
- orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc +0 -0
- orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc +0 -0
- orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc +0 -0
- orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc +0 -0
- orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc +0 -0
- orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc +0 -0
- orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc +0 -0
- orator/src/orator/models/t3/inference/alignment_stream_analyzer.py +154 -0
- orator/src/orator/models/t3/inference/t3_hf_backend.py +6 -6
- orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc +0 -0
- orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc +0 -0
- orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc +0 -0
- orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc +0 -0
orator/src/orator/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/__pycache__/__init__.cpython-311.pyc differ
|
|
orator/src/orator/__pycache__/tts.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/__pycache__/tts.cpython-311.pyc and b/orator/src/orator/__pycache__/tts.cpython-311.pyc differ
|
|
orator/src/orator/models/bigvgan/__pycache__/activations.cpython-311.pyc
ADDED
Binary file (6.09 kB). View file
|
|
orator/src/orator/models/bigvgan/__pycache__/bigvgan.cpython-311.pyc
ADDED
Binary file (13.3 kB). View file
|
|
orator/src/orator/models/bigvgan/activations.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, sin, pow
|
6 |
+
from torch.nn import Parameter
|
7 |
+
|
8 |
+
|
9 |
+
class Snake(nn.Module):
|
10 |
+
'''
|
11 |
+
Implementation of a sine-based periodic activation function
|
12 |
+
Shape:
|
13 |
+
- Input: (B, C, T)
|
14 |
+
- Output: (B, C, T), same shape as the input
|
15 |
+
Parameters:
|
16 |
+
- alpha - trainable parameter
|
17 |
+
References:
|
18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
19 |
+
https://arxiv.org/abs/2006.08195
|
20 |
+
Examples:
|
21 |
+
>>> a1 = snake(256)
|
22 |
+
>>> x = torch.randn(256)
|
23 |
+
>>> x = a1(x)
|
24 |
+
'''
|
25 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
26 |
+
'''
|
27 |
+
Initialization.
|
28 |
+
INPUT:
|
29 |
+
- in_features: shape of the input
|
30 |
+
- alpha: trainable parameter
|
31 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
32 |
+
alpha will be trained along with the rest of your model.
|
33 |
+
'''
|
34 |
+
super(Snake, self).__init__()
|
35 |
+
self.in_features = in_features
|
36 |
+
|
37 |
+
# initialize alpha
|
38 |
+
self.alpha_logscale = alpha_logscale
|
39 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
40 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
41 |
+
else: # linear scale alphas initialized to ones
|
42 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
43 |
+
|
44 |
+
self.alpha.requires_grad = alpha_trainable
|
45 |
+
|
46 |
+
self.no_div_by_zero = 0.000000001
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
'''
|
50 |
+
Forward pass of the function.
|
51 |
+
Applies the function to the input elementwise.
|
52 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
53 |
+
'''
|
54 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
55 |
+
if self.alpha_logscale:
|
56 |
+
alpha = torch.exp(alpha)
|
57 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
58 |
+
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class SnakeBeta(nn.Module):
|
63 |
+
'''
|
64 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
65 |
+
Shape:
|
66 |
+
- Input: (B, C, T)
|
67 |
+
- Output: (B, C, T), same shape as the input
|
68 |
+
Parameters:
|
69 |
+
- alpha - trainable parameter that controls frequency
|
70 |
+
- beta - trainable parameter that controls magnitude
|
71 |
+
References:
|
72 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
73 |
+
https://arxiv.org/abs/2006.08195
|
74 |
+
Examples:
|
75 |
+
>>> a1 = snakebeta(256)
|
76 |
+
>>> x = torch.randn(256)
|
77 |
+
>>> x = a1(x)
|
78 |
+
'''
|
79 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
80 |
+
'''
|
81 |
+
Initialization.
|
82 |
+
INPUT:
|
83 |
+
- in_features: shape of the input
|
84 |
+
- alpha - trainable parameter that controls frequency
|
85 |
+
- beta - trainable parameter that controls magnitude
|
86 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
87 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
88 |
+
alpha will be trained along with the rest of your model.
|
89 |
+
'''
|
90 |
+
super(SnakeBeta, self).__init__()
|
91 |
+
self.in_features = in_features
|
92 |
+
|
93 |
+
# initialize alpha
|
94 |
+
self.alpha_logscale = alpha_logscale
|
95 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
96 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
97 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
98 |
+
else: # linear scale alphas initialized to ones
|
99 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
100 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
101 |
+
|
102 |
+
self.alpha.requires_grad = alpha_trainable
|
103 |
+
self.beta.requires_grad = alpha_trainable
|
104 |
+
|
105 |
+
self.no_div_by_zero = 0.000000001
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
'''
|
109 |
+
Forward pass of the function.
|
110 |
+
Applies the function to the input elementwise.
|
111 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
112 |
+
'''
|
113 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
114 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
115 |
+
if self.alpha_logscale:
|
116 |
+
alpha = torch.exp(alpha)
|
117 |
+
beta = torch.exp(beta)
|
118 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
119 |
+
|
120 |
+
return x
|
orator/src/orator/models/bigvgan/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
6 |
+
from .act import *
|
orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (281 Bytes). View file
|
|
orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/act.cpython-311.pyc
ADDED
Binary file (1.67 kB). View file
|
|
orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/filter.cpython-311.pyc
ADDED
Binary file (4.51 kB). View file
|
|
orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/resample.cpython-311.pyc
ADDED
Binary file (3.43 kB). View file
|
|
orator/src/orator/models/bigvgan/alias_free_torch/act.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .resample import UpSample1d, DownSample1d
|
7 |
+
|
8 |
+
|
9 |
+
class Activation1d(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
activation,
|
12 |
+
up_ratio: int = 2,
|
13 |
+
down_ratio: int = 2,
|
14 |
+
up_kernel_size: int = 12,
|
15 |
+
down_kernel_size: int = 12):
|
16 |
+
super().__init__()
|
17 |
+
self.up_ratio = up_ratio
|
18 |
+
self.down_ratio = down_ratio
|
19 |
+
self.act = activation
|
20 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
21 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
22 |
+
|
23 |
+
# x: [B, C, T]
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.upsample(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.downsample(x)
|
28 |
+
return x
|
orator/src/orator/models/bigvgan/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
if 'sinc' in dir(torch):
|
12 |
+
sinc = torch.sinc
|
13 |
+
else:
|
14 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
15 |
+
# https://adefossez.github.io/julius/julius/core.html
|
16 |
+
# LICENSE is in incl_licenses directory.
|
17 |
+
def sinc(x: torch.Tensor):
|
18 |
+
"""
|
19 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
20 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
21 |
+
"""
|
22 |
+
return torch.where(x == 0,
|
23 |
+
torch.tensor(1., device=x.device, dtype=x.dtype),
|
24 |
+
torch.sin(math.pi * x) / math.pi / x)
|
25 |
+
|
26 |
+
|
27 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
28 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
29 |
+
# LICENSE is in incl_licenses directory.
|
30 |
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
31 |
+
even = (kernel_size % 2 == 0)
|
32 |
+
half_size = kernel_size // 2
|
33 |
+
|
34 |
+
#For kaiser window
|
35 |
+
delta_f = 4 * half_width
|
36 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
37 |
+
if A > 50.:
|
38 |
+
beta = 0.1102 * (A - 8.7)
|
39 |
+
elif A >= 21.:
|
40 |
+
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
41 |
+
else:
|
42 |
+
beta = 0.
|
43 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
44 |
+
|
45 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
46 |
+
if even:
|
47 |
+
time = (torch.arange(-half_size, half_size) + 0.5)
|
48 |
+
else:
|
49 |
+
time = torch.arange(kernel_size) - half_size
|
50 |
+
if cutoff == 0:
|
51 |
+
filter_ = torch.zeros_like(time)
|
52 |
+
else:
|
53 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
54 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
55 |
+
# of the constant component in the input signal.
|
56 |
+
filter_ /= filter_.sum()
|
57 |
+
filter = filter_.view(1, 1, kernel_size)
|
58 |
+
|
59 |
+
return filter
|
60 |
+
|
61 |
+
|
62 |
+
class LowPassFilter1d(nn.Module):
|
63 |
+
def __init__(self,
|
64 |
+
cutoff=0.5,
|
65 |
+
half_width=0.6,
|
66 |
+
stride: int = 1,
|
67 |
+
padding: bool = True,
|
68 |
+
padding_mode: str = 'replicate',
|
69 |
+
kernel_size: int = 12):
|
70 |
+
# kernel_size should be even number for stylegan3 setup,
|
71 |
+
# in this implementation, odd number is also possible.
|
72 |
+
super().__init__()
|
73 |
+
if cutoff < -0.:
|
74 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
75 |
+
if cutoff > 0.5:
|
76 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
77 |
+
self.kernel_size = kernel_size
|
78 |
+
self.even = (kernel_size % 2 == 0)
|
79 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
80 |
+
self.pad_right = kernel_size // 2
|
81 |
+
self.stride = stride
|
82 |
+
self.padding = padding
|
83 |
+
self.padding_mode = padding_mode
|
84 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
85 |
+
self.register_buffer("filter", filter)
|
86 |
+
|
87 |
+
#input [B, C, T]
|
88 |
+
def forward(self, x):
|
89 |
+
_, C, _ = x.shape
|
90 |
+
|
91 |
+
if self.padding:
|
92 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
93 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
94 |
+
|
95 |
+
return out
|
orator/src/orator/models/bigvgan/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from .filter import LowPassFilter1d
|
8 |
+
from .filter import kaiser_sinc_filter1d
|
9 |
+
|
10 |
+
|
11 |
+
class UpSample1d(nn.Module):
|
12 |
+
def __init__(self, ratio=2, kernel_size=None):
|
13 |
+
super().__init__()
|
14 |
+
self.ratio = ratio
|
15 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
16 |
+
self.stride = ratio
|
17 |
+
self.pad = self.kernel_size // ratio - 1
|
18 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
19 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
20 |
+
filter = kaiser_sinc_filter1d(
|
21 |
+
cutoff=0.5 / ratio,
|
22 |
+
half_width=0.6 / ratio,
|
23 |
+
kernel_size=self.kernel_size
|
24 |
+
)
|
25 |
+
self.register_buffer("filter", filter)
|
26 |
+
|
27 |
+
# x: [B, C, T]
|
28 |
+
def forward(self, x):
|
29 |
+
_, C, _ = x.shape
|
30 |
+
|
31 |
+
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
32 |
+
x = self.ratio * F.conv_transpose1d(
|
33 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
34 |
+
)
|
35 |
+
x = x[..., self.pad_left:-self.pad_right]
|
36 |
+
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class DownSample1d(nn.Module):
|
41 |
+
def __init__(self, ratio=2, kernel_size=None):
|
42 |
+
super().__init__()
|
43 |
+
self.ratio = ratio
|
44 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
45 |
+
self.lowpass = LowPassFilter1d(
|
46 |
+
cutoff=0.5 / ratio,
|
47 |
+
half_width=0.6 / ratio,
|
48 |
+
stride=ratio,
|
49 |
+
kernel_size=self.kernel_size
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
xx = self.lowpass(x)
|
54 |
+
|
55 |
+
return xx
|
orator/src/orator/models/bigvgan/bigvgan.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
4 |
+
# LICENSE is in incl_licenses directory.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
8 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
9 |
+
from torch.nn.utils.weight_norm import WeightNorm
|
10 |
+
|
11 |
+
from .activations import SnakeBeta
|
12 |
+
from .alias_free_torch import *
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
LRELU_SLOPE = 0.1
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def get_padding(kernel_size, dilation=1):
|
22 |
+
return int((kernel_size*dilation - dilation)/2)
|
23 |
+
|
24 |
+
|
25 |
+
def init_weights(m, mean=0.0, std=0.01):
|
26 |
+
classname = m.__class__.__name__
|
27 |
+
if classname.find("Conv") != -1:
|
28 |
+
m.weight.data.normal_(mean, std)
|
29 |
+
|
30 |
+
|
31 |
+
class AMPBlock1(torch.nn.Module):
|
32 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
33 |
+
super(AMPBlock1, self).__init__()
|
34 |
+
|
35 |
+
self.convs1 = nn.ModuleList([
|
36 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
37 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
38 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
39 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
40 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
41 |
+
padding=get_padding(kernel_size, dilation[2])))
|
42 |
+
])
|
43 |
+
self.convs1.apply(init_weights)
|
44 |
+
|
45 |
+
self.convs2 = nn.ModuleList([
|
46 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
|
47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
|
48 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)))
|
49 |
+
])
|
50 |
+
self.convs2.apply(init_weights)
|
51 |
+
|
52 |
+
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
53 |
+
|
54 |
+
self.activations = nn.ModuleList([
|
55 |
+
Activation1d(activation=SnakeBeta(channels, alpha_logscale=True))
|
56 |
+
for _ in range(self.num_layers)
|
57 |
+
])
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
61 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
62 |
+
xt = a1(x)
|
63 |
+
xt = c1(xt)
|
64 |
+
xt = a2(xt)
|
65 |
+
xt = c2(xt)
|
66 |
+
x = xt + x
|
67 |
+
|
68 |
+
return x
|
69 |
+
|
70 |
+
def set_weight_norm(self, enabled: bool):
|
71 |
+
weight_norm_fn = weight_norm if enabled else remove_weight_norm
|
72 |
+
for l in self.convs1:
|
73 |
+
weight_norm_fn(l)
|
74 |
+
for l in self.convs2:
|
75 |
+
weight_norm_fn(l)
|
76 |
+
|
77 |
+
|
78 |
+
class BigVGAN(nn.Module):
|
79 |
+
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
80 |
+
|
81 |
+
# We've got a model in prod that has the wrong hparams for this. It's simpler to add this check than to
|
82 |
+
# redistribute the model.
|
83 |
+
ignore_state_dict_unexpected = ("cond_layer.*",)
|
84 |
+
|
85 |
+
def __init__(self):
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
input_dims = 80
|
89 |
+
|
90 |
+
upsample_rates = [10, 8, 4, 2]
|
91 |
+
upsample_kernel_sizes = [x * 2 for x in upsample_rates]
|
92 |
+
upsample_initial_channel = 1024
|
93 |
+
|
94 |
+
resblock_kernel_sizes = [3, 7, 11]
|
95 |
+
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
96 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
97 |
+
self.num_upsamples = len(upsample_rates)
|
98 |
+
|
99 |
+
# pre conv
|
100 |
+
self.conv_pre = weight_norm(Conv1d(input_dims, upsample_initial_channel, 7, 1, padding=3))
|
101 |
+
self.cond_layer = None
|
102 |
+
|
103 |
+
# transposed conv-based upsamplers. does not apply anti-aliasing
|
104 |
+
self.ups = nn.ModuleList()
|
105 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
106 |
+
self.ups.append(nn.ModuleList([
|
107 |
+
weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i),
|
108 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
109 |
+
k, u, padding=(k - u) // 2))
|
110 |
+
]))
|
111 |
+
|
112 |
+
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
113 |
+
self.resblocks = nn.ModuleList()
|
114 |
+
for i in range(len(self.ups)):
|
115 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
116 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
117 |
+
self.resblocks.append(AMPBlock1(ch, k, d))
|
118 |
+
|
119 |
+
# post conv
|
120 |
+
activation_post = SnakeBeta(ch, alpha_logscale=True)
|
121 |
+
self.activation_post = Activation1d(activation=activation_post)
|
122 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
123 |
+
|
124 |
+
# weight initialization
|
125 |
+
for i in range(len(self.ups)):
|
126 |
+
self.ups[i].apply(init_weights)
|
127 |
+
self.conv_post.apply(init_weights)
|
128 |
+
|
129 |
+
def forward(self, x) -> torch.Tensor:
|
130 |
+
"""
|
131 |
+
Args
|
132 |
+
----
|
133 |
+
x: torch.Tensor of shape [B, T, C]
|
134 |
+
"""
|
135 |
+
with torch.inference_mode():
|
136 |
+
|
137 |
+
x = self.conv_pre(x)
|
138 |
+
|
139 |
+
for i in range(self.num_upsamples):
|
140 |
+
# upsampling
|
141 |
+
for i_up in range(len(self.ups[i])):
|
142 |
+
x = self.ups[i][i_up](x)
|
143 |
+
# AMP blocks
|
144 |
+
xs = None
|
145 |
+
for j in range(self.num_kernels):
|
146 |
+
if xs is None:
|
147 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
148 |
+
else:
|
149 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
150 |
+
x = xs / self.num_kernels
|
151 |
+
|
152 |
+
# post conv
|
153 |
+
x = self.activation_post(x)
|
154 |
+
x = self.conv_post(x)
|
155 |
+
|
156 |
+
# Bound the output to [-1, 1]
|
157 |
+
x = torch.tanh(x)
|
158 |
+
|
159 |
+
return x
|
160 |
+
|
161 |
+
@property
|
162 |
+
def weight_norm_enabled(self) -> bool:
|
163 |
+
return any(
|
164 |
+
isinstance(hook, WeightNorm) and hook.name == "weight"
|
165 |
+
for k, hook in self.conv_pre._forward_pre_hooks.items()
|
166 |
+
)
|
167 |
+
|
168 |
+
def set_weight_norm(self, enabled: bool):
|
169 |
+
"""
|
170 |
+
N.B.: weight norm modifies the state dict, causing incompatibilities. Conventions:
|
171 |
+
- BigVGAN runs with weight norm for training, without for inference (done automatically by instantiate())
|
172 |
+
- All checkpoints are saved with weight norm (allows resuming training)
|
173 |
+
"""
|
174 |
+
if enabled != self.weight_norm_enabled:
|
175 |
+
weight_norm_fn = weight_norm if enabled else remove_weight_norm
|
176 |
+
logger.debug(f"{'Applying' if enabled else 'Removing'} weight norm...")
|
177 |
+
|
178 |
+
for l in self.ups:
|
179 |
+
for l_i in l:
|
180 |
+
weight_norm_fn(l_i)
|
181 |
+
for l in self.resblocks:
|
182 |
+
l.set_weight_norm(enabled)
|
183 |
+
weight_norm_fn(self.conv_pre)
|
184 |
+
weight_norm_fn(self.conv_post)
|
185 |
+
|
186 |
+
def train_mode(self):
|
187 |
+
self.train()
|
188 |
+
self.set_weight_norm(enabled=True)
|
189 |
+
|
190 |
+
def inference_mode(self):
|
191 |
+
self.eval()
|
192 |
+
self.set_weight_norm(enabled=False)
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
import sys
|
197 |
+
import soundfile as sf
|
198 |
+
model = BigVGAN()
|
199 |
+
|
200 |
+
state_dict = torch.load("bigvgan32k.pt")
|
201 |
+
msg = model.load_state_dict(state_dict)
|
202 |
+
model.eval()
|
203 |
+
model.set_weight_norm(enabled=False)
|
204 |
+
|
205 |
+
print(msg)
|
206 |
+
mels = torch.load("mels.pt")
|
207 |
+
with torch.inference_mode():
|
208 |
+
y = model(mels.cpu())
|
209 |
+
|
210 |
+
for i, wav in enumerate(y):
|
211 |
+
wav = wav.view(-1).detach().numpy()
|
212 |
+
sf.write(f"bigvgan_test{i}.flac", wav, samplerate=32_000, format="FLAC")
|
orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc differ
|
|
orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc differ
|
|
orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc differ
|
|
orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc and b/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc differ
|
|
orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc differ
|
|
orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc differ
|
|
orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc differ
|
|
orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc and b/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc differ
|
|
orator/src/orator/models/t3/inference/alignment_stream_analyzer.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Resemble AI
|
2 |
+
# Author: John Meade, Jeremy Hsu
|
3 |
+
# MIT License
|
4 |
+
import logging
|
5 |
+
import torch
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from types import MethodType
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class AlignmentAnalysisResult:
|
15 |
+
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
|
16 |
+
false_start: bool
|
17 |
+
# was this frame detected as being part of a long tail with potential hallucinations?
|
18 |
+
long_tail: bool
|
19 |
+
# was this frame detected as repeating existing text content?
|
20 |
+
repetition: bool
|
21 |
+
# was the alignment position of this frame too far from the previous frame?
|
22 |
+
discontinuity: bool
|
23 |
+
# has inference reached the end of the text tokens? eg, this remains false if inference stops early
|
24 |
+
complete: bool
|
25 |
+
# approximate position in the text token sequence. Can be used for generating online timestamps.
|
26 |
+
position: int
|
27 |
+
|
28 |
+
|
29 |
+
class AlignmentStreamAnalyzer:
|
30 |
+
def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
|
31 |
+
"""
|
32 |
+
Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
|
33 |
+
activation maps. This module exploits this to perform online integrity checks which streaming.
|
34 |
+
A hook is injected into the specified attention layer, and heuristics are used to determine alignment
|
35 |
+
position, repetition, etc.
|
36 |
+
|
37 |
+
NOTE: currently requires no queues.
|
38 |
+
"""
|
39 |
+
# self.queue = queue
|
40 |
+
self.text_tokens_slice = (i, j) = text_tokens_slice
|
41 |
+
self.eos_idx = eos_idx
|
42 |
+
self.alignment = torch.zeros(0, j-i)
|
43 |
+
# self.alignment_bin = torch.zeros(0, j-i)
|
44 |
+
self.curr_frame_pos = 0
|
45 |
+
self.text_position = 0
|
46 |
+
|
47 |
+
self.started = False
|
48 |
+
self.started_at = None
|
49 |
+
|
50 |
+
self.complete = False
|
51 |
+
self.completed_at = None
|
52 |
+
|
53 |
+
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
54 |
+
# using it for all layers slows things down too much. We can apply it to just one layer
|
55 |
+
# by intercepting the kwargs and adding a forward hook (credit: jrm)
|
56 |
+
self.last_aligned_attn = None
|
57 |
+
self._add_attention_spy(tfmr, alignment_layer_idx)
|
58 |
+
|
59 |
+
def _add_attention_spy(self, tfmr, alignment_layer_idx):
|
60 |
+
"""
|
61 |
+
Adds a forward hook to a specific attention layer to collect outputs.
|
62 |
+
Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
63 |
+
using it for all layers slows things down too much.
|
64 |
+
(credit: jrm)
|
65 |
+
"""
|
66 |
+
|
67 |
+
def attention_forward_hook(module, input, output):
|
68 |
+
"""
|
69 |
+
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
|
70 |
+
NOTE:
|
71 |
+
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
|
72 |
+
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
|
73 |
+
"""
|
74 |
+
step_attention = output[1].cpu() # (B, 16, N, N)
|
75 |
+
self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
|
76 |
+
|
77 |
+
target_layer = tfmr.layers[alignment_layer_idx].self_attn
|
78 |
+
hook_handle = target_layer.register_forward_hook(attention_forward_hook)
|
79 |
+
|
80 |
+
# Backup original forward
|
81 |
+
original_forward = target_layer.forward
|
82 |
+
def patched_forward(self, *args, **kwargs):
|
83 |
+
kwargs['output_attentions'] = True
|
84 |
+
return original_forward(*args, **kwargs)
|
85 |
+
|
86 |
+
# TODO: how to unpatch it?
|
87 |
+
target_layer.forward = MethodType(patched_forward, target_layer)
|
88 |
+
|
89 |
+
def step(self, logits):
|
90 |
+
"""
|
91 |
+
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
|
92 |
+
"""
|
93 |
+
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
|
94 |
+
aligned_attn = self.last_aligned_attn # (N, N)
|
95 |
+
i, j = self.text_tokens_slice
|
96 |
+
if self.curr_frame_pos == 0:
|
97 |
+
# first chunk has conditioning info, text tokens, and BOS token
|
98 |
+
A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
|
99 |
+
else:
|
100 |
+
# subsequent chunks have 1 frame due to KV-caching
|
101 |
+
A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
|
102 |
+
|
103 |
+
# TODO: monotonic masking; could have issue b/c spaces are often skipped.
|
104 |
+
A_chunk[:, self.curr_frame_pos + 1:] = 0
|
105 |
+
|
106 |
+
|
107 |
+
self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
|
108 |
+
|
109 |
+
A = self.alignment
|
110 |
+
T, S = A.shape
|
111 |
+
|
112 |
+
# update position
|
113 |
+
cur_text_posn = A_chunk[-1].argmax()
|
114 |
+
discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
|
115 |
+
if not discontinuity:
|
116 |
+
self.text_position = cur_text_posn
|
117 |
+
|
118 |
+
# Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
|
119 |
+
# To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
|
120 |
+
# and there are some strong activations in the first few tokens.
|
121 |
+
false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
|
122 |
+
self.started = not false_start
|
123 |
+
if self.started and self.started_at is None:
|
124 |
+
self.started_at = T
|
125 |
+
|
126 |
+
# Is generation likely complete?
|
127 |
+
self.complete = self.complete or self.text_position >= S - 3
|
128 |
+
if self.complete and self.completed_at is None:
|
129 |
+
self.completed_at = T
|
130 |
+
|
131 |
+
# NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
|
132 |
+
# NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
|
133 |
+
last_text_token_duration = A[15:, -3:].sum()
|
134 |
+
|
135 |
+
# Activations for the final token that last too long are likely hallucinations.
|
136 |
+
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
|
137 |
+
|
138 |
+
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
|
139 |
+
repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
|
140 |
+
|
141 |
+
# If a bad ending is detected, force emit EOS by modifying logits
|
142 |
+
# NOTE: this means logits may be inconsistent with latents!
|
143 |
+
if long_tail or repetition:
|
144 |
+
logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
|
145 |
+
# (±2**15 is safe for all dtypes >= 16bit)
|
146 |
+
logits = -(2**15) * torch.ones_like(logits)
|
147 |
+
logits[..., self.eos_idx] = 2**15
|
148 |
+
|
149 |
+
# Suppress EoS to prevent early termination
|
150 |
+
if cur_text_posn < S - 3: # FIXME: arbitrary
|
151 |
+
logits[..., self.eos_idx] = -2**15
|
152 |
+
|
153 |
+
self.curr_frame_pos += 1
|
154 |
+
return logits
|
orator/src/orator/models/t3/inference/t3_hf_backend.py
CHANGED
@@ -23,14 +23,14 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
|
|
23 |
speech_head,
|
24 |
latents_queue=None,
|
25 |
logits_queue=None,
|
|
|
26 |
):
|
27 |
super().__init__(config)
|
28 |
self.model = llama
|
29 |
self.speech_enc = speech_enc
|
30 |
self.speech_head = speech_head
|
31 |
-
self.latents_queue = latents_queue
|
32 |
-
self.logits_queue = logits_queue
|
33 |
self._added_cond = False
|
|
|
34 |
|
35 |
@torch.inference_mode()
|
36 |
def prepare_inputs_for_generation(
|
@@ -101,12 +101,12 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
|
|
101 |
return_dict=True,
|
102 |
)
|
103 |
hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
|
104 |
-
if self.latents_queue is not None:
|
105 |
-
self.latents_queue.put(hidden_states)
|
106 |
|
107 |
logits = self.speech_head(hidden_states)
|
108 |
-
|
109 |
-
|
|
|
|
|
110 |
|
111 |
return CausalLMOutputWithCrossAttentions(
|
112 |
logits=logits,
|
|
|
23 |
speech_head,
|
24 |
latents_queue=None,
|
25 |
logits_queue=None,
|
26 |
+
alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
|
27 |
):
|
28 |
super().__init__(config)
|
29 |
self.model = llama
|
30 |
self.speech_enc = speech_enc
|
31 |
self.speech_head = speech_head
|
|
|
|
|
32 |
self._added_cond = False
|
33 |
+
self.alignment_stream_analyzer = alignment_stream_analyzer
|
34 |
|
35 |
@torch.inference_mode()
|
36 |
def prepare_inputs_for_generation(
|
|
|
101 |
return_dict=True,
|
102 |
)
|
103 |
hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
|
|
|
|
|
104 |
|
105 |
logits = self.speech_head(hidden_states)
|
106 |
+
assert inputs_embeds.size(0) == 1
|
107 |
+
|
108 |
+
# NOTE: hallucination handler may modify logits to force emit an EOS token
|
109 |
+
logits = self.alignment_stream_analyzer.step(logits)
|
110 |
|
111 |
return CausalLMOutputWithCrossAttentions(
|
112 |
logits=logits,
|
orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc differ
|
|
orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc differ
|
|
orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc differ
|
|
orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc
CHANGED
Binary files a/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc differ
|
|