kevinwang676 commited on
Commit
e3a3acc
·
verified ·
1 Parent(s): c3b1ea7

Create model.py

Browse files
Files changed (1) hide show
  1. GPT_SoVITS/module/model.py +1030 -0
GPT_SoVITS/module/model.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import copy
4
+ import math
5
+ import os
6
+ import pdb
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from module import commons
13
+ from module import modules
14
+ from module import attentions
15
+
16
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
17
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
18
+ from module.commons import init_weights, get_padding
19
+ from module.mrte_model import MRTE
20
+ from module.quantize import ResidualVectorQuantizer
21
+ # from text import symbols
22
+ from text import symbols as symbols_v1
23
+ from text import symbols2 as symbols_v2
24
+ from torch.cuda.amp import autocast
25
+ import contextlib
26
+
27
+
28
+ class StochasticDurationPredictor(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_channels,
32
+ filter_channels,
33
+ kernel_size,
34
+ p_dropout,
35
+ n_flows=4,
36
+ gin_channels=0,
37
+ ):
38
+ super().__init__()
39
+ filter_channels = in_channels # it needs to be removed from future version.
40
+ self.in_channels = in_channels
41
+ self.filter_channels = filter_channels
42
+ self.kernel_size = kernel_size
43
+ self.p_dropout = p_dropout
44
+ self.n_flows = n_flows
45
+ self.gin_channels = gin_channels
46
+
47
+ self.log_flow = modules.Log()
48
+ self.flows = nn.ModuleList()
49
+ self.flows.append(modules.ElementwiseAffine(2))
50
+ for i in range(n_flows):
51
+ self.flows.append(
52
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
53
+ )
54
+ self.flows.append(modules.Flip())
55
+
56
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
57
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
58
+ self.post_convs = modules.DDSConv(
59
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
60
+ )
61
+ self.post_flows = nn.ModuleList()
62
+ self.post_flows.append(modules.ElementwiseAffine(2))
63
+ for i in range(4):
64
+ self.post_flows.append(
65
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
66
+ )
67
+ self.post_flows.append(modules.Flip())
68
+
69
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
70
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
71
+ self.convs = modules.DDSConv(
72
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
73
+ )
74
+ if gin_channels != 0:
75
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
76
+
77
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
78
+ x = torch.detach(x)
79
+ x = self.pre(x)
80
+ if g is not None:
81
+ g = torch.detach(g)
82
+ x = x + self.cond(g)
83
+ x = self.convs(x, x_mask)
84
+ x = self.proj(x) * x_mask
85
+
86
+ if not reverse:
87
+ flows = self.flows
88
+ assert w is not None
89
+
90
+ logdet_tot_q = 0
91
+ h_w = self.post_pre(w)
92
+ h_w = self.post_convs(h_w, x_mask)
93
+ h_w = self.post_proj(h_w) * x_mask
94
+ e_q = (
95
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
96
+ * x_mask
97
+ )
98
+ z_q = e_q
99
+ for flow in self.post_flows:
100
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
101
+ logdet_tot_q += logdet_q
102
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
103
+ u = torch.sigmoid(z_u) * x_mask
104
+ z0 = (w - u) * x_mask
105
+ logdet_tot_q += torch.sum(
106
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
107
+ )
108
+ logq = (
109
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
110
+ - logdet_tot_q
111
+ )
112
+
113
+ logdet_tot = 0
114
+ z0, logdet = self.log_flow(z0, x_mask)
115
+ logdet_tot += logdet
116
+ z = torch.cat([z0, z1], 1)
117
+ for flow in flows:
118
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
119
+ logdet_tot = logdet_tot + logdet
120
+ nll = (
121
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
122
+ - logdet_tot
123
+ )
124
+ return nll + logq # [b]
125
+ else:
126
+ flows = list(reversed(self.flows))
127
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
128
+ z = (
129
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
130
+ * noise_scale
131
+ )
132
+ for flow in flows:
133
+ z = flow(z, x_mask, g=x, reverse=reverse)
134
+ z0, z1 = torch.split(z, [1, 1], 1)
135
+ logw = z0
136
+ return logw
137
+
138
+
139
+ class DurationPredictor(nn.Module):
140
+ def __init__(
141
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
142
+ ):
143
+ super().__init__()
144
+
145
+ self.in_channels = in_channels
146
+ self.filter_channels = filter_channels
147
+ self.kernel_size = kernel_size
148
+ self.p_dropout = p_dropout
149
+ self.gin_channels = gin_channels
150
+
151
+ self.drop = nn.Dropout(p_dropout)
152
+ self.conv_1 = nn.Conv1d(
153
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
154
+ )
155
+ self.norm_1 = modules.LayerNorm(filter_channels)
156
+ self.conv_2 = nn.Conv1d(
157
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
158
+ )
159
+ self.norm_2 = modules.LayerNorm(filter_channels)
160
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
161
+
162
+ if gin_channels != 0:
163
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
164
+
165
+ def forward(self, x, x_mask, g=None):
166
+ x = torch.detach(x)
167
+ if g is not None:
168
+ g = torch.detach(g)
169
+ x = x + self.cond(g)
170
+ x = self.conv_1(x * x_mask)
171
+ x = torch.relu(x)
172
+ x = self.norm_1(x)
173
+ x = self.drop(x)
174
+ x = self.conv_2(x * x_mask)
175
+ x = torch.relu(x)
176
+ x = self.norm_2(x)
177
+ x = self.drop(x)
178
+ x = self.proj(x * x_mask)
179
+ return x * x_mask
180
+
181
+
182
+ class TextEncoder(nn.Module):
183
+ def __init__(
184
+ self,
185
+ out_channels,
186
+ hidden_channels,
187
+ filter_channels,
188
+ n_heads,
189
+ n_layers,
190
+ kernel_size,
191
+ p_dropout,
192
+ latent_channels=192,
193
+ version = "v2",
194
+ ):
195
+ super().__init__()
196
+ self.out_channels = out_channels
197
+ self.hidden_channels = hidden_channels
198
+ self.filter_channels = filter_channels
199
+ self.n_heads = n_heads
200
+ self.n_layers = n_layers
201
+ self.kernel_size = kernel_size
202
+ self.p_dropout = p_dropout
203
+ self.latent_channels = latent_channels
204
+ self.version = version
205
+
206
+ self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
207
+
208
+ self.encoder_ssl = attentions.Encoder(
209
+ hidden_channels,
210
+ filter_channels,
211
+ n_heads,
212
+ n_layers // 2,
213
+ kernel_size,
214
+ p_dropout,
215
+ )
216
+
217
+ self.encoder_text = attentions.Encoder(
218
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
219
+ )
220
+
221
+ if self.version == "v1":
222
+ symbols = symbols_v1.symbols
223
+ else:
224
+ symbols = symbols_v2.symbols
225
+ self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
226
+
227
+ self.mrte = MRTE()
228
+
229
+ self.encoder2 = attentions.Encoder(
230
+ hidden_channels,
231
+ filter_channels,
232
+ n_heads,
233
+ n_layers // 2,
234
+ kernel_size,
235
+ p_dropout,
236
+ )
237
+
238
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
239
+
240
+ def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None):
241
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
242
+ y.dtype
243
+ )
244
+
245
+ y = self.ssl_proj(y * y_mask) * y_mask
246
+
247
+ y = self.encoder_ssl(y * y_mask, y_mask)
248
+
249
+ text_mask = torch.unsqueeze(
250
+ commons.sequence_mask(text_lengths, text.size(1)), 1
251
+ ).to(y.dtype)
252
+ if test == 1:
253
+ text[:, :] = 0
254
+ text = self.text_embedding(text).transpose(1, 2)
255
+ text = self.encoder_text(text * text_mask, text_mask)
256
+ y = self.mrte(y, y_mask, text, text_mask, ge)
257
+ y = self.encoder2(y * y_mask, y_mask)
258
+ if(speed!=1):
259
+ y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
260
+ y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
261
+ stats = self.proj(y) * y_mask
262
+ m, logs = torch.split(stats, self.out_channels, dim=1)
263
+ return y, m, logs, y_mask
264
+
265
+ def extract_latent(self, x):
266
+ x = self.ssl_proj(x)
267
+ quantized, codes, commit_loss, quantized_list = self.quantizer(x)
268
+ return codes.transpose(0, 1)
269
+
270
+ def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
271
+ quantized = self.quantizer.decode(codes)
272
+
273
+ y = self.vq_proj(quantized) * y_mask
274
+ y = self.encoder_ssl(y * y_mask, y_mask)
275
+
276
+ y = self.mrte(y, y_mask, refer, refer_mask, ge)
277
+
278
+ y = self.encoder2(y * y_mask, y_mask)
279
+
280
+ stats = self.proj(y) * y_mask
281
+ m, logs = torch.split(stats, self.out_channels, dim=1)
282
+ return y, m, logs, y_mask, quantized
283
+
284
+
285
+ class ResidualCouplingBlock(nn.Module):
286
+ def __init__(
287
+ self,
288
+ channels,
289
+ hidden_channels,
290
+ kernel_size,
291
+ dilation_rate,
292
+ n_layers,
293
+ n_flows=4,
294
+ gin_channels=0,
295
+ ):
296
+ super().__init__()
297
+ self.channels = channels
298
+ self.hidden_channels = hidden_channels
299
+ self.kernel_size = kernel_size
300
+ self.dilation_rate = dilation_rate
301
+ self.n_layers = n_layers
302
+ self.n_flows = n_flows
303
+ self.gin_channels = gin_channels
304
+
305
+ self.flows = nn.ModuleList()
306
+ for i in range(n_flows):
307
+ self.flows.append(
308
+ modules.ResidualCouplingLayer(
309
+ channels,
310
+ hidden_channels,
311
+ kernel_size,
312
+ dilation_rate,
313
+ n_layers,
314
+ gin_channels=gin_channels,
315
+ mean_only=True,
316
+ )
317
+ )
318
+ self.flows.append(modules.Flip())
319
+
320
+ def forward(self, x, x_mask, g=None, reverse=False):
321
+ if not reverse:
322
+ for flow in self.flows:
323
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
324
+ else:
325
+ for flow in reversed(self.flows):
326
+ x = flow(x, x_mask, g=g, reverse=reverse)
327
+ return x
328
+
329
+
330
+ class PosteriorEncoder(nn.Module):
331
+ def __init__(
332
+ self,
333
+ in_channels,
334
+ out_channels,
335
+ hidden_channels,
336
+ kernel_size,
337
+ dilation_rate,
338
+ n_layers,
339
+ gin_channels=0,
340
+ ):
341
+ super().__init__()
342
+ self.in_channels = in_channels
343
+ self.out_channels = out_channels
344
+ self.hidden_channels = hidden_channels
345
+ self.kernel_size = kernel_size
346
+ self.dilation_rate = dilation_rate
347
+ self.n_layers = n_layers
348
+ self.gin_channels = gin_channels
349
+
350
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
351
+ self.enc = modules.WN(
352
+ hidden_channels,
353
+ kernel_size,
354
+ dilation_rate,
355
+ n_layers,
356
+ gin_channels=gin_channels,
357
+ )
358
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
359
+
360
+ def forward(self, x, x_lengths, g=None):
361
+ if g != None:
362
+ g = g.detach()
363
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
364
+ x.dtype
365
+ )
366
+ x = self.pre(x) * x_mask
367
+ x = self.enc(x, x_mask, g=g)
368
+ stats = self.proj(x) * x_mask
369
+ m, logs = torch.split(stats, self.out_channels, dim=1)
370
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
371
+ return z, m, logs, x_mask
372
+
373
+
374
+ class WNEncoder(nn.Module):
375
+ def __init__(
376
+ self,
377
+ in_channels,
378
+ out_channels,
379
+ hidden_channels,
380
+ kernel_size,
381
+ dilation_rate,
382
+ n_layers,
383
+ gin_channels=0,
384
+ ):
385
+ super().__init__()
386
+ self.in_channels = in_channels
387
+ self.out_channels = out_channels
388
+ self.hidden_channels = hidden_channels
389
+ self.kernel_size = kernel_size
390
+ self.dilation_rate = dilation_rate
391
+ self.n_layers = n_layers
392
+ self.gin_channels = gin_channels
393
+
394
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
395
+ self.enc = modules.WN(
396
+ hidden_channels,
397
+ kernel_size,
398
+ dilation_rate,
399
+ n_layers,
400
+ gin_channels=gin_channels,
401
+ )
402
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
403
+ self.norm = modules.LayerNorm(out_channels)
404
+
405
+ def forward(self, x, x_lengths, g=None):
406
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
407
+ x.dtype
408
+ )
409
+ x = self.pre(x) * x_mask
410
+ x = self.enc(x, x_mask, g=g)
411
+ out = self.proj(x) * x_mask
412
+ out = self.norm(out)
413
+ return out
414
+
415
+
416
+ class Generator(torch.nn.Module):
417
+ def __init__(
418
+ self,
419
+ initial_channel,
420
+ resblock,
421
+ resblock_kernel_sizes,
422
+ resblock_dilation_sizes,
423
+ upsample_rates,
424
+ upsample_initial_channel,
425
+ upsample_kernel_sizes,
426
+ gin_channels=0,
427
+ ):
428
+ super(Generator, self).__init__()
429
+ self.num_kernels = len(resblock_kernel_sizes)
430
+ self.num_upsamples = len(upsample_rates)
431
+ self.conv_pre = Conv1d(
432
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
433
+ )
434
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
435
+
436
+ self.ups = nn.ModuleList()
437
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
438
+ self.ups.append(
439
+ weight_norm(
440
+ ConvTranspose1d(
441
+ upsample_initial_channel // (2**i),
442
+ upsample_initial_channel // (2 ** (i + 1)),
443
+ k,
444
+ u,
445
+ padding=(k - u) // 2,
446
+ )
447
+ )
448
+ )
449
+
450
+ self.resblocks = nn.ModuleList()
451
+ for i in range(len(self.ups)):
452
+ ch = upsample_initial_channel // (2 ** (i + 1))
453
+ for j, (k, d) in enumerate(
454
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
455
+ ):
456
+ self.resblocks.append(resblock(ch, k, d))
457
+
458
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
459
+ self.ups.apply(init_weights)
460
+
461
+ if gin_channels != 0:
462
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
463
+
464
+ def forward(self, x, g=None):
465
+ x = self.conv_pre(x)
466
+ if g is not None:
467
+ x = x + self.cond(g)
468
+
469
+ for i in range(self.num_upsamples):
470
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
471
+ x = self.ups[i](x)
472
+ xs = None
473
+ for j in range(self.num_kernels):
474
+ if xs is None:
475
+ xs = self.resblocks[i * self.num_kernels + j](x)
476
+ else:
477
+ xs += self.resblocks[i * self.num_kernels + j](x)
478
+ x = xs / self.num_kernels
479
+ x = F.leaky_relu(x)
480
+ x = self.conv_post(x)
481
+ x = torch.tanh(x)
482
+
483
+ return x
484
+
485
+ def remove_weight_norm(self):
486
+ print("Removing weight norm...")
487
+ for l in self.ups:
488
+ remove_weight_norm(l)
489
+ for l in self.resblocks:
490
+ l.remove_weight_norm()
491
+
492
+
493
+ class DiscriminatorP(torch.nn.Module):
494
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
495
+ super(DiscriminatorP, self).__init__()
496
+ self.period = period
497
+ self.use_spectral_norm = use_spectral_norm
498
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
499
+ self.convs = nn.ModuleList(
500
+ [
501
+ norm_f(
502
+ Conv2d(
503
+ 1,
504
+ 32,
505
+ (kernel_size, 1),
506
+ (stride, 1),
507
+ padding=(get_padding(kernel_size, 1), 0),
508
+ )
509
+ ),
510
+ norm_f(
511
+ Conv2d(
512
+ 32,
513
+ 128,
514
+ (kernel_size, 1),
515
+ (stride, 1),
516
+ padding=(get_padding(kernel_size, 1), 0),
517
+ )
518
+ ),
519
+ norm_f(
520
+ Conv2d(
521
+ 128,
522
+ 512,
523
+ (kernel_size, 1),
524
+ (stride, 1),
525
+ padding=(get_padding(kernel_size, 1), 0),
526
+ )
527
+ ),
528
+ norm_f(
529
+ Conv2d(
530
+ 512,
531
+ 1024,
532
+ (kernel_size, 1),
533
+ (stride, 1),
534
+ padding=(get_padding(kernel_size, 1), 0),
535
+ )
536
+ ),
537
+ norm_f(
538
+ Conv2d(
539
+ 1024,
540
+ 1024,
541
+ (kernel_size, 1),
542
+ 1,
543
+ padding=(get_padding(kernel_size, 1), 0),
544
+ )
545
+ ),
546
+ ]
547
+ )
548
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
549
+
550
+ def forward(self, x):
551
+ fmap = []
552
+
553
+ # 1d to 2d
554
+ b, c, t = x.shape
555
+ if t % self.period != 0: # pad first
556
+ n_pad = self.period - (t % self.period)
557
+ x = F.pad(x, (0, n_pad), "reflect")
558
+ t = t + n_pad
559
+ x = x.view(b, c, t // self.period, self.period)
560
+
561
+ for l in self.convs:
562
+ x = l(x)
563
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
564
+ fmap.append(x)
565
+ x = self.conv_post(x)
566
+ fmap.append(x)
567
+ x = torch.flatten(x, 1, -1)
568
+
569
+ return x, fmap
570
+
571
+
572
+ class DiscriminatorS(torch.nn.Module):
573
+ def __init__(self, use_spectral_norm=False):
574
+ super(DiscriminatorS, self).__init__()
575
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
576
+ self.convs = nn.ModuleList(
577
+ [
578
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
579
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
580
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
581
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
582
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
583
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
584
+ ]
585
+ )
586
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
587
+
588
+ def forward(self, x):
589
+ fmap = []
590
+
591
+ for l in self.convs:
592
+ x = l(x)
593
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
594
+ fmap.append(x)
595
+ x = self.conv_post(x)
596
+ fmap.append(x)
597
+ x = torch.flatten(x, 1, -1)
598
+
599
+ return x, fmap
600
+
601
+
602
+ class MultiPeriodDiscriminator(torch.nn.Module):
603
+ def __init__(self, use_spectral_norm=False):
604
+ super(MultiPeriodDiscriminator, self).__init__()
605
+ periods = [2, 3, 5, 7, 11]
606
+
607
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
608
+ discs = discs + [
609
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
610
+ ]
611
+ self.discriminators = nn.ModuleList(discs)
612
+
613
+ def forward(self, y, y_hat):
614
+ y_d_rs = []
615
+ y_d_gs = []
616
+ fmap_rs = []
617
+ fmap_gs = []
618
+ for i, d in enumerate(self.discriminators):
619
+ y_d_r, fmap_r = d(y)
620
+ y_d_g, fmap_g = d(y_hat)
621
+ y_d_rs.append(y_d_r)
622
+ y_d_gs.append(y_d_g)
623
+ fmap_rs.append(fmap_r)
624
+ fmap_gs.append(fmap_g)
625
+
626
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
627
+
628
+
629
+ class ReferenceEncoder(nn.Module):
630
+ """
631
+ inputs --- [N, Ty/r, n_mels*r] mels
632
+ outputs --- [N, ref_enc_gru_size]
633
+ """
634
+
635
+ def __init__(self, spec_channels, gin_channels=0):
636
+ super().__init__()
637
+ self.spec_channels = spec_channels
638
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
639
+ K = len(ref_enc_filters)
640
+ filters = [1] + ref_enc_filters
641
+ convs = [
642
+ weight_norm(
643
+ nn.Conv2d(
644
+ in_channels=filters[i],
645
+ out_channels=filters[i + 1],
646
+ kernel_size=(3, 3),
647
+ stride=(2, 2),
648
+ padding=(1, 1),
649
+ )
650
+ )
651
+ for i in range(K)
652
+ ]
653
+ self.convs = nn.ModuleList(convs)
654
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
655
+
656
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
657
+ self.gru = nn.GRU(
658
+ input_size=ref_enc_filters[-1] * out_channels,
659
+ hidden_size=256 // 2,
660
+ batch_first=True,
661
+ )
662
+ self.proj = nn.Linear(128, gin_channels)
663
+
664
+ def forward(self, inputs):
665
+ N = inputs.size(0)
666
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
667
+ for conv in self.convs:
668
+ out = conv(out)
669
+ # out = wn(out)
670
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
671
+
672
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
673
+ T = out.size(1)
674
+ N = out.size(0)
675
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
676
+
677
+ self.gru.flatten_parameters()
678
+ memory, out = self.gru(out) # out --- [1, N, 128]
679
+
680
+ return self.proj(out.squeeze(0)).unsqueeze(-1)
681
+
682
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
683
+ for i in range(n_convs):
684
+ L = (L - kernel_size + 2 * pad) // stride + 1
685
+ return L
686
+
687
+
688
+ class Quantizer_module(torch.nn.Module):
689
+ def __init__(self, n_e, e_dim):
690
+ super(Quantizer_module, self).__init__()
691
+ self.embedding = nn.Embedding(n_e, e_dim)
692
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
693
+
694
+ def forward(self, x):
695
+ d = (
696
+ torch.sum(x**2, 1, keepdim=True)
697
+ + torch.sum(self.embedding.weight**2, 1)
698
+ - 2 * torch.matmul(x, self.embedding.weight.T)
699
+ )
700
+ min_indicies = torch.argmin(d, 1)
701
+ z_q = self.embedding(min_indicies)
702
+ return z_q, min_indicies
703
+
704
+
705
+ class Quantizer(torch.nn.Module):
706
+ def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
707
+ super(Quantizer, self).__init__()
708
+ assert embed_dim % n_code_groups == 0
709
+ self.quantizer_modules = nn.ModuleList(
710
+ [
711
+ Quantizer_module(n_codes, embed_dim // n_code_groups)
712
+ for _ in range(n_code_groups)
713
+ ]
714
+ )
715
+ self.n_code_groups = n_code_groups
716
+ self.embed_dim = embed_dim
717
+
718
+ def forward(self, xin):
719
+ # B, C, T
720
+ B, C, T = xin.shape
721
+ xin = xin.transpose(1, 2)
722
+ x = xin.reshape(-1, self.embed_dim)
723
+ x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
724
+ min_indicies = []
725
+ z_q = []
726
+ for _x, m in zip(x, self.quantizer_modules):
727
+ _z_q, _min_indicies = m(_x)
728
+ z_q.append(_z_q)
729
+ min_indicies.append(_min_indicies) # B * T,
730
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
731
+ loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
732
+ (z_q - xin.detach()) ** 2
733
+ )
734
+ z_q = xin + (z_q - xin).detach()
735
+ z_q = z_q.transpose(1, 2)
736
+ codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
737
+ return z_q, loss, codes.transpose(1, 2)
738
+
739
+ def embed(self, x):
740
+ # idx: N, 4, T
741
+ x = x.transpose(1, 2)
742
+ x = torch.split(x, 1, 2)
743
+ ret = []
744
+ for q, embed in zip(x, self.quantizer_modules):
745
+ q = embed.embedding(q.squeeze(-1))
746
+ ret.append(q)
747
+ ret = torch.cat(ret, -1)
748
+ return ret.transpose(1, 2) # N, C, T
749
+
750
+
751
+ class CodePredictor(nn.Module):
752
+ def __init__(
753
+ self,
754
+ hidden_channels,
755
+ filter_channels,
756
+ n_heads,
757
+ n_layers,
758
+ kernel_size,
759
+ p_dropout,
760
+ n_q=8,
761
+ dims=1024,
762
+ ssl_dim=768,
763
+ ):
764
+ super().__init__()
765
+ self.hidden_channels = hidden_channels
766
+ self.filter_channels = filter_channels
767
+ self.n_heads = n_heads
768
+ self.n_layers = n_layers
769
+ self.kernel_size = kernel_size
770
+ self.p_dropout = p_dropout
771
+
772
+ self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
773
+ self.ref_enc = modules.MelStyleEncoder(
774
+ ssl_dim, style_vector_dim=hidden_channels
775
+ )
776
+
777
+ self.encoder = attentions.Encoder(
778
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
779
+ )
780
+
781
+ self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
782
+ self.n_q = n_q
783
+ self.dims = dims
784
+
785
+ def forward(self, x, x_mask, refer, codes, infer=False):
786
+ x = x.detach()
787
+ x = self.vq_proj(x * x_mask) * x_mask
788
+ g = self.ref_enc(refer, x_mask)
789
+ x = x + g
790
+ x = self.encoder(x * x_mask, x_mask)
791
+ x = self.out_proj(x * x_mask) * x_mask
792
+ logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
793
+ 2, 3
794
+ )
795
+ target = codes[1:].transpose(0, 1)
796
+ if not infer:
797
+ logits = logits.reshape(-1, self.dims)
798
+ target = target.reshape(-1)
799
+ loss = torch.nn.functional.cross_entropy(logits, target)
800
+ return loss
801
+ else:
802
+ _, top10_preds = torch.topk(logits, 10, dim=-1)
803
+ correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
804
+ top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
805
+
806
+ print("Top-10 Accuracy:", top3_acc, "%")
807
+
808
+ pred_codes = torch.argmax(logits, dim=-1)
809
+ acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
810
+ print("Top-1 Accuracy:", acc, "%")
811
+
812
+ return pred_codes.transpose(0, 1)
813
+
814
+
815
+ class SynthesizerTrn(nn.Module):
816
+ """
817
+ Synthesizer for Training
818
+ """
819
+
820
+ def __init__(
821
+ self,
822
+ spec_channels,
823
+ segment_size,
824
+ inter_channels,
825
+ hidden_channels,
826
+ filter_channels,
827
+ n_heads,
828
+ n_layers,
829
+ kernel_size,
830
+ p_dropout,
831
+ resblock,
832
+ resblock_kernel_sizes,
833
+ resblock_dilation_sizes,
834
+ upsample_rates,
835
+ upsample_initial_channel,
836
+ upsample_kernel_sizes,
837
+ n_speakers=0,
838
+ gin_channels=0,
839
+ use_sdp=True,
840
+ semantic_frame_rate=None,
841
+ freeze_quantizer=None,
842
+ version = "v2",
843
+ **kwargs
844
+ ):
845
+ super().__init__()
846
+ self.spec_channels = spec_channels
847
+ self.inter_channels = inter_channels
848
+ self.hidden_channels = hidden_channels
849
+ self.filter_channels = filter_channels
850
+ self.n_heads = n_heads
851
+ self.n_layers = n_layers
852
+ self.kernel_size = kernel_size
853
+ self.p_dropout = p_dropout
854
+ self.resblock = resblock
855
+ self.resblock_kernel_sizes = resblock_kernel_sizes
856
+ self.resblock_dilation_sizes = resblock_dilation_sizes
857
+ self.upsample_rates = upsample_rates
858
+ self.upsample_initial_channel = upsample_initial_channel
859
+ self.upsample_kernel_sizes = upsample_kernel_sizes
860
+ self.segment_size = segment_size
861
+ self.n_speakers = n_speakers
862
+ self.gin_channels = gin_channels
863
+ self.version = version
864
+
865
+ self.use_sdp = use_sdp
866
+ self.enc_p = TextEncoder(
867
+ inter_channels,
868
+ hidden_channels,
869
+ filter_channels,
870
+ n_heads,
871
+ n_layers,
872
+ kernel_size,
873
+ p_dropout,
874
+ version = version,
875
+ )
876
+ self.dec = Generator(
877
+ inter_channels,
878
+ resblock,
879
+ resblock_kernel_sizes,
880
+ resblock_dilation_sizes,
881
+ upsample_rates,
882
+ upsample_initial_channel,
883
+ upsample_kernel_sizes,
884
+ gin_channels=gin_channels,
885
+ )
886
+ self.enc_q = PosteriorEncoder(
887
+ spec_channels,
888
+ inter_channels,
889
+ hidden_channels,
890
+ 5,
891
+ 1,
892
+ 16,
893
+ gin_channels=gin_channels,
894
+ )
895
+ self.flow = ResidualCouplingBlock(
896
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
897
+ )
898
+
899
+ # self.version=os.environ.get("version","v1")
900
+ if(self.version=="v1"):
901
+ self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
902
+ else:
903
+ self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
904
+
905
+ ssl_dim = 768
906
+ assert semantic_frame_rate in ["25hz", "50hz"]
907
+ self.semantic_frame_rate = semantic_frame_rate
908
+ if semantic_frame_rate == "25hz":
909
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
910
+ else:
911
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
912
+
913
+ self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
914
+ self.freeze_quantizer = freeze_quantizer
915
+
916
+ def forward(self, ssl, y, y_lengths, text, text_lengths):
917
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
918
+ y.dtype
919
+ )
920
+ if(self.version=="v1"):
921
+ ge = self.ref_enc(y * y_mask, y_mask)
922
+ else:
923
+ ge = self.ref_enc(y * y_mask, y_mask)
924
+ with autocast(enabled=False):
925
+ maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
926
+ with maybe_no_grad:
927
+ if self.freeze_quantizer:
928
+ self.ssl_proj.eval()
929
+ self.quantizer.eval()
930
+ ssl = self.ssl_proj(ssl)
931
+ quantized, codes, commit_loss, quantized_list = self.quantizer(
932
+ ssl, layers=[0]
933
+ )
934
+
935
+ if self.semantic_frame_rate == "25hz":
936
+ quantized = F.interpolate(
937
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
938
+ )
939
+
940
+ x, m_p, logs_p, y_mask = self.enc_p(
941
+ quantized, y_lengths, text, text_lengths, ge
942
+ )
943
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
944
+ z_p = self.flow(z, y_mask, g=ge)
945
+
946
+ z_slice, ids_slice = commons.rand_slice_segments(
947
+ z, y_lengths, self.segment_size
948
+ )
949
+ o = self.dec(z_slice, g=ge)
950
+ return (
951
+ o,
952
+ commit_loss,
953
+ ids_slice,
954
+ y_mask,
955
+ y_mask,
956
+ (z, z_p, m_p, logs_p, m_q, logs_q),
957
+ quantized,
958
+ )
959
+
960
+ def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
961
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
962
+ y.dtype
963
+ )
964
+ if(self.version=="v1"):
965
+ ge = self.ref_enc(y * y_mask, y_mask)
966
+ else:
967
+ ge = self.ref_enc(y * y_mask, y_mask)
968
+
969
+ ssl = self.ssl_proj(ssl)
970
+ quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
971
+ if self.semantic_frame_rate == "25hz":
972
+ quantized = F.interpolate(
973
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
974
+ )
975
+
976
+ x, m_p, logs_p, y_mask = self.enc_p(
977
+ quantized, y_lengths, text, text_lengths, ge, test=test
978
+ )
979
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
980
+
981
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
982
+
983
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
984
+ return o, y_mask, (z, z_p, m_p, logs_p)
985
+
986
+ @torch.no_grad()
987
+ def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
988
+ def get_ge(refer):
989
+ ge = None
990
+ if refer is not None:
991
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
992
+ refer_mask = torch.unsqueeze(
993
+ commons.sequence_mask(refer_lengths, refer.size(2)), 1
994
+ ).to(refer.dtype)
995
+ if (self.version == "v1"):
996
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
997
+ else:
998
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
999
+ return ge
1000
+ if(type(refer)==list):
1001
+ ges=[]
1002
+ for _refer in refer:
1003
+ ge=get_ge(_refer)
1004
+ ges.append(ge)
1005
+ ge=torch.stack(ges,0).mean(0)
1006
+ else:
1007
+ ge=get_ge(refer)
1008
+
1009
+ y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
1010
+ text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
1011
+
1012
+ quantized = self.quantizer.decode(codes)
1013
+ if self.semantic_frame_rate == "25hz":
1014
+ quantized = F.interpolate(
1015
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
1016
+ )
1017
+ x, m_p, logs_p, y_mask = self.enc_p(
1018
+ quantized, y_lengths, text, text_lengths, ge,speed
1019
+ )
1020
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1021
+
1022
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
1023
+
1024
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
1025
+ return o
1026
+
1027
+ def extract_latent(self, x):
1028
+ ssl = self.ssl_proj(x)
1029
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
1030
+ return codes.transpose(0, 1)