Dionyssos commited on
Commit
62ef231
·
1 Parent(s): 8a9a2fe

time varying speaker style

Browse files
Files changed (4) hide show
  1. Modules/hifigan.py +13 -5
  2. models.py +53 -80
  3. msinference.py +178 -75
  4. requirements.txt +1 -1
Modules/hifigan.py CHANGED
@@ -12,16 +12,24 @@ import numpy as np
12
  LRELU_SLOPE = 0.1
13
 
14
  class AdaIN1d(nn.Module):
 
 
 
15
  def __init__(self, style_dim, num_features):
16
  super().__init__()
17
  self.norm = nn.InstanceNorm1d(num_features, affine=False)
18
  self.fc = nn.Linear(style_dim, num_features*2)
19
 
20
  def forward(self, x, s):
21
- h = self.fc(s)
22
- h = h.view(h.size(0), h.size(1), 1)
23
- gamma, beta = torch.chunk(h, chunks=2, dim=1)
24
- return (1 + gamma) * self.norm(x) + beta
 
 
 
 
 
25
 
26
  class AdaINResBlock1(torch.nn.Module):
27
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
@@ -443,7 +451,7 @@ class Decoder(nn.Module):
443
  self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
444
 
445
 
446
- def forward(self, asr, F0_curve, N, s):
447
  if self.training:
448
  downlist = [0, 3, 7]
449
  F0_down = downlist[random.randint(0, 2)]
 
12
  LRELU_SLOPE = 0.1
13
 
14
  class AdaIN1d(nn.Module):
15
+
16
+ # used by HiFiGan & ProsodyPredictor
17
+
18
  def __init__(self, style_dim, num_features):
19
  super().__init__()
20
  self.norm = nn.InstanceNorm1d(num_features, affine=False)
21
  self.fc = nn.Linear(style_dim, num_features*2)
22
 
23
  def forward(self, x, s):
24
+
25
+ s = self.fc(s) # [bs, 1024, 130]
26
+ s = F.interpolate(s[:, :, 0, :].transpose(1,2), x.shape[2], mode='linear') # different time-resolution than Dur
27
+
28
+ gamma, beta = torch.chunk(s, chunks=2, dim=1) # channels vary in for loop
29
+
30
+ # affine (1 + lin(x)) * inst(x) + lin(x) is this a skip connection where the weight is a lin of itself
31
+
32
+ return (1 + gamma) * self.norm(x) + beta # norm(x) = PLBERT has norm / beta&gamma = style has no norm()
33
 
34
  class AdaINResBlock1(torch.nn.Module):
35
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
 
451
  self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
452
 
453
 
454
+ def forward(self, asr=None, F0_curve=None, N=None, s=None):
455
  if self.training:
456
  downlist = [0, 3, 7]
457
  F0_down = downlist[random.randint(0, 2)]
models.py CHANGED
@@ -8,7 +8,7 @@ import torch.nn.functional as F
8
  from torch.nn.utils import weight_norm, spectral_norm
9
  from Utils.ASR.models import ASRCNN
10
  from Utils.JDC.model import JDCNet
11
- from munch import Munch
12
  import yaml
13
 
14
 
@@ -110,7 +110,7 @@ class ResBlk(nn.Module):
110
 
111
  class StyleEncoder(nn.Module):
112
 
113
- # used for both acoustic & prosodic ref_s/p
114
 
115
  def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
116
  super().__init__()
@@ -125,15 +125,20 @@ class StyleEncoder(nn.Module):
125
 
126
  blocks += [nn.LeakyReLU(0.2)]
127
  blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
128
- blocks += [nn.AdaptiveAvgPool2d(1)]
 
 
129
  blocks += [nn.LeakyReLU(0.2)]
130
  self.shared = nn.Sequential(*blocks)
131
 
132
  self.unshared = nn.Linear(dim_out, style_dim)
133
 
134
  def forward(self, x):
135
- h = self.shared(x)
136
- h = h.view(h.size(0), -1)
 
 
 
