Respair commited on
Commit
580e3a4
·
verified ·
1 Parent(s): e222005

Create models.py

Browse files
Hiformer_Checkpoint_Libri_24khz/models.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from utils import init_weights, get_padding
7
+ import numpy as np
8
+ from stft import TorchSTFT
9
+ import torchaudio
10
+ from nnAudio import features
11
+ from einops import rearrange
12
+ from norm2d import NormConv2d
13
+ from utils import get_padding
14
+ from munch import Munch
15
+ from conformer import Conformer
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ def get_2d_padding(kernel_size, dilation=(1, 1)):
21
+ return (
22
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
23
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
24
+ )
25
+
26
+
27
+
28
+ class ResBlock1(torch.nn.Module):
29
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
30
+ super(ResBlock1, self).__init__()
31
+ self.h = h
32
+ self.convs1 = nn.ModuleList([
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
34
+ padding=get_padding(kernel_size, dilation[0]))),
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
36
+ padding=get_padding(kernel_size, dilation[1]))),
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
38
+ padding=get_padding(kernel_size, dilation[2])))
39
+ ])
40
+ self.convs1.apply(init_weights)
41
+
42
+ self.convs2 = nn.ModuleList([
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
44
+ padding=get_padding(kernel_size, 1))),
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1))),
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1)))
49
+ ])
50
+ self.convs2.apply(init_weights)
51
+
52
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
53
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
54
+
55
+
56
+ def forward(self, x):
57
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.alpha1, self.alpha2):
58
+ xt = x + (1 / a1) * (torch.sin(a1 * x) ** 2) # Snake1D
59
+ xt = c1(xt)
60
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
61
+ xt = c2(xt)
62
+ x = xt + x
63
+ return x
64
+
65
+ def remove_weight_norm(self):
66
+ for l in self.convs1:
67
+ remove_weight_norm(l)
68
+ for l in self.convs2:
69
+ remove_weight_norm(l)
70
+
71
+ class ResBlock1_old(torch.nn.Module):
72
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
73
+ super(ResBlock1, self).__init__()
74
+ self.h = h
75
+ self.convs1 = nn.ModuleList([
76
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
77
+ padding=get_padding(kernel_size, dilation[0]))),
78
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
79
+ padding=get_padding(kernel_size, dilation[1]))),
80
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
81
+ padding=get_padding(kernel_size, dilation[2])))
82
+ ])
83
+ self.convs1.apply(init_weights)
84
+
85
+ self.convs2 = nn.ModuleList([
86
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
87
+ padding=get_padding(kernel_size, 1))),
88
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
89
+ padding=get_padding(kernel_size, 1))),
90
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
91
+ padding=get_padding(kernel_size, 1)))
92
+ ])
93
+ self.convs2.apply(init_weights)
94
+
95
+ def forward(self, x):
96
+ for c1, c2 in zip(self.convs1, self.convs2):
97
+ xt = F.leaky_relu(x, LRELU_SLOPE)
98
+ xt = c1(xt)
99
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
100
+ xt = c2(xt)
101
+ x = xt + x
102
+ return x
103
+
104
+ def remove_weight_norm(self):
105
+ for l in self.convs1:
106
+ remove_weight_norm(l)
107
+ for l in self.convs2:
108
+ remove_weight_norm(l)
109
+
110
+
111
+ class ResBlock2(torch.nn.Module):
112
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
113
+ super(ResBlock2, self).__init__()
114
+ self.h = h
115
+ self.convs = nn.ModuleList([
116
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
117
+ padding=get_padding(kernel_size, dilation[0]))),
118
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
119
+ padding=get_padding(kernel_size, dilation[1])))
120
+ ])
121
+ self.convs.apply(init_weights)
122
+
123
+ def forward(self, x):
124
+ for c in self.convs:
125
+ xt = F.leaky_relu(x, LRELU_SLOPE)
126
+ xt = c(xt)
127
+ x = xt + x
128
+ return x
129
+
130
+ def remove_weight_norm(self):
131
+ for l in self.convs:
132
+ remove_weight_norm(l)
133
+
134
+
135
+ class SineGen(torch.nn.Module):
136
+ """ Definition of sine generator
137
+ SineGen(samp_rate, harmonic_num = 0,
138
+ sine_amp = 0.1, noise_std = 0.003,
139
+ voiced_threshold = 0,
140
+ flag_for_pulse=False)
141
+ samp_rate: sampling rate in Hz
142
+ harmonic_num: number of harmonic overtones (default 0)
143
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
144
+ noise_std: std of Gaussian noise (default 0.003)
145
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
146
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
147
+ Note: when flag_for_pulse is True, the first time step of a voiced
148
+ segment is always sin(np.pi) or cos(0)
149
+ """
150
+
151
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
152
+ sine_amp=0.1, noise_std=0.003,
153
+ voiced_threshold=0,
154
+ flag_for_pulse=False):
155
+ super(SineGen, self).__init__()
156
+ self.sine_amp = sine_amp
157
+ self.noise_std = noise_std
158
+ self.harmonic_num = harmonic_num
159
+ self.dim = self.harmonic_num + 1
160
+ self.sampling_rate = samp_rate
161
+ self.voiced_threshold = voiced_threshold
162
+ self.flag_for_pulse = flag_for_pulse
163
+ self.upsample_scale = upsample_scale
164
+
165
+ def _f02uv(self, f0):
166
+ # generate uv signal
167
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
168
+ return uv
169
+
170
+ def _f02sine(self, f0_values):
171
+ """ f0_values: (batchsize, length, dim)
172
+ where dim indicates fundamental tone and overtones
173
+ """
174
+ # convert to F0 in rad. The interger part n can be ignored
175
+ # because 2 * np.pi * n doesn't affect phase
176
+ rad_values = (f0_values / self.sampling_rate) % 1
177
+
178
+ # initial phase noise (no noise for fundamental component)
179
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
180
+ device=f0_values.device)
181
+ rand_ini[:, 0] = 0
182
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
183
+
184
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
185
+ if not self.flag_for_pulse:
186
+ # # for normal case
187
+
188
+ # # To prevent torch.cumsum numerical overflow,
189
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
190
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
191
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
192
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
193
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
194
+ # cumsum_shift = torch.zeros_like(rad_values)
195
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
196
+
197
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
198
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
199
+ scale_factor=1/self.upsample_scale,
200
+ mode="linear").transpose(1, 2)
201
+
202
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
203
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
204
+ # cumsum_shift = torch.zeros_like(rad_values)
205
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
206
+
207
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
208
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
209
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
210
+ sines = torch.sin(phase)
211
+
212
+ else:
213
+ # If necessary, make sure that the first time step of every
214
+ # voiced segments is sin(pi) or cos(0)
215
+ # This is used for pulse-train generation
216
+
217
+ # identify the last time step in unvoiced segments
218
+ uv = self._f02uv(f0_values)
219
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
220
+ uv_1[:, -1, :] = 1
221
+ u_loc = (uv < 1) * (uv_1 > 0)
222
+
223
+ # get the instantanouse phase
224
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
225
+ # different batch needs to be processed differently
226
+ for idx in range(f0_values.shape[0]):
227
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
228
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
229
+ # stores the accumulation of i.phase within
230
+ # each voiced segments
231
+ tmp_cumsum[idx, :, :] = 0
232
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
233
+
234
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
235
+ # within the previous voiced segment.
236
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
237
+
238
+ # get the sines
239
+ sines = torch.cos(i_phase * 2 * np.pi)
240
+ return sines
241
+
242
+ def forward(self, f0):
243
+ """ sine_tensor, uv = forward(f0)
244
+ input F0: tensor(batchsize=1, length, dim=1)
245
+ f0 for unvoiced steps should be 0
246
+ output sine_tensor: tensor(batchsize=1, length, dim)
247
+ output uv: tensor(batchsize=1, length, 1)
248
+ """
249
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
250
+ device=f0.device)
251
+ # fundamental component
252
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
253
+
254
+ # generate sine waveforms
255
+ sine_waves = self._f02sine(fn) * self.sine_amp
256
+
257
+ # generate uv signal
258
+ # uv = torch.ones(f0.shape)
259
+ # uv = uv * (f0 > self.voiced_threshold)
260
+ uv = self._f02uv(f0)
261
+
262
+ # noise: for unvoiced should be similar to sine_amp
263
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
264
+ # . for voiced regions is self.noise_std
265
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
266
+ noise = noise_amp * torch.randn_like(sine_waves)
267
+
268
+ # first: set the unvoiced part to 0 by uv
269
+ # then: additive noise
270
+ sine_waves = sine_waves * uv + noise
271
+ return sine_waves, uv, noise
272
+
273
+
274
+ class SourceModuleHnNSF(torch.nn.Module):
275
+ """ SourceModule for hn-nsf
276
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
277
+ add_noise_std=0.003, voiced_threshod=0)
278
+ sampling_rate: sampling_rate in Hz
279
+ harmonic_num: number of harmonic above F0 (default: 0)
280
+ sine_amp: amplitude of sine source signal (default: 0.1)
281
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
282
+ note that amplitude of noise in unvoiced is decided
283
+ by sine_amp
284
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
285
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
286
+ F0_sampled (batchsize, length, 1)
287
+ Sine_source (batchsize, length, 1)
288
+ noise_source (batchsize, length 1)
289
+ uv (batchsize, length, 1)
290
+ """
291
+
292
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
293
+ add_noise_std=0.003, voiced_threshod=0):
294
+ super(SourceModuleHnNSF, self).__init__()
295
+
296
+ self.sine_amp = sine_amp
297
+ self.noise_std = add_noise_std
298
+
299
+ # to produce sine waveforms
300
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
301
+ sine_amp, add_noise_std, voiced_threshod)
302
+
303
+ # to merge source harmonics into a single excitation
304
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
305
+ self.l_tanh = torch.nn.Tanh()
306
+
307
+ def forward(self, x):
308
+ """
309
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
310
+ F0_sampled (batchsize, length, 1)
311
+ Sine_source (batchsize, length, 1)
312
+ noise_source (batchsize, length 1)
313
+ """
314
+ # source for harmonic branch
315
+ with torch.no_grad():
316
+ sine_wavs, uv, _ = self.l_sin_gen(x)
317
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
318
+
319
+ # source for noise branch, in the same shape as uv
320
+ noise = torch.randn_like(uv) * self.sine_amp / 3
321
+ return sine_merge, noise, uv
322
+ def padDiff(x):
323
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
324
+
325
+
326
+
327
+ class Generator(torch.nn.Module):
328
+ def __init__(self, h, F0_model):
329
+ super(Generator, self).__init__()
330
+ self.h = h
331
+ self.num_kernels = len(h.resblock_kernel_sizes)
332
+ self.num_upsamples = len(h.upsample_rates)
333
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
334
+
335
+
336
+
337
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
338
+
339
+ self.m_source = SourceModuleHnNSF(
340
+ sampling_rate=h.sampling_rate,
341
+ upsample_scale=np.prod(h.upsample_rates) * h.gen_istft_hop_size,
342
+ harmonic_num=8, voiced_threshod=10)
343
+
344
+ self.f0_upsamp = torch.nn.Upsample(
345
+ scale_factor=np.prod(h.upsample_rates) * h.gen_istft_hop_size)
346
+ self.noise_convs = nn.ModuleList()
347
+ self.noise_res = nn.ModuleList()
348
+
349
+ self.F0_model = F0_model
350
+
351
+ self.ups = nn.ModuleList()
352
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
353
+ self.ups.append(weight_norm(
354
+ ConvTranspose1d(h.upsample_initial_channel//(2**i),
355
+ h.upsample_initial_channel//(2**(i+1)),
356
+ k,
357
+ u,
358
+ padding=(k-u)//2)))
359
+
360
+ c_cur = h.upsample_initial_channel // (2 ** (i + 1))
361
+
362
+ if i + 1 < len(h.upsample_rates): #
363
+ stride_f0 = np.prod(h.upsample_rates[i + 1:])
364
+ self.noise_convs.append(Conv1d(
365
+ h.gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
366
+ self.noise_res.append(resblock(h, c_cur, 7, [1,3,5]))
367
+ else:
368
+ self.noise_convs.append(Conv1d(h.gen_istft_n_fft + 2, c_cur, kernel_size=1))
369
+ self.noise_res.append(resblock(h, c_cur, 11, [1,3,5]))
370
+
371
+ self.alphas = nn.ParameterList()
372
+ self.alphas.append(nn.Parameter(torch.ones(1, h.upsample_initial_channel, 1)))
373
+ self.resblocks = nn.ModuleList()
374
+ for i in range(len(self.ups)):
375
+ ch = h.upsample_initial_channel//(2**(i+1))
376
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
377
+ for j, (k, d) in enumerate(
378
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
379
+ self.resblocks.append(resblock(h, ch, k, d))
380
+
381
+
382
+ self.conformers = nn.ModuleList()
383
+ self.post_n_fft = h.gen_istft_n_fft
384
+ self.conv_post = weight_norm(Conv1d(128, self.post_n_fft + 2, 7, 1, padding=3))
385
+
386
+ for i in range(len(self.ups)):
387
+ ch = h.upsample_initial_channel // (2**i)
388
+ self.conformers.append(
389
+ Conformer(
390
+ dim=ch,
391
+ depth=2,
392
+ dim_head=64,
393
+ heads=8,
394
+ ff_mult=4,
395
+ conv_expansion_factor=2,
396
+ conv_kernel_size=31,
397
+ attn_dropout=0.1,
398
+ ff_dropout=0.1,
399
+ conv_dropout=0.1,
400
+ # device=self.device
401
+ )
402
+ )
403
+
404
+ self.ups.apply(init_weights)
405
+ self.conv_post.apply(init_weights)
406
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
407
+ self.stft = TorchSTFT(filter_length=h.gen_istft_n_fft,
408
+ hop_length=h.gen_istft_hop_size,
409
+ win_length=h.gen_istft_n_fft)
410
+
411
+
412
+
413
+ def forward(self, x):
414
+
415
+
416
+
417
+ f0, _, _ = self.F0_model(x.unsqueeze(1))
418
+ if len(f0.shape) == 1:
419
+ f0 = f0.unsqueeze(0)
420
+
421
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
422
+
423
+ har_source, _, _ = self.m_source(f0)
424
+ har_source = har_source.transpose(1, 2).squeeze(1)
425
+ har_spec, har_phase = self.stft.transform(har_source)
426
+ har = torch.cat([har_spec, har_phase], dim=1)
427
+
428
+
429
+ x = self.conv_pre(x)
430
+
431
+ for i in range(self.num_upsamples):
432
+
433
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
434
+ x = rearrange(x, "b f t -> b t f")
435
+
436
+ x = self.conformers[i](x)
437
+
438
+ x = rearrange(x, "b t f -> b f t")
439
+
440
+ # x = F.leaky_relu(x, LRELU_SLOPE)
441
+ x_source = self.noise_convs[i](har)
442
+ x_source = self.noise_res[i](x_source)
443
+
444
+ x = self.ups[i](x)
445
+ if i == self.num_upsamples - 1:
446
+ x = self.reflection_pad(x)
447
+
448
+ x = x + x_source
449
+
450
+
451
+ xs = None
452
+ for j in range(self.num_kernels):
453
+ if xs is None:
454
+ xs = self.resblocks[i*self.num_kernels+j](x)
455
+ else:
456
+ xs += self.resblocks[i*self.num_kernels+j](x)
457
+ x = xs / self.num_kernels
458
+ # x = F.leaky_relu(x)
459
+
460
+
461
+ x = x + (1 / self.alphas[i + 1]) * (torch.sin(self.alphas[i + 1] * x) ** 2)
462
+
463
+ x = self.conv_post(x)
464
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]).to(x.device)
465
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]).to(x.device)
466
+
467
+ return spec, phase
468
+
469
+ def remove_weight_norm(self):
470
+ print("Removing weight norm...")
471
+ for l in self.ups:
472
+ remove_weight_norm(l)
473
+ for l in self.resblocks:
474
+ l.remove_weight_norm()
475
+ remove_weight_norm(self.conv_pre)
476
+ remove_weight_norm(self.conv_post)
477
+
478
+
479
+
480
+ def stft(x, fft_size, hop_size, win_length, window):
481
+ """Perform STFT and convert to magnitude spectrogram.
482
+ Args:
483
+ x (Tensor): Input signal tensor (B, T).
484
+ fft_size (int): FFT size.
485
+ hop_size (int): Hop size.
486
+ win_length (int): Window length.
487
+ window (str): Window function type.
488
+ Returns:
489
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
490
+ """
491
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
492
+ return_complex=True)
493
+ real = x_stft[..., 0]
494
+ imag = x_stft[..., 1]
495
+
496
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
497
+ return torch.abs(x_stft).transpose(2, 1)
498
+
499
+ class SpecDiscriminator(nn.Module):
500
+ """docstring for Discriminator."""
501
+
502
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
503
+ super(SpecDiscriminator, self).__init__()
504
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
505
+ self.fft_size = fft_size
506
+ self.shift_size = shift_size
507
+ self.win_length = win_length
508
+ self.window = getattr(torch, window)(win_length)
509
+ self.discriminators = nn.ModuleList([
510
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
511
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
512
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
513
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
514
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
515
+ ])
516
+
517
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
518
+
519
+ def forward(self, y):
520
+
521
+ fmap = []
522
+ y = y.squeeze(1)
523
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
524
+ y = y.unsqueeze(1)
525
+ for i, d in enumerate(self.discriminators):
526
+ y = d(y)
527
+ y = F.leaky_relu(y, LRELU_SLOPE)
528
+ fmap.append(y)
529
+
530
+ y = self.out(y)
531
+ fmap.append(y)
532
+
533
+ return torch.flatten(y, 1, -1), fmap
534
+
535
+ # class MultiResSpecDiscriminator(torch.nn.Module):
536
+
537
+ # def __init__(self,
538
+ # fft_sizes=[1024, 2048, 512],
539
+ # hop_sizes=[120, 240, 50],
540
+ # win_lengths=[600, 1200, 240],
541
+ # window="hann_window"):
542
+
543
+ # super(MultiResSpecDiscriminator, self).__init__()
544
+ # self.discriminators = nn.ModuleList([
545
+ # SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
546
+ # SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
547
+ # SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
548
+ # ])
549
+
550
+ # def forward(self, y, y_hat):
551
+ # y_d_rs = []
552
+ # y_d_gs = []
553
+ # fmap_rs = []
554
+ # fmap_gs = []
555
+ # for i, d in enumerate(self.discriminators):
556
+ # y_d_r, fmap_r = d(y)
557
+ # y_d_g, fmap_g = d(y_hat)
558
+ # y_d_rs.append(y_d_r)
559
+ # fmap_rs.append(fmap_r)
560
+ # y_d_gs.append(y_d_g)
561
+ # fmap_gs.append(fmap_g)
562
+
563
+ # return y_d_rs, y_d_gs, fmap_rs, fmap_gs
564
+
565
+
566
+ class DiscriminatorP(torch.nn.Module):
567
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
568
+ super(DiscriminatorP, self).__init__()
569
+ self.period = period
570
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
571
+ self.convs = nn.ModuleList([
572
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
573
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
574
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
575
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
576
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
577
+ ])
578
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
579
+
580
+ def forward(self, x):
581
+ fmap = []
582
+
583
+ # 1d to 2d
584
+ b, c, t = x.shape
585
+ if t % self.period != 0: # pad first
586
+ n_pad = self.period - (t % self.period)
587
+ x = F.pad(x, (0, n_pad), "reflect")
588
+ t = t + n_pad
589
+ x = x.view(b, c, t // self.period, self.period)
590
+
591
+ for l in self.convs:
592
+ x = l(x)
593
+ x = F.leaky_relu(x, LRELU_SLOPE)
594
+ fmap.append(x)
595
+ x = self.conv_post(x)
596
+ fmap.append(x)
597
+ x = torch.flatten(x, 1, -1)
598
+
599
+ return x, fmap
600
+
601
+
602
+ class MultiPeriodDiscriminator(torch.nn.Module):
603
+ def __init__(self):
604
+ super(MultiPeriodDiscriminator, self).__init__()
605
+ self.discriminators = nn.ModuleList([
606
+ DiscriminatorP(2),
607
+ DiscriminatorP(3),
608
+ DiscriminatorP(5),
609
+ DiscriminatorP(7),
610
+ DiscriminatorP(11),
611
+ ])
612
+
613
+ def forward(self, y, y_hat):
614
+ y_d_rs = []
615
+ y_d_gs = []
616
+ fmap_rs = []
617
+ fmap_gs = []
618
+ for i, d in enumerate(self.discriminators):
619
+ y_d_r, fmap_r = d(y)
620
+ y_d_g, fmap_g = d(y_hat)
621
+ y_d_rs.append(y_d_r)
622
+ fmap_rs.append(fmap_r)
623
+ y_d_gs.append(y_d_g)
624
+ fmap_gs.append(fmap_g)
625
+
626
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
627
+
628
+
629
+ class DiscriminatorS(torch.nn.Module):
630
+ def __init__(self, use_spectral_norm=False):
631
+ super(DiscriminatorS, self).__init__()
632
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
633
+ self.convs = nn.ModuleList([
634
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
635
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
636
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
637
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
638
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
639
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
640
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
641
+ ])
642
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
643
+
644
+ def forward(self, x):
645
+ fmap = []
646
+ for l in self.convs:
647
+ x = l(x)
648
+ x = F.leaky_relu(x, LRELU_SLOPE)
649
+ fmap.append(x)
650
+ x = self.conv_post(x)
651
+ fmap.append(x)
652
+ x = torch.flatten(x, 1, -1)
653
+
654
+ return x, fmap
655
+
656
+
657
+ class MultiScaleDiscriminator(torch.nn.Module):
658
+ def __init__(self):
659
+ super(MultiScaleDiscriminator, self).__init__()
660
+ self.discriminators = nn.ModuleList([
661
+ DiscriminatorS(use_spectral_norm=True),
662
+ DiscriminatorS(),
663
+ DiscriminatorS(),
664
+ ])
665
+ self.meanpools = nn.ModuleList([
666
+ AvgPool1d(4, 2, padding=2),
667
+ AvgPool1d(4, 2, padding=2)
668
+ ])
669
+
670
+ def forward(self, y, y_hat):
671
+ y_d_rs = []
672
+ y_d_gs = []
673
+ fmap_rs = []
674
+ fmap_gs = []
675
+ for i, d in enumerate(self.discriminators):
676
+ if i != 0:
677
+ y = self.meanpools[i-1](y)
678
+ y_hat = self.meanpools[i-1](y_hat)
679
+ y_d_r, fmap_r = d(y)
680
+ y_d_g, fmap_g = d(y_hat)
681
+ y_d_rs.append(y_d_r)
682
+ fmap_rs.append(fmap_r)
683
+ y_d_gs.append(y_d_g)
684
+ fmap_gs.append(fmap_g)
685
+
686
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
687
+
688
+
689
+
690
+
691
+
692
+ ########################### from ringformer
693
+
694
+ # multiscale_subband_cfg = {
695
+ # "hop_lengths": [1024, 512, 512], # Doubled to maintain similar time resolution
696
+ # "sampling_rate": 44100, # New sampling rate
697
+ # "filters": 32, # Kept same as it controls initial feature dimension
698
+ # "max_filters": 1024, # Kept same as it's a maximum limit
699
+ # "filters_scale": 1, # Kept same as it's a scaling factor
700
+ # "dilations": [1, 2, 4], # Kept same as they control receptive field growth
701
+ # "in_channels": 1, # Kept same (mono audio)
702
+ # "out_channels": 1, # Kept same (mono audio)
703
+ # "n_octaves": [10, 10, 10], # Increased by 1 to handle higher frequency range
704
+ # "bins_per_octaves": [24, 36, 48], # Kept same as they control frequency resolution
705
+ # }
706
+
707
+
708
+
709
+ multiscale_subband_cfg = {
710
+ "hop_lengths": [512, 256, 256],
711
+ "sampling_rate": 24000,
712
+ "filters": 32,
713
+ "max_filters": 1024,
714
+ "filters_scale": 1,
715
+ "dilations": [1, 2, 4],
716
+ "in_channels": 1,
717
+ "out_channels": 1,
718
+ "n_octaves": [9, 9, 9],
719
+ "bins_per_octaves": [24, 36, 48],
720
+ }
721
+
722
+ class DiscriminatorCQT(nn.Module):
723
+ def __init__(self, cfg, hop_length, n_octaves, bins_per_octave):
724
+ super(DiscriminatorCQT, self).__init__()
725
+ self.cfg = cfg
726
+
727
+ self.filters = cfg.filters
728
+ self.max_filters = cfg.max_filters
729
+ self.filters_scale = cfg.filters_scale
730
+ self.kernel_size = (3, 9)
731
+ self.dilations = cfg.dilations
732
+ self.stride = (1, 2)
733
+
734
+ self.in_channels = cfg.in_channels
735
+ self.out_channels = cfg.out_channels
736
+ self.fs = cfg.sampling_rate
737
+ self.hop_length = hop_length
738
+ self.n_octaves = n_octaves
739
+ self.bins_per_octave = bins_per_octave
740
+
741
+ self.cqt_transform = features.cqt.CQT2010v2(
742
+ sr=self.fs * 2,
743
+ hop_length=self.hop_length,
744
+ n_bins=self.bins_per_octave * self.n_octaves,
745
+ bins_per_octave=self.bins_per_octave,
746
+ output_format="Complex",
747
+ pad_mode="constant",
748
+ )
749
+
750
+ self.conv_pres = nn.ModuleList()
751
+ for i in range(self.n_octaves):
752
+ self.conv_pres.append(
753
+ NormConv2d(
754
+ self.in_channels * 2,
755
+ self.in_channels * 2,
756
+ kernel_size=self.kernel_size,
757
+ padding=get_2d_padding(self.kernel_size),
758
+ )
759
+ )
760
+
761
+ self.convs = nn.ModuleList()
762
+
763
+ self.convs.append(
764
+ NormConv2d(
765
+ self.in_channels * 2,
766
+ self.filters,
767
+ kernel_size=self.kernel_size,
768
+ padding=get_2d_padding(self.kernel_size),
769
+ )
770
+ )
771
+
772
+ in_chs = min(self.filters_scale * self.filters, self.max_filters)
773
+ for i, dilation in enumerate(self.dilations):
774
+ out_chs = min(
775
+ (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
776
+ )
777
+ self.convs.append(
778
+ NormConv2d(
779
+ in_chs,
780
+ out_chs,
781
+ kernel_size=self.kernel_size,
782
+ stride=self.stride,
783
+ dilation=(dilation, 1),
784
+ padding=get_2d_padding(self.kernel_size, (dilation, 1)),
785
+ norm="weight_norm",
786
+ )
787
+ )
788
+ in_chs = out_chs
789
+ out_chs = min(
790
+ (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
791
+ self.max_filters,
792
+ )
793
+ self.convs.append(
794
+ NormConv2d(
795
+ in_chs,
796
+ out_chs,
797
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
798
+ padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
799
+ norm="weight_norm",
800
+ )
801
+ )
802
+
803
+ self.conv_post = NormConv2d(
804
+ out_chs,
805
+ self.out_channels,
806
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
807
+ padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
808
+ norm="weight_norm",
809
+ )
810
+
811
+ self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
812
+ self.resample = torchaudio.transforms.Resample(
813
+ orig_freq=self.fs, new_freq=self.fs * 2
814
+ )
815
+
816
+ def forward(self, x):
817
+ fmap = []
818
+
819
+ x = self.resample(x)
820
+
821
+ z = self.cqt_transform(x)
822
+
823
+ z_amplitude = z[:, :, :, 0].unsqueeze(1)
824
+ z_phase = z[:, :, :, 1].unsqueeze(1)
825
+
826
+ z = torch.cat([z_amplitude, z_phase], dim=1)
827
+ z = rearrange(z, "b c w t -> b c t w")
828
+
829
+ latent_z = []
830
+ for i in range(self.n_octaves):
831
+ latent_z.append(
832
+ self.conv_pres[i](
833
+ z[
834
+ :,
835
+ :,
836
+ :,
837
+ i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
838
+ ]
839
+ )
840
+ )
841
+ latent_z = torch.cat(latent_z, dim=-1)
842
+
843
+ for i, l in enumerate(self.convs):
844
+ latent_z = l(latent_z)
845
+
846
+ latent_z = self.activation(latent_z)
847
+ fmap.append(latent_z)
848
+
849
+ latent_z = self.conv_post(latent_z)
850
+
851
+ return latent_z, fmap
852
+
853
+
854
+
855
+ class MultiScaleSubbandCQTDiscriminator(nn.Module): # replacing "MultiResSpecDiscriminator"
856
+ def __init__(self):
857
+ super(MultiScaleSubbandCQTDiscriminator, self).__init__()
858
+ cfg = Munch(multiscale_subband_cfg)
859
+ self.cfg = cfg
860
+ self.discriminators = nn.ModuleList(
861
+ [
862
+ DiscriminatorCQT(
863
+ cfg,
864
+ hop_length=cfg.hop_lengths[i],
865
+ n_octaves=cfg.n_octaves[i],
866
+ bins_per_octave=cfg.bins_per_octaves[i],
867
+ )
868
+ for i in range(len(cfg.hop_lengths))
869
+ ]
870
+ )
871
+
872
+ def forward(self, y, y_hat):
873
+ y_d_rs = []
874
+ y_d_gs = []
875
+ fmap_rs = []
876
+ fmap_gs = []
877
+
878
+ for disc in self.discriminators:
879
+ y_d_r, fmap_r = disc(y)
880
+ y_d_g, fmap_g = disc(y_hat)
881
+ y_d_rs.append(y_d_r)
882
+ fmap_rs.append(fmap_r)
883
+ y_d_gs.append(y_d_g)
884
+ fmap_gs.append(fmap_g)
885
+
886
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
887
+
888
+
889
+
890
+ #############################
891
+
892
+
893
+
894
+ def feature_loss(fmap_r, fmap_g):
895
+ loss = 0
896
+ for dr, dg in zip(fmap_r, fmap_g):
897
+ for rl, gl in zip(dr, dg):
898
+ loss += torch.mean(torch.abs(rl - gl))
899
+
900
+ return loss*2
901
+
902
+
903
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
904
+ loss = 0
905
+ r_losses = []
906
+ g_losses = []
907
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
908
+ r_loss = torch.mean((1-dr)**2)
909
+ g_loss = torch.mean(dg**2)
910
+ loss += (r_loss + g_loss)
911
+ r_losses.append(r_loss.item())
912
+ g_losses.append(g_loss.item())
913
+
914
+ return loss, r_losses, g_losses
915
+
916
+
917
+ def generator_loss(disc_outputs):
918
+ loss = 0
919
+ gen_losses = []
920
+ for dg in disc_outputs:
921
+ l = torch.mean((1-dg)**2)
922
+ gen_losses.append(l)
923
+ loss += l
924
+
925
+ return loss, gen_losses
926
+
927
+ def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
928
+ loss = 0
929
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
930
+ tau = 0.04
931
+ m_DG = torch.median((dr-dg))
932
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
933
+ loss += tau - F.relu(tau - L_rel)
934
+ return loss
935
+
936
+ def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
937
+ loss = 0
938
+ for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
939
+ tau = 0.04
940
+ m_DG = torch.median((dr-dg))
941
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
942
+ loss += tau - F.relu(tau - L_rel)
943
+ return loss