hexgrad commited on
Commit
be929c2
·
verified ·
1 Parent(s): ab10220

Delete models.py

Browse files
Files changed (1) hide show
  1. models.py +0 -577
models.py DELETED
@@ -1,577 +0,0 @@
1
- # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
- from istftnet import Decoder
3
- from munch import Munch
4
- from plbert import load_plbert
5
- from torch.nn.utils import weight_norm, spectral_norm
6
- import numpy as np
7
- import os
8
- import os.path as osp
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
-
13
- class LearnedDownSample(nn.Module):
14
- def __init__(self, layer_type, dim_in):
15
- super().__init__()
16
- self.layer_type = layer_type
17
-
18
- if self.layer_type == 'none':
19
- self.conv = nn.Identity()
20
- elif self.layer_type == 'timepreserve':
21
- self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
22
- elif self.layer_type == 'half':
23
- self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
24
- else:
25
- raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
26
-
27
- def forward(self, x):
28
- return self.conv(x)
29
-
30
- class LearnedUpSample(nn.Module):
31
- def __init__(self, layer_type, dim_in):
32
- super().__init__()
33
- self.layer_type = layer_type
34
-
35
- if self.layer_type == 'none':
36
- self.conv = nn.Identity()
37
- elif self.layer_type == 'timepreserve':
38
- self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
39
- elif self.layer_type == 'half':
40
- self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
41
- else:
42
- raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
43
-
44
-
45
- def forward(self, x):
46
- return self.conv(x)
47
-
48
- class DownSample(nn.Module):
49
- def __init__(self, layer_type):
50
- super().__init__()
51
- self.layer_type = layer_type
52
-
53
- def forward(self, x):
54
- if self.layer_type == 'none':
55
- return x
56
- elif self.layer_type == 'timepreserve':
57
- return F.avg_pool2d(x, (2, 1))
58
- elif self.layer_type == 'half':
59
- if x.shape[-1] % 2 != 0:
60
- x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
61
- return F.avg_pool2d(x, 2)
62
- else:
63
- raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
64
-
65
-
66
- class UpSample(nn.Module):
67
- def __init__(self, layer_type):
68
- super().__init__()
69
- self.layer_type = layer_type
70
-
71
- def forward(self, x):
72
- if self.layer_type == 'none':
73
- return x
74
- elif self.layer_type == 'timepreserve':
75
- return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
76
- elif self.layer_type == 'half':
77
- return F.interpolate(x, scale_factor=2, mode='nearest')
78
- else:
79
- raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
80
-
81
-
82
- class ResBlk(nn.Module):
83
- def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
84
- normalize=False, downsample='none'):
85
- super().__init__()
86
- self.actv = actv
87
- self.normalize = normalize
88
- self.downsample = DownSample(downsample)
89
- self.downsample_res = LearnedDownSample(downsample, dim_in)
90
- self.learned_sc = dim_in != dim_out
91
- self._build_weights(dim_in, dim_out)
92
-
93
- def _build_weights(self, dim_in, dim_out):
94
- self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
95
- self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
96
- if self.normalize:
97
- self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
98
- self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
99
- if self.learned_sc:
100
- self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
101
-
102
- def _shortcut(self, x):
103
- if self.learned_sc:
104
- x = self.conv1x1(x)
105
- if self.downsample:
106
- x = self.downsample(x)
107
- return x
108
-
109
- def _residual(self, x):
110
- if self.normalize:
111
- x = self.norm1(x)
112
- x = self.actv(x)
113
- x = self.conv1(x)
114
- x = self.downsample_res(x)
115
- if self.normalize:
116
- x = self.norm2(x)
117
- x = self.actv(x)
118
- x = self.conv2(x)
119
- return x
120
-
121
- def forward(self, x):
122
- x = self._shortcut(x) + self._residual(x)
123
- return x / np.sqrt(2) # unit variance
124
-
125
- class LinearNorm(torch.nn.Module):
126
- def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
127
- super(LinearNorm, self).__init__()
128
- self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
129
-
130
- torch.nn.init.xavier_uniform_(
131
- self.linear_layer.weight,
132
- gain=torch.nn.init.calculate_gain(w_init_gain))
133
-
134
- def forward(self, x):
135
- return self.linear_layer(x)
136
-
137
- class Discriminator2d(nn.Module):
138
- def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
139
- super().__init__()
140
- blocks = []
141
- blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
142
-
143
- for lid in range(repeat_num):
144
- dim_out = min(dim_in*2, max_conv_dim)
145
- blocks += [ResBlk(dim_in, dim_out, downsample='half')]
146
- dim_in = dim_out
147
-
148
- blocks += [nn.LeakyReLU(0.2)]
149
- blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
150
- blocks += [nn.LeakyReLU(0.2)]
151
- blocks += [nn.AdaptiveAvgPool2d(1)]
152
- blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
153
- self.main = nn.Sequential(*blocks)
154
-
155
- def get_feature(self, x):
156
- features = []
157
- for l in self.main:
158
- x = l(x)
159
- features.append(x)
160
- out = features[-1]
161
- out = out.view(out.size(0), -1) # (batch, num_domains)
162
- return out, features
163
-
164
- def forward(self, x):
165
- out, features = self.get_feature(x)
166
- out = out.squeeze() # (batch)
167
- return out, features
168
-
169
- class ResBlk1d(nn.Module):
170
- def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
171
- normalize=False, downsample='none', dropout_p=0.2):
172
- super().__init__()
173
- self.actv = actv
174
- self.normalize = normalize
175
- self.downsample_type = downsample
176
- self.learned_sc = dim_in != dim_out
177
- self._build_weights(dim_in, dim_out)
178
- self.dropout_p = dropout_p
179
-
180
- if self.downsample_type == 'none':
181
- self.pool = nn.Identity()
182
- else:
183
- self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
184
-
185
- def _build_weights(self, dim_in, dim_out):
186
- self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
187
- self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
188
- if self.normalize:
189
- self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
190
- self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
191
- if self.learned_sc:
192
- self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
193
-
194
- def downsample(self, x):
195
- if self.downsample_type == 'none':
196
- return x
197
- else:
198
- if x.shape[-1] % 2 != 0:
199
- x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
200
- return F.avg_pool1d(x, 2)
201
-
202
- def _shortcut(self, x):
203
- if self.learned_sc:
204
- x = self.conv1x1(x)
205
- x = self.downsample(x)
206
- return x
207
-
208
- def _residual(self, x):
209
- if self.normalize:
210
- x = self.norm1(x)
211
- x = self.actv(x)
212
- x = F.dropout(x, p=self.dropout_p, training=self.training)
213
-
214
- x = self.conv1(x)
215
- x = self.pool(x)
216
- if self.normalize:
217
- x = self.norm2(x)
218
-
219
- x = self.actv(x)
220
- x = F.dropout(x, p=self.dropout_p, training=self.training)
221
-
222
- x = self.conv2(x)
223
- return x
224
-
225
- def forward(self, x):
226
- x = self._shortcut(x) + self._residual(x)
227
- return x / np.sqrt(2) # unit variance
228
-
229
- class LayerNorm(nn.Module):
230
- def __init__(self, channels, eps=1e-5):
231
- super().__init__()
232
- self.channels = channels
233
- self.eps = eps
234
-
235
- self.gamma = nn.Parameter(torch.ones(channels))
236
- self.beta = nn.Parameter(torch.zeros(channels))
237
-
238
- def forward(self, x):
239
- x = x.transpose(1, -1)
240
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
241
- return x.transpose(1, -1)
242
-
243
- class TextEncoder(nn.Module):
244
- def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
245
- super().__init__()
246
- self.embedding = nn.Embedding(n_symbols, channels)
247
-
248
- padding = (kernel_size - 1) // 2
249
- self.cnn = nn.ModuleList()
250
- for _ in range(depth):
251
- self.cnn.append(nn.Sequential(
252
- weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
253
- LayerNorm(channels),
254
- actv,
255
- nn.Dropout(0.2),
256
- ))
257
- # self.cnn = nn.Sequential(*self.cnn)
258
-
259
- self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
260
-
261
- def forward(self, x, input_lengths, m):
262
- x = self.embedding(x) # [B, T, emb]
263
- x = x.transpose(1, 2) # [B, emb, T]
264
- m = m.to(input_lengths.device).unsqueeze(1)
265
- x.masked_fill_(m, 0.0)
266
-
267
- for c in self.cnn:
268
- x = c(x)
269
- x.masked_fill_(m, 0.0)
270
-
271
- x = x.transpose(1, 2) # [B, T, chn]
272
-
273
- input_lengths = input_lengths.cpu().numpy()
274
- x = nn.utils.rnn.pack_padded_sequence(
275
- x, input_lengths, batch_first=True, enforce_sorted=False)
276
-
277
- self.lstm.flatten_parameters()
278
- x, _ = self.lstm(x)
279
- x, _ = nn.utils.rnn.pad_packed_sequence(
280
- x, batch_first=True)
281
-
282
- x = x.transpose(-1, -2)
283
- x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
284
-
285
- x_pad[:, :, :x.shape[-1]] = x
286
- x = x_pad.to(x.device)
287
-
288
- x.masked_fill_(m, 0.0)
289
-
290
- return x
291
-
292
- def inference(self, x):
293
- x = self.embedding(x)
294
- x = x.transpose(1, 2)
295
- x = self.cnn(x)
296
- x = x.transpose(1, 2)
297
- self.lstm.flatten_parameters()
298
- x, _ = self.lstm(x)
299
- return x
300
-
301
- def length_to_mask(self, lengths):
302
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
303
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
304
- return mask
305
-
306
-
307
-
308
- class AdaIN1d(nn.Module):
309
- def __init__(self, style_dim, num_features):
310
- super().__init__()
311
- self.norm = nn.InstanceNorm1d(num_features, affine=False)
312
- self.fc = nn.Linear(style_dim, num_features*2)
313
-
314
- def forward(self, x, s):
315
- h = self.fc(s)
316
- h = h.view(h.size(0), h.size(1), 1)
317
- gamma, beta = torch.chunk(h, chunks=2, dim=1)
318
- return (1 + gamma) * self.norm(x) + beta
319
-
320
- class UpSample1d(nn.Module):
321
- def __init__(self, layer_type):
322
- super().__init__()
323
- self.layer_type = layer_type
324
-
325
- def forward(self, x):
326
- if self.layer_type == 'none':
327
- return x
328
- else:
329
- return F.interpolate(x, scale_factor=2, mode='nearest')
330
-
331
- class AdainResBlk1d(nn.Module):
332
- def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
333
- upsample='none', dropout_p=0.0):
334
- super().__init__()
335
- self.actv = actv
336
- self.upsample_type = upsample
337
- self.upsample = UpSample1d(upsample)
338
- self.learned_sc = dim_in != dim_out
339
- self._build_weights(dim_in, dim_out, style_dim)
340
- self.dropout = nn.Dropout(dropout_p)
341
-
342
- if upsample == 'none':
343
- self.pool = nn.Identity()
344
- else:
345
- self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
346
-
347
-
348
- def _build_weights(self, dim_in, dim_out, style_dim):
349
- self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
350
- self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
351
- self.norm1 = AdaIN1d(style_dim, dim_in)
352
- self.norm2 = AdaIN1d(style_dim, dim_out)
353
- if self.learned_sc:
354
- self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
355
-
356
- def _shortcut(self, x):
357
- x = self.upsample(x)
358
- if self.learned_sc:
359
- x = self.conv1x1(x)
360
- return x
361
-
362
- def _residual(self, x, s):
363
- x = self.norm1(x, s)
364
- x = self.actv(x)
365
- x = self.pool(x)
366
- x = self.conv1(self.dropout(x))
367
- x = self.norm2(x, s)
368
- x = self.actv(x)
369
- x = self.conv2(self.dropout(x))
370
- return x
371
-
372
- def forward(self, x, s):
373
- out = self._residual(x, s)
374
- out = (out + self._shortcut(x)) / np.sqrt(2)
375
- return out
376
-
377
- class AdaLayerNorm(nn.Module):
378
- def __init__(self, style_dim, channels, eps=1e-5):
379
- super().__init__()
380
- self.channels = channels
381
- self.eps = eps
382
-
383
- self.fc = nn.Linear(style_dim, channels*2)
384
-
385
- def forward(self, x, s):
386
- x = x.transpose(-1, -2)
387
- x = x.transpose(1, -1)
388
-
389
- h = self.fc(s)
390
- h = h.view(h.size(0), h.size(1), 1)
391
- gamma, beta = torch.chunk(h, chunks=2, dim=1)
392
- gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
393
-
394
-
395
- x = F.layer_norm(x, (self.channels,), eps=self.eps)
396
- x = (1 + gamma) * x + beta
397
- return x.transpose(1, -1).transpose(-1, -2)
398
-
399
- class ProsodyPredictor(nn.Module):
400
-
401
- def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
402
- super().__init__()
403
-
404
- self.text_encoder = DurationEncoder(sty_dim=style_dim,
405
- d_model=d_hid,
406
- nlayers=nlayers,
407
- dropout=dropout)
408
-
409
- self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
410
- self.duration_proj = LinearNorm(d_hid, max_dur)
411
-
412
- self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
413
- self.F0 = nn.ModuleList()
414
- self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
415
- self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
416
- self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
417
-
418
- self.N = nn.ModuleList()
419
- self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
420
- self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
421
- self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
422
-
423
- self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
424
- self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
425
-
426
-
427
- def forward(self, texts, style, text_lengths, alignment, m):
428
- d = self.text_encoder(texts, style, text_lengths, m)
429
-
430
- batch_size = d.shape[0]
431
- text_size = d.shape[1]
432
-
433
- # predict duration
434
- input_lengths = text_lengths.cpu().numpy()
435
- x = nn.utils.rnn.pack_padded_sequence(
436
- d, input_lengths, batch_first=True, enforce_sorted=False)
437
-
438
- m = m.to(text_lengths.device).unsqueeze(1)
439
-
440
- self.lstm.flatten_parameters()
441
- x, _ = self.lstm(x)
442
- x, _ = nn.utils.rnn.pad_packed_sequence(
443
- x, batch_first=True)
444
-
445
- x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
446
-
447
- x_pad[:, :x.shape[1], :] = x
448
- x = x_pad.to(x.device)
449
-
450
- duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
451
-
452
- en = (d.transpose(-1, -2) @ alignment)
453
-
454
- return duration.squeeze(-1), en
455
-
456
- def F0Ntrain(self, x, s):
457
- x, _ = self.shared(x.transpose(-1, -2))
458
-
459
- F0 = x.transpose(-1, -2)
460
- for block in self.F0:
461
- F0 = block(F0, s)
462
- F0 = self.F0_proj(F0)
463
-
464
- N = x.transpose(-1, -2)
465
- for block in self.N:
466
- N = block(N, s)
467
- N = self.N_proj(N)
468
-
469
- return F0.squeeze(1), N.squeeze(1)
470
-
471
- def length_to_mask(self, lengths):
472
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
473
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
474
- return mask
475
-
476
- class DurationEncoder(nn.Module):
477
-
478
- def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
479
- super().__init__()
480
- self.lstms = nn.ModuleList()
481
- for _ in range(nlayers):
482
- self.lstms.append(nn.LSTM(d_model + sty_dim,
483
- d_model // 2,
484
- num_layers=1,
485
- batch_first=True,
486
- bidirectional=True,
487
- dropout=dropout))
488
- self.lstms.append(AdaLayerNorm(sty_dim, d_model))
489
-
490
-
491
- self.dropout = dropout
492
- self.d_model = d_model
493
- self.sty_dim = sty_dim
494
-
495
- def forward(self, x, style, text_lengths, m):
496
- masks = m.to(text_lengths.device)
497
-
498
- x = x.permute(2, 0, 1)
499
- s = style.expand(x.shape[0], x.shape[1], -1)
500
- x = torch.cat([x, s], axis=-1)
501
- x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
502
-
503
- x = x.transpose(0, 1)
504
- input_lengths = text_lengths.cpu().numpy()
505
- x = x.transpose(-1, -2)
506
-
507
- for block in self.lstms:
508
- if isinstance(block, AdaLayerNorm):
509
- x = block(x.transpose(-1, -2), style).transpose(-1, -2)
510
- x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
511
- x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
512
- else:
513
- x = x.transpose(-1, -2)
514
- x = nn.utils.rnn.pack_padded_sequence(
515
- x, input_lengths, batch_first=True, enforce_sorted=False)
516
- block.flatten_parameters()
517
- x, _ = block(x)
518
- x, _ = nn.utils.rnn.pad_packed_sequence(
519
- x, batch_first=True)
520
- x = F.dropout(x, p=self.dropout, training=self.training)
521
- x = x.transpose(-1, -2)
522
-
523
- x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
524
-
525
- x_pad[:, :, :x.shape[-1]] = x
526
- x = x_pad.to(x.device)
527
-
528
- return x.transpose(-1, -2)
529
-
530
- def inference(self, x, style):
531
- x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
532
- style = style.expand(x.shape[0], x.shape[1], -1)
533
- x = torch.cat([x, style], axis=-1)
534
- src = self.pos_encoder(x)
535
- output = self.transformer_encoder(src).transpose(0, 1)
536
- return output
537
-
538
- def length_to_mask(self, lengths):
539
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
540
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
541
- return mask
542
-
543
- # https://github.com/yl4579/StyleTTS2/blob/main/utils.py
544
- def recursive_munch(d):
545
- if isinstance(d, dict):
546
- return Munch((k, recursive_munch(v)) for k, v in d.items())
547
- elif isinstance(d, list):
548
- return [recursive_munch(v) for v in d]
549
- else:
550
- return d
551
-
552
- def build_model(args, device):
553
- args = recursive_munch(args)
554
- assert args.decoder.type == 'istftnet', 'Decoder type unknown'
555
- decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
556
- resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
557
- upsample_rates = args.decoder.upsample_rates,
558
- upsample_initial_channel=args.decoder.upsample_initial_channel,
559
- resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
560
- upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
561
- gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
562
- text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
563
- predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
564
- bert = load_plbert()
565
- bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
566
- for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
567
- for child in parent.children():
568
- if isinstance(child, nn.RNNBase):
569
- child.flatten_parameters()
570
- model = Munch(
571
- bert=bert.to(device).eval(),
572
- bert_encoder=bert_encoder.to(device).eval(),
573
- predictor=predictor.to(device).eval(),
574
- decoder=decoder.to(device).eval(),
575
- text_encoder=text_encoder.to(device).eval(),
576
- )
577
- return model