137
  s = self.unshared(h)
138
  return s
139
 
@@ -289,21 +294,6 @@ class TextEncoder(nn.Module):
289
  mask = torch.gt(mask+1, lengths.unsqueeze(1))
290
  return mask
291
 
292
-
293
-
294
- class AdaIN1d(nn.Module):
295
- def __init__(self, style_dim, num_features):
296
- super().__init__()
297
- self.norm = nn.InstanceNorm1d(num_features, affine=False)
298
- self.fc = nn.Linear(style_dim, num_features*2)
299
-
300
- def forward(self, x, s):
301
- h = self.fc(s)
302
- h = h.view(h.size(0), h.size(1), 1)
303
- gamma, beta = torch.chunk(h, chunks=2, dim=1)
304
- # affine (1 + lin(x)) * inst(x) + lin(x) is this a skip connection where the weight is a lin of itself
305
- return (1 + gamma) * self.norm(x) + beta # norm(x) = PLBERT has norm / beta&gamma = style has no norm()
306
-
307
  class UpSample1d(nn.Module):
308
  def __init__(self, layer_type):
309
  super().__init__()
@@ -316,8 +306,15 @@ class UpSample1d(nn.Module):
316
  return F.interpolate(x, scale_factor=2, mode='nearest')
317
 
318
  class AdainResBlk1d(nn.Module):
319
- def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
320
- upsample='none', dropout_p=0.0):
 
 
 
 
 
 
 
321
  super().__init__()
322
  self.actv = actv
323
  self.upsample_type = upsample
@@ -362,26 +359,22 @@ class AdainResBlk1d(nn.Module):
362
  return out
363
 
364
  class AdaLayerNorm(nn.Module):
365
- def __init__(self, style_dim, channels, eps=1e-5):
 
 
 
366
  super().__init__()
367
- self.channels = channels
368
  self.eps = eps
369
-
370
- self.fc = nn.Linear(style_dim, channels*2)
371
 
372
  def forward(self, x, s):
373
- x = x.transpose(-1, -2)
374
- x = x.transpose(1, -1)
375
-
376
- h = self.fc(s)
377
- h = h.view(h.size(0), h.size(1), 1)
378
- gamma, beta = torch.chunk(h, chunks=2, dim=1)
379
- gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
380
 
381
-
382
- x = F.layer_norm(x, (self.channels,), eps=self.eps)
383
  x = (1 + gamma) * x + beta
384
- return x.transpose(1, -1).transpose(-1, -2)
385
 
386
  class ProsodyPredictor(nn.Module):
387
 
@@ -414,7 +407,12 @@ class ProsodyPredictor(nn.Module):
414
  x, _ = self.shared(x.transpose(-1, -2))
415
 
416
  F0 = x.transpose(-1, -2)
 
 
417
  for block in self.F0:
 
 
 
418
  F0 = block(F0, s)
419
  F0 = self.F0_proj(F0)
420
 
@@ -452,21 +450,30 @@ class DurationEncoder(nn.Module):
452
  def forward(self, x, style, text_lengths, m):
453
  masks = m.to(text_lengths.device)
454
 
455
- x = x.permute(2, 0, 1)
456
- s = style.expand(x.shape[0], x.shape[1], -1)
457
- x = torch.cat([x, s], axis=-1)
458
- x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
 
 
 
 
 
 
 
459
 
460
- x = x.transpose(0, 1)
461
  input_lengths = text_lengths.cpu().numpy()
462
- x = x.transpose(-1, -2)
463
 
464
  for block in self.lstms:
465
  if isinstance(block, AdaLayerNorm):
466
- x = block(x.transpose(-1, -2), style).transpose(-1, -2)
467
- x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
468
- x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
 
 
469
  else:
 
470
  x = x.transpose(-1, -2)
471
  x = nn.utils.rnn.pack_padded_sequence(
472
  x, input_lengths, batch_first=True, enforce_sorted=False)
