time varying speaker style
Browse files- Modules/hifigan.py +13 -5
- models.py +53 -80
- msinference.py +178 -75
- requirements.txt +1 -1
Modules/hifigan.py
CHANGED
@@ -12,16 +12,24 @@ import numpy as np
|
|
12 |
LRELU_SLOPE = 0.1
|
13 |
|
14 |
class AdaIN1d(nn.Module):
|
|
|
|
|
|
|
15 |
def __init__(self, style_dim, num_features):
|
16 |
super().__init__()
|
17 |
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
18 |
self.fc = nn.Linear(style_dim, num_features*2)
|
19 |
|
20 |
def forward(self, x, s):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
class AdaINResBlock1(torch.nn.Module):
|
27 |
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
@@ -443,7 +451,7 @@ class Decoder(nn.Module):
|
|
443 |
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
|
444 |
|
445 |
|
446 |
-
def forward(self, asr, F0_curve, N, s):
|
447 |
if self.training:
|
448 |
downlist = [0, 3, 7]
|
449 |
F0_down = downlist[random.randint(0, 2)]
|
|
|
12 |
LRELU_SLOPE = 0.1
|
13 |
|
14 |
class AdaIN1d(nn.Module):
|
15 |
+
|
16 |
+
# used by HiFiGan & ProsodyPredictor
|
17 |
+
|
18 |
def __init__(self, style_dim, num_features):
|
19 |
super().__init__()
|
20 |
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
21 |
self.fc = nn.Linear(style_dim, num_features*2)
|
22 |
|
23 |
def forward(self, x, s):
|
24 |
+
|
25 |
+
s = self.fc(s) # [bs, 1024, 130]
|
26 |
+
s = F.interpolate(s[:, :, 0, :].transpose(1,2), x.shape[2], mode='linear') # different time-resolution than Dur
|
27 |
+
|
28 |
+
gamma, beta = torch.chunk(s, chunks=2, dim=1) # channels vary in for loop
|
29 |
+
|
30 |
+
# affine (1 + lin(x)) * inst(x) + lin(x) is this a skip connection where the weight is a lin of itself
|
31 |
+
|
32 |
+
return (1 + gamma) * self.norm(x) + beta # norm(x) = PLBERT has norm / beta&gamma = style has no norm()
|
33 |
|
34 |
class AdaINResBlock1(torch.nn.Module):
|
35 |
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
|
|
451 |
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
|
452 |
|
453 |
|
454 |
+
def forward(self, asr=None, F0_curve=None, N=None, s=None):
|
455 |
if self.training:
|
456 |
downlist = [0, 3, 7]
|
457 |
F0_down = downlist[random.randint(0, 2)]
|
models.py
CHANGED
@@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|
8 |
from torch.nn.utils import weight_norm, spectral_norm
|
9 |
from Utils.ASR.models import ASRCNN
|
10 |
from Utils.JDC.model import JDCNet
|
11 |
-
from
|
12 |
import yaml
|
13 |
|
14 |
|
@@ -110,7 +110,7 @@ class ResBlk(nn.Module):
|
|
110 |
|
111 |
class StyleEncoder(nn.Module):
|
112 |
|
113 |
-
#
|
114 |
|
115 |
def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
|
116 |
super().__init__()
|
@@ -125,15 +125,20 @@ class StyleEncoder(nn.Module):
|
|
125 |
|
126 |
blocks += [nn.LeakyReLU(0.2)]
|
127 |
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
128 |
-
|
|
|
|
|
129 |
blocks += [nn.LeakyReLU(0.2)]
|
130 |
self.shared = nn.Sequential(*blocks)
|
131 |
|
132 |
self.unshared = nn.Linear(dim_out, style_dim)
|
133 |
|
134 |
def forward(self, x):
|
135 |
-
h = self.shared(x)
|
136 |
-
|
|
|
|
|
|
|
137 |
s = self.unshared(h)
|
138 |
return s
|
139 |
|
@@ -289,21 +294,6 @@ class TextEncoder(nn.Module):
|
|
289 |
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
290 |
return mask
|
291 |
|
292 |
-
|
293 |
-
|
294 |
-
class AdaIN1d(nn.Module):
|
295 |
-
def __init__(self, style_dim, num_features):
|
296 |
-
super().__init__()
|
297 |
-
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
298 |
-
self.fc = nn.Linear(style_dim, num_features*2)
|
299 |
-
|
300 |
-
def forward(self, x, s):
|
301 |
-
h = self.fc(s)
|
302 |
-
h = h.view(h.size(0), h.size(1), 1)
|
303 |
-
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
304 |
-
# affine (1 + lin(x)) * inst(x) + lin(x) is this a skip connection where the weight is a lin of itself
|
305 |
-
return (1 + gamma) * self.norm(x) + beta # norm(x) = PLBERT has norm / beta&gamma = style has no norm()
|
306 |
-
|
307 |
class UpSample1d(nn.Module):
|
308 |
def __init__(self, layer_type):
|
309 |
super().__init__()
|
@@ -316,8 +306,15 @@ class UpSample1d(nn.Module):
|
|
316 |
return F.interpolate(x, scale_factor=2, mode='nearest')
|
317 |
|
318 |
class AdainResBlk1d(nn.Module):
|
319 |
-
|
320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
super().__init__()
|
322 |
self.actv = actv
|
323 |
self.upsample_type = upsample
|
@@ -362,26 +359,22 @@ class AdainResBlk1d(nn.Module):
|
|
362 |
return out
|
363 |
|
364 |
class AdaLayerNorm(nn.Module):
|
365 |
-
|
|
|
|
|
|
|
366 |
super().__init__()
|
367 |
-
self.channels = channels
|
368 |
self.eps = eps
|
369 |
-
|
370 |
-
self.fc = nn.Linear(style_dim, channels*2)
|
371 |
|
372 |
def forward(self, x, s):
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
h = self.fc(s)
|
377 |
-
h = h.view(h.size(0), h.size(1), 1)
|
378 |
-
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
379 |
-
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
380 |
|
381 |
-
|
382 |
-
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
383 |
x = (1 + gamma) * x + beta
|
384 |
-
return x
|
385 |
|
386 |
class ProsodyPredictor(nn.Module):
|
387 |
|
@@ -414,7 +407,12 @@ class ProsodyPredictor(nn.Module):
|
|
414 |
x, _ = self.shared(x.transpose(-1, -2))
|
415 |
|
416 |
F0 = x.transpose(-1, -2)
|
|
|
|
|
417 |
for block in self.F0:
|
|
|
|
|
|
|
418 |
F0 = block(F0, s)
|
419 |
F0 = self.F0_proj(F0)
|
420 |
|
@@ -452,21 +450,30 @@ class DurationEncoder(nn.Module):
|
|
452 |
def forward(self, x, style, text_lengths, m):
|
453 |
masks = m.to(text_lengths.device)
|
454 |
|
455 |
-
|
456 |
-
|
457 |
-
x
|
458 |
-
x.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
|
460 |
-
|
461 |
input_lengths = text_lengths.cpu().numpy()
|
462 |
-
x = x.transpose(-1, -2)
|
463 |
|
464 |
for block in self.lstms:
|
465 |
if isinstance(block, AdaLayerNorm):
|
466 |
-
|
467 |
-
|
468 |
-
x
|
|
|
|
|
469 |
else:
|
|
|
470 |
x = x.transpose(-1, -2)
|
471 |
x = nn.utils.rnn.pack_padded_sequence(
|
472 |
x, input_lengths, batch_first=True, enforce_sorted=False)
|
@@ -481,6 +488,7 @@ class DurationEncoder(nn.Module):
|
|
481 |
|
482 |
x_pad[:, :, :x.shape[-1]] = x
|
483 |
x = x_pad.to(x.device)
|
|
|
484 |
# print('Calling Duration Encoder\n\n\n\n',x.shape, x.min(), x.max())
|
485 |
# Calling Duration Encoder
|
486 |
# torch.Size([1, 640, 107]) tensor(-3.0903, device='cuda:0') tensor(2.3089, device='cuda:0')
|
@@ -493,7 +501,6 @@ def load_F0_models(path):
|
|
493 |
# load F0 model
|
494 |
|
495 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
496 |
-
print(path, 'WHAT ARE YOU TRYING TO LOAD F0 L520')
|
497 |
path = path.replace('.t7', '.pth')
|
498 |
params = torch.load(path, map_location='cpu')['net']
|
499 |
F0_model.load_state_dict(params)
|
@@ -524,37 +531,3 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
|
|
524 |
_ = asr_model.train()
|
525 |
|
526 |
return asr_model
|
527 |
-
|
528 |
-
def build_model(args, text_aligner, pitch_extractor, bert):
|
529 |
-
print(f'\n==============\n {args.decoder.type=}\n==============L584 models.py @ build_model()\n')
|
530 |
-
|
531 |
-
from Modules.hifigan import Decoder
|
532 |
-
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
533 |
-
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
534 |
-
upsample_rates = args.decoder.upsample_rates,
|
535 |
-
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
536 |
-
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
537 |
-
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
|
538 |
-
|
539 |
-
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
540 |
-
|
541 |
-
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
542 |
-
|
543 |
-
style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
|
544 |
-
predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
|
545 |
-
nets = Munch(
|
546 |
-
bert=bert,
|
547 |
-
bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
|
548 |
-
|
549 |
-
predictor=predictor,
|
550 |
-
decoder=decoder,
|
551 |
-
text_encoder=text_encoder,
|
552 |
-
|
553 |
-
predictor_encoder=predictor_encoder,
|
554 |
-
style_encoder=style_encoder,
|
555 |
-
|
556 |
-
text_aligner = text_aligner,
|
557 |
-
pitch_extractor=pitch_extractor
|
558 |
-
)
|
559 |
-
|
560 |
-
return nets
|
|
|
8 |
from torch.nn.utils import weight_norm, spectral_norm
|
9 |
from Utils.ASR.models import ASRCNN
|
10 |
from Utils.JDC.model import JDCNet
|
11 |
+
from Modules.hifigan import AdaIN1d
|
12 |
import yaml
|
13 |
|
14 |
|
|
|
110 |
|
111 |
class StyleEncoder(nn.Module):
|
112 |
|
113 |
+
# for both acoustic & prosodic ref_s/p
|
114 |
|
115 |
def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
|
116 |
super().__init__()
|
|
|
125 |
|
126 |
blocks += [nn.LeakyReLU(0.2)]
|
127 |
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
128 |
+
|
129 |
+
# blocks += [nn.AdaptiveAvgPool2d(1)] # THIS AVERAGES THE TIME-FRAMES OF SPEAKER STYLE
|
130 |
+
|
131 |
blocks += [nn.LeakyReLU(0.2)]
|
132 |
self.shared = nn.Sequential(*blocks)
|
133 |
|
134 |
self.unshared = nn.Linear(dim_out, style_dim)
|
135 |
|
136 |
def forward(self, x):
|
137 |
+
h = self.shared(x) # [bs, 512, 1, 11]
|
138 |
+
|
139 |
+
h = h.mean(3, keepdims=True) # UN COMMENT FOR TIME INVARIANT GLOBAL SPEAKER STYLE
|
140 |
+
|
141 |
+
h = h.transpose(1, 3)
|
142 |
s = self.unshared(h)
|
143 |
return s
|
144 |
|
|
|
294 |
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
295 |
return mask
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
class UpSample1d(nn.Module):
|
298 |
def __init__(self, layer_type):
|
299 |
super().__init__()
|
|
|
306 |
return F.interpolate(x, scale_factor=2, mode='nearest')
|
307 |
|
308 |
class AdainResBlk1d(nn.Module):
|
309 |
+
|
310 |
+
# only instantiated in ProsodyPredictor
|
311 |
+
|
312 |
+
def __init__(self, dim_in,
|
313 |
+
dim_out,
|
314 |
+
style_dim=64,
|
315 |
+
actv=nn.LeakyReLU(0.2),
|
316 |
+
upsample='none',
|
317 |
+
dropout_p=0.0):
|
318 |
super().__init__()
|
319 |
self.actv = actv
|
320 |
self.upsample_type = upsample
|
|
|
359 |
return out
|
360 |
|
361 |
class AdaLayerNorm(nn.Module):
|
362 |
+
|
363 |
+
# only instantianted in DurationPredictor()
|
364 |
+
|
365 |
+
def __init__(self, style_dim, channels=None, eps=1e-5):
|
366 |
super().__init__()
|
|
|
367 |
self.eps = eps
|
368 |
+
self.fc = nn.Linear(style_dim, 1024)
|
|
|
369 |
|
370 |
def forward(self, x, s):
|
371 |
+
h = self.fc(s.transpose(1, 2)) # has to be transposed due to interpolate needing the last dim to be frames
|
372 |
+
gamma = h[:, :, :512]
|
373 |
+
beta = h[:, :, 512:1024]
|
|
|
|
|
|
|
|
|
374 |
|
375 |
+
x = F.layer_norm(x.transpose(1, 2), (512, ), eps=self.eps)
|
|
|
376 |
x = (1 + gamma) * x + beta
|
377 |
+
return x # [1, 75, 512]
|
378 |
|
379 |
class ProsodyPredictor(nn.Module):
|
380 |
|
|
|
407 |
x, _ = self.shared(x.transpose(-1, -2))
|
408 |
|
409 |
F0 = x.transpose(-1, -2)
|
410 |
+
|
411 |
+
|
412 |
for block in self.F0:
|
413 |
+
print(f'F)N {F0.shape=} {s.shape=}\n')
|
414 |
+
# )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
|
415 |
+
|
416 |
F0 = block(F0, s)
|
417 |
F0 = self.F0_proj(F0)
|
418 |
|
|
|
450 |
def forward(self, x, style, text_lengths, m):
|
451 |
masks = m.to(text_lengths.device)
|
452 |
|
453 |
+
|
454 |
+
|
455 |
+
# x : [bs, 512, 987]
|
456 |
+
# print('DURATION ENCODER', x.shape, style.shape, masks.shape)
|
457 |
+
# s = style.expand(x.shape[0], x.shape[1], -1)
|
458 |
+
style = style[:, :, 0, :].transpose(2, 1) # [bs, 128, 11]
|
459 |
+
# print("S IN DURATION ENC", style.shape, x.shape)
|
460 |
+
style = F.interpolate(style, x.shape[2])
|
461 |
+
print(f'L468 IN DURATION ENC {x.shape=}, {style.shape=} {masks.shape=}') # mask = [1,75]
|
462 |
+
x = torch.cat([x, style], axis=1) # [bs, 640, 75]
|
463 |
+
x.masked_fill_(masks[:, None, :], 0.0)
|
464 |
|
465 |
+
|
466 |
input_lengths = text_lengths.cpu().numpy()
|
|
|
467 |
|
468 |
for block in self.lstms:
|
469 |
if isinstance(block, AdaLayerNorm):
|
470 |
+
|
471 |
+
print(f'\n=========ENTER ADALAYNORM L479 models.py {x.shape=}, {style.shape=}')
|
472 |
+
x = block(x, style) # [bs, 75, 512]
|
473 |
+
x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
|
474 |
+
x.masked_fill_(masks[:, None, :], 0.0)
|
475 |
else:
|
476 |
+
# print(f'{x.shape=} ENTER LSTM') # [bs, 640, 75] LSTM reduce ch 640 -> 512
|
477 |
x = x.transpose(-1, -2)
|
478 |
x = nn.utils.rnn.pack_padded_sequence(
|
479 |
x, input_lengths, batch_first=True, enforce_sorted=False)
|
|
|
488 |
|
489 |
x_pad[:, :, :x.shape[-1]] = x
|
490 |
x = x_pad.to(x.device)
|
491 |
+
# print(f'{x.shape=} EXIR LSTM') # [bs, 512, 75]
|
492 |
# print('Calling Duration Encoder\n\n\n\n',x.shape, x.min(), x.max())
|
493 |
# Calling Duration Encoder
|
494 |
# torch.Size([1, 640, 107]) tensor(-3.0903, device='cuda:0') tensor(2.3089, device='cuda:0')
|
|
|
501 |
# load F0 model
|
502 |
|
503 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
|
|
504 |
path = path.replace('.t7', '.pth')
|
505 |
params = torch.load(path, map_location='cpu')['net']
|
506 |
F0_model.load_state_dict(params)
|
|
|
531 |
_ = asr_model.train()
|
532 |
|
533 |
return asr_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
msinference.py
CHANGED
@@ -7,8 +7,7 @@ import numpy as np
|
|
7 |
import yaml
|
8 |
import torchaudio
|
9 |
import librosa
|
10 |
-
from models import
|
11 |
-
from munch import Munch
|
12 |
from nltk.tokenize import word_tokenize
|
13 |
|
14 |
torch.manual_seed(0)
|
@@ -62,17 +61,6 @@ def alpha_num(f):
|
|
62 |
return f
|
63 |
|
64 |
|
65 |
-
|
66 |
-
def recursive_munch(d):
|
67 |
-
if isinstance(d, dict):
|
68 |
-
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
69 |
-
elif isinstance(d, list):
|
70 |
-
return [recursive_munch(v) for v in d]
|
71 |
-
else:
|
72 |
-
return d
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
# ======== UTILS ABOVE
|
77 |
|
78 |
def length_to_mask(lengths):
|
@@ -94,10 +82,10 @@ def compute_style(path):
|
|
94 |
mel_tensor = preprocess(audio).to(device)
|
95 |
|
96 |
with torch.no_grad():
|
97 |
-
ref_s =
|
98 |
-
ref_p =
|
99 |
-
|
100 |
-
return torch.cat([ref_s, ref_p], dim=1
|
101 |
|
102 |
device = 'cpu'
|
103 |
if torch.cuda.is_available():
|
@@ -112,50 +100,151 @@ global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_
|
|
112 |
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
113 |
|
114 |
|
115 |
-
|
|
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
ASR_path = config.get('ASR_path', False)
|
120 |
-
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
pitch_extractor = load_F0_models(F0_path)
|
125 |
|
126 |
-
# load BERT model
|
127 |
from Utils.PLBERT.util import load_plbert
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
|
137 |
# params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
|
138 |
params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
|
139 |
params = params_whole['net']
|
140 |
|
141 |
-
for key in model:
|
142 |
-
if key in params:
|
143 |
-
print('%s loaded' % key)
|
144 |
-
try:
|
145 |
-
model[key].load_state_dict(params[key])
|
146 |
-
except:
|
147 |
-
from collections import OrderedDict
|
148 |
-
state_dict = params[key]
|
149 |
-
new_state_dict = OrderedDict()
|
150 |
-
for k, v in state_dict.items():
|
151 |
-
name = k[7:] # remove `module.`
|
152 |
-
new_state_dict[name] = v
|
153 |
-
# load params
|
154 |
-
model[key].load_state_dict(new_state_dict, strict=False)
|
155 |
-
# except:
|
156 |
-
# _load(params[key], model[key])
|
157 |
-
_ = [model[key].eval() for key in model]
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
|
161 |
def inference(text,
|
@@ -193,24 +282,31 @@ def inference(text,
|
|
193 |
# 54, 156, 63, 158, 147, 83, 56, 16, 4]], device='cuda:0')
|
194 |
|
195 |
|
196 |
-
t_en =
|
197 |
-
bert_dur =
|
198 |
-
d_en =
|
199 |
# print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
|
200 |
# BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
|
201 |
|
202 |
|
203 |
|
204 |
-
ref = ref_s[:, :128]
|
205 |
-
s = ref_s[:, 128:]
|
206 |
|
207 |
-
# s = .74 * s # prosody / arousal & fading unvoiced syllabes [x0.7 - x1.2]
|
208 |
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
-
x, _ =
|
213 |
-
duration =
|
214 |
|
215 |
duration = torch.sigmoid(duration).sum(axis=-1)
|
216 |
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
@@ -224,23 +320,25 @@ def inference(text,
|
|
224 |
|
225 |
# encode prosody
|
226 |
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
F0_pred, N_pred =
|
234 |
|
235 |
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
x =
|
243 |
-
|
|
|
|
|
244 |
|
245 |
x = x.cpu().numpy()[0, 0, :-400] # weird pulse at the end of sentences
|
246 |
|
@@ -299,6 +397,11 @@ import re
|
|
299 |
from num2words import num2words
|
300 |
|
301 |
PHONEME_MAP = {
|
|
|
|
|
|
|
|
|
|
|
302 |
'q': 'ku',
|
303 |
'w': 'aou',
|
304 |
'z': 's',
|
|
|
7 |
import yaml
|
8 |
import torchaudio
|
9 |
import librosa
|
10 |
+
from models import ProsodyPredictor, TextEncoder, StyleEncoder, load_ASR_models, load_F0_models
|
|
|
11 |
from nltk.tokenize import word_tokenize
|
12 |
|
13 |
torch.manual_seed(0)
|
|
|
61 |
return f
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# ======== UTILS ABOVE
|
65 |
|
66 |
def length_to_mask(lengths):
|
|
|
82 |
mel_tensor = preprocess(audio).to(device)
|
83 |
|
84 |
with torch.no_grad():
|
85 |
+
ref_s = style_encoder(mel_tensor.unsqueeze(1))
|
86 |
+
ref_p = predictor_encoder(mel_tensor.unsqueeze(1)) # [bs, 11, 1, 128]
|
87 |
+
print(f'\n\n\n\nCOMPUTE STYLe {ref_s.shape=} {ref_p.shape=}')
|
88 |
+
return torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
|
89 |
|
90 |
device = 'cpu'
|
91 |
if torch.cuda.is_available():
|
|
|
100 |
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
101 |
|
102 |
|
103 |
+
args = yaml.safe_load(open(str('Utils/config.yml')))
|
104 |
+
ASR_config = args['ASR_config']
|
105 |
|
106 |
+
ASR_path = args['ASR_path']
|
107 |
+
text_aligner = load_ASR_models(ASR_path, ASR_config).eval().to(device)
|
|
|
|
|
108 |
|
109 |
+
F0_path = args['F0_path']
|
110 |
+
pitch_extractor = load_F0_models(F0_path).eval().to(device)
|
|
|
111 |
|
|
|
112 |
from Utils.PLBERT.util import load_plbert
|
113 |
+
bert = load_plbert(args['PLBERT_dir']).eval().to(device)
|
114 |
+
# model_params = recursive_munch(config['model_params'])
|
115 |
+
# --
|
116 |
+
# def build_model(args, text_aligner, pitch_extractor, bert):
|
117 |
+
# print(f'\n==============\n {args.decoder.type=}\n==============L584 models.py @ build_model()\n')
|
118 |
+
# # ======================================
|
119 |
+
# In [4]: args['model_params']
|
120 |
+
# Out[4]:
|
121 |
+
# {'decoder': {'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
122 |
+
# 'resblock_kernel_sizes': [3, 7, 11],
|
123 |
+
# 'type': 'hifigan',
|
124 |
+
# 'upsample_initial_channel': 512,
|
125 |
+
# 'upsample_kernel_sizes': [20, 10, 6, 4],
|
126 |
+
# 'upsample_rates': [10, 5, 3, 2]},
|
127 |
+
# 'diffusion': {'dist': {'estimate_sigma_data': True,
|
128 |
+
# 'mean': -3.0,
|
129 |
+
# 'sigma_data': 0.19926648961191362,
|
130 |
+
# 'std': 1.0},
|
131 |
+
# 'embedding_mask_proba': 0.1,
|
132 |
+
# 'transformer': {'head_features': 64,
|
133 |
+
# 'multiplier': 2,
|
134 |
+
# 'num_heads': 8,
|
135 |
+
# 'num_layers': 3}},
|
136 |
+
# 'dim_in': 64,
|
137 |
+
# 'dropout': 0.2,
|
138 |
+
# 'hidden_dim': 512,
|
139 |
+
# 'max_conv_dim': 512,
|
140 |
+
# 'max_dur': 50,
|
141 |
+
# 'multispeaker': True,
|
142 |
+
# 'n_layer': 3,
|
143 |
+
# 'n_mels': 80,
|
144 |
+
# 'n_token': 178,
|
145 |
+
# 'slm': {'hidden': 768,
|
146 |
+
# 'initial_channel': 64,
|
147 |
+
# 'model': 'microsoft/wavlm-base-plus',
|
148 |
+
# 'nlayers': 13,
|
149 |
+
# 'sr': 16000},
|
150 |
+
# 'style_dim': 128}
|
151 |
+
# # ===============================================
|
152 |
+
from Modules.hifigan import Decoder
|
153 |
+
decoder = Decoder(dim_in=512,
|
154 |
+
style_dim=128,
|
155 |
+
dim_out=80, # n_mels
|
156 |
+
resblock_kernel_sizes = [3, 7, 11],
|
157 |
+
upsample_rates = [10, 5, 3, 2],
|
158 |
+
upsample_initial_channel=512,
|
159 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
160 |
+
upsample_kernel_sizes=[20, 10, 6, 4]).eval().to(device)
|
161 |
+
|
162 |
+
text_encoder = TextEncoder(channels=512,
|
163 |
+
kernel_size=5,
|
164 |
+
depth=3, #args['model_params']['n_layer'],
|
165 |
+
n_symbols=178, #args['model_params']['n_token']
|
166 |
+
).eval().to(device)
|
167 |
+
|
168 |
+
predictor = ProsodyPredictor(style_dim=128,
|
169 |
+
d_hid=512,
|
170 |
+
nlayers=3, # OFFICIAL config.nlayers=5;
|
171 |
+
max_dur=50,
|
172 |
+
dropout=.2).eval().to(device)
|
173 |
+
|
174 |
+
style_encoder = StyleEncoder(dim_in=64,
|
175 |
+
style_dim=128,
|
176 |
+
max_conv_dim=512).eval().to(device) # acoustic style encoder
|
177 |
+
predictor_encoder = StyleEncoder(dim_in=64,
|
178 |
+
style_dim=128,
|
179 |
+
max_conv_dim=512).eval().to(device) # prosodic style encoder
|
180 |
+
bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device)
|
181 |
+
# --
|
182 |
+
# model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
183 |
+
# _ = [model[key].eval() for key in model]
|
184 |
+
# _ = [model[key].to(device) for key in model]
|
185 |
|
186 |
# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
|
187 |
# params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
|
188 |
params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
|
189 |
params = params_whole['net']
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
+
# 'bert',
|
193 |
+
# 'bert_encoder',
|
194 |
+
# 'predictor',
|
195 |
+
# 'decoder',
|
196 |
+
# 'text_encoder',
|
197 |
+
# 'predictor_encoder',
|
198 |
+
# 'style_encoder',
|
199 |
+
# 'text_aligner',
|
200 |
+
# 'pitch_extractor'
|
201 |
+
# --
|
202 |
+
from collections import OrderedDict
|
203 |
+
|
204 |
+
new_state_dict = OrderedDict()
|
205 |
+
for k, v in params['bert'].items():
|
206 |
+
new_state_dict[k[7:]] = v # del 'module.'
|
207 |
+
bert.load_state_dict(new_state_dict, strict=True)
|
208 |
+
# --
|
209 |
+
new_state_dict = OrderedDict()
|
210 |
+
for k, v in params['bert_encoder'].items():
|
211 |
+
new_state_dict[k[7:]] = v # del 'module.'
|
212 |
+
bert_encoder.load_state_dict(new_state_dict, strict=True)
|
213 |
+
# --
|
214 |
+
new_state_dict = OrderedDict()
|
215 |
+
for k, v in params['predictor'].items():
|
216 |
+
new_state_dict[k[7:]] = v # del 'module.'
|
217 |
+
predictor.load_state_dict(new_state_dict, strict=True) # XTRA non-ckpt LSTMs nlayers add slowiness to voice
|
218 |
+
# --
|
219 |
+
new_state_dict = OrderedDict()
|
220 |
+
for k, v in params['decoder'].items():
|
221 |
+
new_state_dict[k[7:]] = v
|
222 |
+
decoder.load_state_dict(new_state_dict, strict=True)
|
223 |
+
# --
|
224 |
+
new_state_dict = OrderedDict()
|
225 |
+
for k, v in params['text_encoder'].items():
|
226 |
+
new_state_dict[k[7:]] = v
|
227 |
+
text_encoder.load_state_dict(new_state_dict, strict=True)
|
228 |
+
# --
|
229 |
+
new_state_dict = OrderedDict()
|
230 |
+
for k, v in params['predictor_encoder'].items():
|
231 |
+
new_state_dict[k[7:]] = v
|
232 |
+
predictor_encoder.load_state_dict(new_state_dict, strict=True)
|
233 |
+
# --
|
234 |
+
new_state_dict = OrderedDict()
|
235 |
+
for k, v in params['style_encoder'].items():
|
236 |
+
new_state_dict[k[7:]] = v
|
237 |
+
style_encoder.load_state_dict(new_state_dict, strict=True)
|
238 |
+
# --
|
239 |
+
new_state_dict = OrderedDict()
|
240 |
+
for k, v in params['text_aligner'].items():
|
241 |
+
new_state_dict[k[7:]] = v # del 'module.'
|
242 |
+
text_aligner.load_state_dict(new_state_dict, strict=True)
|
243 |
+
# --
|
244 |
+
new_state_dict = OrderedDict()
|
245 |
+
for k, v in params['pitch_extractor'].items():
|
246 |
+
new_state_dict[k[7:]] = v
|
247 |
+
pitch_extractor.load_state_dict(new_state_dict, strict=True)
|
248 |
|
249 |
|
250 |
def inference(text,
|
|
|
282 |
# 54, 156, 63, 158, 147, 83, 56, 16, 4]], device='cuda:0')
|
283 |
|
284 |
|
285 |
+
t_en = text_encoder(tokens, input_lengths, text_mask)
|
286 |
+
bert_dur = bert(tokens, attention_mask=(~text_mask).int())
|
287 |
+
d_en = bert_encoder(bert_dur).transpose(-1, -2)
|
288 |
# print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
|
289 |
# BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
|
290 |
|
291 |
|
292 |
|
293 |
+
ref = ref_s[:, :, :, :128] # [bs, 11, 1, 128]
|
294 |
+
s = ref_s[:, :, :, 128:] # have channels as last dim so it can go through nn.Linear layers
|
295 |
|
|
|
296 |
|
297 |
+
# ON compute style we dont know yet the size to interpolate
|
298 |
+
# Perhaps we can interpolate ref_s here as now we know how many bert time-frames the text needs
|
299 |
+
# s = .74 * s # prosody / arousal & fading unvoiced syllabes [x0.7 - x1.2]
|
300 |
+
|
301 |
+
|
302 |
+
print(f'{d_en.shape=} {s.shape=} {input_lengths.shape=} {text_mask.shape=}')
|
303 |
+
d = predictor.text_encoder(d_en,
|
304 |
+
s,
|
305 |
+
input_lengths,
|
306 |
+
text_mask)
|
307 |
|
308 |
+
x, _ = predictor.lstm(d)
|
309 |
+
duration = predictor.duration_proj(x)
|
310 |
|
311 |
duration = torch.sigmoid(duration).sum(axis=-1)
|
312 |
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
|
|
320 |
|
321 |
# encode prosody
|
322 |
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
323 |
+
|
324 |
+
asr_new = torch.zeros_like(en)
|
325 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
326 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
327 |
+
en = asr_new
|
328 |
+
print('_________________________________________F0_____________________________')
|
329 |
+
F0_pred, N_pred = predictor.F0Ntrain(en, s)
|
330 |
|
331 |
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
332 |
+
|
333 |
+
asr_new = torch.zeros_like(asr)
|
334 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
335 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
336 |
+
asr = asr_new
|
337 |
+
print('_________________________________________HiFI_____________________________')
|
338 |
+
x = decoder(asr=asr,
|
339 |
+
F0_curve=F0_pred,
|
340 |
+
N=N_pred,
|
341 |
+
s=ref)
|
342 |
|
343 |
x = x.cpu().numpy()[0, 0, :-400] # weird pulse at the end of sentences
|
344 |
|
|
|
397 |
from num2words import num2words
|
398 |
|
399 |
PHONEME_MAP = {
|
400 |
+
'služ' : 'sloooozz', # 'službeno'
|
401 |
+
'suver': 'siuveeerra', # 'suverena'
|
402 |
+
'država': 'dirrezav', # 'država'
|
403 |
+
'iči': 'ici', # 'Graniči'
|
404 |
+
's ': 'se', # a s with space
|
405 |
'q': 'ku',
|
406 |
'w': 'aou',
|
407 |
'z': 's',
|
requirements.txt
CHANGED
@@ -13,7 +13,7 @@ omegaconf
|
|
13 |
opencv-python
|
14 |
soundfile
|
15 |
transformers
|
16 |
-
|
17 |
srt
|
18 |
nltk
|
19 |
phonemizer
|
|
|
13 |
opencv-python
|
14 |
soundfile
|
15 |
transformers
|
16 |
+
audresample
|
17 |
srt
|
18 |
nltk
|
19 |
phonemizer
|