@@ -481,6 +488,7 @@ class DurationEncoder(nn.Module):
481
 
482
  x_pad[:, :, :x.shape[-1]] = x
483
  x = x_pad.to(x.device)
 
484
  # print('Calling Duration Encoder\n\n\n\n',x.shape, x.min(), x.max())
485
  # Calling Duration Encoder
486
  # torch.Size([1, 640, 107]) tensor(-3.0903, device='cuda:0') tensor(2.3089, device='cuda:0')
@@ -493,7 +501,6 @@ def load_F0_models(path):
493
  # load F0 model
494
 
495
  F0_model = JDCNet(num_class=1, seq_len=192)
496
- print(path, 'WHAT ARE YOU TRYING TO LOAD F0 L520')
497
  path = path.replace('.t7', '.pth')
498
  params = torch.load(path, map_location='cpu')['net']
499
  F0_model.load_state_dict(params)
@@ -524,37 +531,3 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
524
  _ = asr_model.train()
525
 
526
  return asr_model
527
-
528
- def build_model(args, text_aligner, pitch_extractor, bert):
529
- print(f'\n==============\n {args.decoder.type=}\n==============L584 models.py @ build_model()\n')
530
-
531
- from Modules.hifigan import Decoder
532
- decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
533
- resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
534
- upsample_rates = args.decoder.upsample_rates,
535
- upsample_initial_channel=args.decoder.upsample_initial_channel,
536
- resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
537
- upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
538
-
539
- text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
540
-
541
- predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
542
-
543
- style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
544
- predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
545
- nets = Munch(
546
- bert=bert,
547
- bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
548
-
549
- predictor=predictor,
550
- decoder=decoder,
551
- text_encoder=text_encoder,
552
-
553
- predictor_encoder=predictor_encoder,
554
- style_encoder=style_encoder,
555
-
556
- text_aligner = text_aligner,
557
- pitch_extractor=pitch_extractor
558
- )
559
-
560
- return nets
 
8
  from torch.nn.utils import weight_norm, spectral_norm
9
  from Utils.ASR.models import ASRCNN
10
  from Utils.JDC.model import JDCNet
11
+ from Modules.hifigan import AdaIN1d
12
  import yaml
13
 
14
 
 
110
 
111
  class StyleEncoder(nn.Module):
112
 
113
+ # for both acoustic & prosodic ref_s/p
114
 
115
  def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
116
  super().__init__()
 
125
 
126
  blocks += [nn.LeakyReLU(0.2)]
127
  blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
128
+
129
+ # blocks += [nn.AdaptiveAvgPool2d(1)] # THIS AVERAGES THE TIME-FRAMES OF SPEAKER STYLE
130
+
131
  blocks += [nn.LeakyReLU(0.2)]
132
  self.shared = nn.Sequential(*blocks)
133
 
134
  self.unshared = nn.Linear(dim_out, style_dim)
135
 
136
  def forward(self, x):
137
+ h = self.shared(x) # [bs, 512, 1, 11]
138
+
139
+ h = h.mean(3, keepdims=True) # UN COMMENT FOR TIME INVARIANT GLOBAL SPEAKER STYLE
140
+
141
+ h = h.transpose(1, 3)
142
  s = self.unshared(h)
143
  return s
144
 
 
294
  mask = torch.gt(mask+1, lengths.unsqueeze(1))
295
  return mask
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  class UpSample1d(nn.Module):
298
  def __init__(self, layer_type):
299
  super().__init__()
 
306
  return F.interpolate(x, scale_factor=2, mode='nearest')
307
 
308
  class AdainResBlk1d(nn.Module):
309
+
310
+ # only instantiated in ProsodyPredictor
311
+
312
+ def __init__(self, dim_in,
313
+ dim_out,
314
+ style_dim=64,
315
+ actv=nn.LeakyReLU(0.2),
316
+ upsample='none',
317
+ dropout_p=0.0):
318
  super().__init__()
319
  self.actv = actv
320
  self.upsample_type = upsample
 
359
  return out
360
 
361
  class AdaLayerNorm(nn.Module):
362
+
363
+ # only instantianted in DurationPredictor()
364
+
365
+ def __init__(self, style_dim, channels=None, eps=1e-5):
366
  super().__init__()
 
367
  self.eps = eps
368
+ self.fc = nn.Linear(style_dim, 1024)
 
369
 
370
  def forward(self, x, s):
371
+ h = self.fc(s.transpose(1, 2)) # has to be transposed due to interpolate needing the last dim to be frames
372
+ gamma = h[:, :, :512]
373
+ beta = h[:, :, 512:1024]
 
 
 
 
374
 
375
+ x = F.layer_norm(x.transpose(1, 2), (512, ), eps=self.eps)
 
376
  x = (1 + gamma) * x + beta
377
+ return x # [1, 75, 512]
378
 
379
  class ProsodyPredictor(nn.Module):
380
 
 
407
  x, _ = self.shared(x.transpose(-1, -2))
408
 
409
  F0 = x.transpose(-1, -2)
410
+
411
+
412
  for block in self.F0:
413
+ print(f'F)N {F0.shape=} {s.shape=}\n')
414
+ # )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
415
+
416
  F0 = block(F0, s)
417
  F0 = self.F0_proj(F0)
418
 
 
450
  def forward(self, x, style, text_lengths, m):
451
  masks = m.to(text_lengths.device)
452
 
453
+
454
+
455
+ # x : [bs, 512, 987]
456
+ # print('DURATION ENCODER', x.shape, style.shape, masks.shape)
457
+ # s = style.expand(x.shape[0], x.shape[1], -1)
458
+ style = style[:, :, 0, :].transpose(2, 1) # [bs, 128, 11]
459
+ # print("S IN DURATION ENC", style.shape, x.shape)
460
+ style = F.interpolate(style, x.shape[2])
461
+ print(f'L468 IN DURATION ENC {x.shape=}, {style.shape=} {masks.shape=}') # mask = [1,75]
462
+ x = torch.cat([x, style], axis=1) # [bs, 640, 75]
463
+ x.masked_fill_(masks[:, None, :], 0.0)
464
 
465
+
466
  input_lengths = text_lengths.cpu().numpy()
 
467
 
468
  for block in self.lstms:
469
  if isinstance(block, AdaLayerNorm):
470
+
471
+ print(f'\n=========ENTER ADALAYNORM L479 models.py {x.shape=}, {style.shape=}')
472
+ x = block(x, style) # [bs, 75, 512]
473
+ x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
474
+ x.masked_fill_(masks[:, None, :], 0.0)
475
  else:
476
+ # print(f'{x.shape=} ENTER LSTM') # [bs, 640, 75] LSTM reduce ch 640 -> 512
477
  x = x.transpose(-1, -2)
478
  x = nn.utils.rnn.pack_padded_sequence(
479
  x, input_lengths, batch_first=True, enforce_sorted=False)
 
488
 
489
  x_pad[:, :, :x.shape[-1]] = x
490
  x = x_pad.to(x.device)
491
+ # print(f'{x.shape=} EXIR LSTM') # [bs, 512, 75]
492
  # print('Calling Duration Encoder\n\n\n\n',x.shape, x.min(), x.max())
493
  # Calling Duration Encoder
494
  # torch.Size([1, 640, 107]) tensor(-3.0903, device='cuda:0') tensor(2.3089, device='cuda:0')
 
501
  # load F0 model
502
 
503
  F0_model = JDCNet(num_class=1, seq_len=192)
 
504
  path = path.replace('.t7', '.pth')
505
  params = torch.load(path, map_location='cpu')['net']
506
  F0_model.load_state_dict(params)
 
531
  _ = asr_model.train()
532
 
533
  return asr_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
msinference.py CHANGED
@@ -7,8 +7,7 @@ import numpy as np
7
  import yaml
8
  import torchaudio
9
  import librosa
10
- from models import *
11
- from munch import Munch
12
  from nltk.tokenize import word_tokenize
13
 
14
  torch.manual_seed(0)
@@ -62,17 +61,6 @@ def alpha_num(f):
62
  return f
63
 
64
 
65
-
66
- def recursive_munch(d):
67
- if isinstance(d, dict):
68
- return Munch((k, recursive_munch(v)) for k, v in d.items())
69
- elif isinstance(d, list):
70
- return [recursive_munch(v) for v in d]
71
- else:
72
- return d
73
-
74
-
75
-
76
  # ======== UTILS ABOVE
77
 
78
  def length_to_mask(lengths):
@@ -94,10 +82,10 @@ def compute_style(path):
94
  mel_tensor = preprocess(audio).to(device)
95
 
96
  with torch.no_grad():
97
- ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
98
- ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
99
-
100
- return torch.cat([ref_s, ref_p], dim=1)
101
 
102
  device = 'cpu'
103
  if torch.cuda.is_available():
@@ -112,50 +100,151 @@ global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_
112
  # phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
113
 
114
 
115
- config = yaml.safe_load(open(str('Utils/config.yml')))
 
116
 
117
- # load pretrained ASR model
118
- ASR_config = config.get('ASR_config', False)
119
- ASR_path = config.get('ASR_path', False)
120
- text_aligner = load_ASR_models(ASR_path, ASR_config)
121
 
122
- # load pretrained F0 model
123
- F0_path = config.get('F0_path', False)
124
- pitch_extractor = load_F0_models(F0_path)
125
 
126
- # load BERT model
127
  from Utils.PLBERT.util import load_plbert
128
- BERT_path = config.get('PLBERT_dir', False)
129
- plbert = load_plbert(BERT_path)
130
-
131
- model_params = recursive_munch(config['model_params'])
132
- model = build_model(model_params, text_aligner, pitch_extractor, plbert)
133
- _ = [model[key].eval() for key in model]
134
- _ = [model[key].to(device) for key in model]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  # params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
137
  # params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
138
  params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
139
  params = params_whole['net']
140
 
141
- for key in model:
142
- if key in params:
143
- print('%s loaded' % key)
144
- try:
145
- model[key].load_state_dict(params[key])
146
- except:
147
- from collections import OrderedDict
148
- state_dict = params[key]
149
- new_state_dict = OrderedDict()
150
- for k, v in state_dict.items():
151
- name = k[7:] # remove `module.`
152
- new_state_dict[name] = v
153
- # load params
154
- model[key].load_state_dict(new_state_dict, strict=False)
155
- # except:
156
- # _load(params[key], model[key])
157
- _ = [model[key].eval() for key in model]
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  def inference(text,
@@ -193,24 +282,31 @@ def inference(text,
193
  # 54, 156, 63, 158, 147, 83, 56, 16, 4]], device='cuda:0')
194
 
195
 
196
- t_en = model.text_encoder(tokens, input_lengths, text_mask)
197
- bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
198
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
199
  # print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
200
  # BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
201
 
202
 
203
 
204
- ref = ref_s[:, :128]
205
- s = ref_s[:, 128:]
206
 
207
- # s = .74 * s # prosody / arousal & fading unvoiced syllabes [x0.7 - x1.2]
208
 
209
- d = model.predictor.text_encoder(d_en,
210
- s, input_lengths, text_mask)
 
 
 
 
 
 
 
 
211
 
212
- x, _ = model.predictor.lstm(d)
213
- duration = model.predictor.duration_proj(x)
214
 
215
  duration = torch.sigmoid(duration).sum(axis=-1)
216
  pred_dur = torch.round(duration.squeeze()).clamp(min=1)
@@ -224,23 +320,25 @@ def inference(text,
224
 
225
  # encode prosody
226
  en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
227
- if model_params.decoder.type == "hifigan":
228
- asr_new = torch.zeros_like(en)
229
- asr_new[:, :, 0] = en[:, :, 0]
230
- asr_new[:, :, 1:] = en[:, :, 0:-1]
231
- en = asr_new
232
-
233
- F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
234
 
235
  asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
236
- if model_params.decoder.type == "hifigan":
237
- asr_new = torch.zeros_like(asr)
238
- asr_new[:, :, 0] = asr[:, :, 0]
239
- asr_new[:, :, 1:] = asr[:, :, 0:-1]
240
- asr = asr_new
241
-
242
- x = model.decoder(asr,
243
- F0_pred, N_pred, ref.squeeze().unsqueeze(0))
 
 
244
 
245
  x = x.cpu().numpy()[0, 0, :-400] # weird pulse at the end of sentences
246
 
@@ -299,6 +397,11 @@ import re
299
  from num2words import num2words
300
 
301
  PHONEME_MAP = {
 
 
 
 
 
302
  'q': 'ku',
303
  'w': 'aou',
304
  'z': 's',
 
7
  import yaml
8
  import torchaudio
9
  import librosa
10
+ from models import ProsodyPredictor, TextEncoder, StyleEncoder, load_ASR_models, load_F0_models
 
11
  from nltk.tokenize import word_tokenize
12
 
13
  torch.manual_seed(0)
 
61
  return f
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
64
  # ======== UTILS ABOVE
65
 
66
  def length_to_mask(lengths):
 
82
  mel_tensor = preprocess(audio).to(device)
83
 
84
  with torch.no_grad():
85
+ ref_s = style_encoder(mel_tensor.unsqueeze(1))
86
+ ref_p = predictor_encoder(mel_tensor.unsqueeze(1)) # [bs, 11, 1, 128]
87
+ print(f'\n\n\n\nCOMPUTE STYLe {ref_s.shape=} {ref_p.shape=}')
88
+ return torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
89
 
90
  device = 'cpu'
91
  if torch.cuda.is_available():
 
100
  # phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
101
 
102
 
103
+ args = yaml.safe_load(open(str('Utils/config.yml')))
104
+ ASR_config = args['ASR_config']
105
 
106
+ ASR_path = args['ASR_path']
107
+ text_aligner = load_ASR_models(ASR_path, ASR_config).eval().to(device)
 
 
108
 
109
+ F0_path = args['F0_path']
110
+ pitch_extractor = load_F0_models(F0_path).eval().to(device)
 
111
 
 
112
  from Utils.PLBERT.util import load_plbert
113
+ bert = load_plbert(args['PLBERT_dir']).eval().to(device)
114
+ # model_params = recursive_munch(config['model_params'])
115
+ # --
116
+ # def build_model(args, text_aligner, pitch_extractor, bert):
117
+ # print(f'\n==============\n {args.decoder.type=}\n==============L584 models.py @ build_model()\n')
118
+ # # ======================================
119
+ # In [4]: args['model_params']
120
+ # Out[4]:
121
+ # {'decoder': {'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
122
+ # 'resblock_kernel_sizes': [3, 7, 11],
123
+ # 'type': 'hifigan',
124
+ # 'upsample_initial_channel': 512,
125
+ # 'upsample_kernel_sizes': [20, 10, 6, 4],
126
+ # 'upsample_rates': [10, 5, 3, 2]},
127
+ # 'diffusion': {'dist': {'estimate_sigma_data': True,
128
+ # 'mean': -3.0,
129
+ # 'sigma_data': 0.19926648961191362,
130
+ # 'std': 1.0},
131
+ # 'embedding_mask_proba': 0.1,
132
+ # 'transformer': {'head_features': 64,
133
+ # 'multiplier': 2,
134
+ # 'num_heads': 8,
135
+ # 'num_layers': 3}},
136
+ # 'dim_in': 64,
137
+ # 'dropout': 0.2,
138
+ # 'hidden_dim': 512,
139
+ # 'max_conv_dim': 512,
140
+ # 'max_dur': 50,
141
+ # 'multispeaker': True,
142
+ # 'n_layer': 3,
143
+ # 'n_mels': 80,
144
+ # 'n_token': 178,
145
+ # 'slm': {'hidden': 768,
146
+ # 'initial_channel': 64,
147
+ # 'model': 'microsoft/wavlm-base-plus',
148
+ # 'nlayers': 13,
149
+ # 'sr': 16000},
150
+ # 'style_dim': 128}
151
+ # # ===============================================
152
+ from Modules.hifigan import Decoder
153
+ decoder = Decoder(dim_in=512,
154
+ style_dim=128,
155
+ dim_out=80, # n_mels
156
+ resblock_kernel_sizes = [3, 7, 11],
157
+ upsample_rates = [10, 5, 3, 2],
158
+ upsample_initial_channel=512,
159
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
160
+ upsample_kernel_sizes=[20, 10, 6, 4]).eval().to(device)
161
+
162
+ text_encoder = TextEncoder(channels=512,
163
+ kernel_size=5,
164
+ depth=3, #args['model_params']['n_layer'],
165
+ n_symbols=178, #args['model_params']['n_token']
166
+ ).eval().to(device)
167
+
168
+ predictor = ProsodyPredictor(style_dim=128,
169
+ d_hid=512,
170
+ nlayers=3, # OFFICIAL config.nlayers=5;
171
+ max_dur=50,
172
+ dropout=.2).eval().to(device)
173
+
174
+ style_encoder = StyleEncoder(dim_in=64,
175
+ style_dim=128,
176
+ max_conv_dim=512).eval().to(device) # acoustic style encoder
177
+ predictor_encoder = StyleEncoder(dim_in=64,
178
+ style_dim=128,
179
+ max_conv_dim=512).eval().to(device) # prosodic style encoder
180
+ bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device)
181
+ # --
182
+ # model = build_model(model_params, text_aligner, pitch_extractor, plbert)
183
+ # _ = [model[key].eval() for key in model]
184
+ # _ = [model[key].to(device) for key in model]
185
 
186
  # params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
187
  # params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
188
  params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
189
  params = params_whole['net']
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ # 'bert',
193
+ # 'bert_encoder',
194
+ # 'predictor',
195
+ # 'decoder',
196
+ # 'text_encoder',
197
+ # 'predictor_encoder',
198
+ # 'style_encoder',
199
+ # 'text_aligner',
200
+ # 'pitch_extractor'
201
+ # --
202
+ from collections import OrderedDict
203
+
204
+ new_state_dict = OrderedDict()
205
+ for k, v in params['bert'].items():
206
+ new_state_dict[k[7:]] = v # del 'module.'
207
+ bert.load_state_dict(new_state_dict, strict=True)
208
+ # --
209
+ new_state_dict = OrderedDict()
210
+ for k, v in params['bert_encoder'].items():
211
+ new_state_dict[k[7:]] = v # del 'module.'
212
+ bert_encoder.load_state_dict(new_state_dict, strict=True)
213
+ # --
214
+ new_state_dict = OrderedDict()
215
+ for k, v in params['predictor'].items():
216
+ new_state_dict[k[7:]] = v # del 'module.'
217
+ predictor.load_state_dict(new_state_dict, strict=True) # XTRA non-ckpt LSTMs nlayers add slowiness to voice
218
+ # --
219
+ new_state_dict = OrderedDict()
220
+ for k, v in params['decoder'].items():
221
+ new_state_dict[k[7:]] = v
222
+ decoder.load_state_dict(new_state_dict, strict=True)
223
+ # --
224
+ new_state_dict = OrderedDict()
225
+ for k, v in params['text_encoder'].items():
226
+ new_state_dict[k[7:]] = v
227
+ text_encoder.load_state_dict(new_state_dict, strict=True)
228
+ # --
229
+ new_state_dict = OrderedDict()
230
+ for k, v in params['predictor_encoder'].items():
231
+ new_state_dict[k[7:]] = v
232
+ predictor_encoder.load_state_dict(new_state_dict, strict=True)
233
+ # --
234
+ new_state_dict = OrderedDict()
235
+ for k, v in params['style_encoder'].items():
236
+ new_state_dict[k[7:]] = v
237
+ style_encoder.load_state_dict(new_state_dict, strict=True)
238
+ # --
239
+ new_state_dict = OrderedDict()
240
+ for k, v in params['text_aligner'].items():
241
+ new_state_dict[k[7:]] = v # del 'module.'
242
+ text_aligner.load_state_dict(new_state_dict, strict=True)
243
+ # --
244
+ new_state_dict = OrderedDict()
245
+ for k, v in params['pitch_extractor'].items():
246
+ new_state_dict[k[7:]] = v
247
+ pitch_extractor.load_state_dict(new_state_dict, strict=True)
248
 
249
 
250
  def inference(text,
 
282
  # 54, 156, 63, 158, 147, 83, 56, 16, 4]], device='cuda:0')
283
 
284
 
285
+ t_en = text_encoder(tokens, input_lengths, text_mask)
286
+ bert_dur = bert(tokens, attention_mask=(~text_mask).int())
287
+ d_en = bert_encoder(bert_dur).transpose(-1, -2)
288
  # print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
289
  # BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
290
 
291
 
292
 
293
+ ref = ref_s[:, :, :, :128] # [bs, 11, 1, 128]
294
+ s = ref_s[:, :, :, 128:] # have channels as last dim so it can go through nn.Linear layers
295
 
 
296
 
297
+ # ON compute style we dont know yet the size to interpolate
298
+ # Perhaps we can interpolate ref_s here as now we know how many bert time-frames the text needs
299
+ # s = .74 * s # prosody / arousal & fading unvoiced syllabes [x0.7 - x1.2]
300
+
301
+
302
+ print(f'{d_en.shape=} {s.shape=} {input_lengths.shape=} {text_mask.shape=}')
303
+ d = predictor.text_encoder(d_en,
304
+ s,
305
+ input_lengths,
306
+ text_mask)
307
 
308
+ x, _ = predictor.lstm(d)
309
+ duration = predictor.duration_proj(x)
310
 
311
  duration = torch.sigmoid(duration).sum(axis=-1)
312
  pred_dur = torch.round(duration.squeeze()).clamp(min=1)
 
320
 
321
  # encode prosody
322
  en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
323
+
324
+ asr_new = torch.zeros_like(en)
325
+ asr_new[:, :, 0] = en[:, :, 0]
326
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
327
+ en = asr_new
328
+ print('_________________________________________F0_____________________________')
329
+ F0_pred, N_pred = predictor.F0Ntrain(en, s)
330
 
331
  asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
332
+
333
+ asr_new = torch.zeros_like(asr)
334
+ asr_new[:, :, 0] = asr[:, :, 0]
335
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
336
+ asr = asr_new
337
+ print('_________________________________________HiFI_____________________________')
338
+ x = decoder(asr=asr,
339
+ F0_curve=F0_pred,
340
+ N=N_pred,
341
+ s=ref)
342
 
343
  x = x.cpu().numpy()[0, 0, :-400] # weird pulse at the end of sentences
344
 
 
397
  from num2words import num2words
398
 
399
  PHONEME_MAP = {
400
+ 'služ' : 'sloooozz', # 'službeno'
401
+ 'suver': 'siuveeerra', # 'suverena'
402
+ 'država': 'dirrezav', # 'država'
403
+ 'iči': 'ici', # 'Graniči'
404
+ 's ': 'se', # a s with space
405
  'q': 'ku',
406
  'w': 'aou',
407
  'z': 's',
requirements.txt CHANGED
@@ -13,7 +13,7 @@ omegaconf
13
  opencv-python
14
  soundfile
15
  transformers
16
- munch
17
  srt
18
  nltk
19
  phonemizer
 
13
  opencv-python
14
  soundfile
15
  transformers
16
+ audresample
17
  srt
18
  nltk
19
  phonemizer