Text-to-Speech
English
hexgrad commited on
Commit
0d0ce74
·
verified ·
1 Parent(s): a04ef58

Delete models.py

Browse files
Files changed (1) hide show
  1. models.py +0 -372
models.py DELETED
@@ -1,372 +0,0 @@
1
- # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
- from istftnet import AdaIN1d, Decoder
3
- from munch import Munch
4
- from pathlib import Path
5
- from plbert import load_plbert
6
- from torch.nn.utils import weight_norm, spectral_norm
7
- import json
8
- import numpy as np
9
- import os
10
- import os.path as osp
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
-
15
- class LinearNorm(torch.nn.Module):
16
- def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
17
- super(LinearNorm, self).__init__()
18
- self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
19
-
20
- torch.nn.init.xavier_uniform_(
21
- self.linear_layer.weight,
22
- gain=torch.nn.init.calculate_gain(w_init_gain))
23
-
24
- def forward(self, x):
25
- return self.linear_layer(x)
26
-
27
- class LayerNorm(nn.Module):
28
- def __init__(self, channels, eps=1e-5):
29
- super().__init__()
30
- self.channels = channels
31
- self.eps = eps
32
-
33
- self.gamma = nn.Parameter(torch.ones(channels))
34
- self.beta = nn.Parameter(torch.zeros(channels))
35
-
36
- def forward(self, x):
37
- x = x.transpose(1, -1)
38
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
39
- return x.transpose(1, -1)
40
-
41
- class TextEncoder(nn.Module):
42
- def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
43
- super().__init__()
44
- self.embedding = nn.Embedding(n_symbols, channels)
45
-
46
- padding = (kernel_size - 1) // 2
47
- self.cnn = nn.ModuleList()
48
- for _ in range(depth):
49
- self.cnn.append(nn.Sequential(
50
- weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
51
- LayerNorm(channels),
52
- actv,
53
- nn.Dropout(0.2),
54
- ))
55
- # self.cnn = nn.Sequential(*self.cnn)
56
-
57
- self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
58
-
59
- def forward(self, x, input_lengths, m):
60
- x = self.embedding(x) # [B, T, emb]
61
- x = x.transpose(1, 2) # [B, emb, T]
62
- m = m.to(input_lengths.device).unsqueeze(1)
63
- x.masked_fill_(m, 0.0)
64
-
65
- for c in self.cnn:
66
- x = c(x)
67
- x.masked_fill_(m, 0.0)
68
-
69
- x = x.transpose(1, 2) # [B, T, chn]
70
-
71
- input_lengths = input_lengths.cpu().numpy()
72
- x = nn.utils.rnn.pack_padded_sequence(
73
- x, input_lengths, batch_first=True, enforce_sorted=False)
74
-
75
- self.lstm.flatten_parameters()
76
- x, _ = self.lstm(x)
77
- x, _ = nn.utils.rnn.pad_packed_sequence(
78
- x, batch_first=True)
79
-
80
- x = x.transpose(-1, -2)
81
- x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
82
-
83
- x_pad[:, :, :x.shape[-1]] = x
84
- x = x_pad.to(x.device)
85
-
86
- x.masked_fill_(m, 0.0)
87
-
88
- return x
89
-
90
- def inference(self, x):
91
- x = self.embedding(x)
92
- x = x.transpose(1, 2)
93
- x = self.cnn(x)
94
- x = x.transpose(1, 2)
95
- self.lstm.flatten_parameters()
96
- x, _ = self.lstm(x)
97
- return x
98
-
99
- def length_to_mask(self, lengths):
100
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
101
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
102
- return mask
103
-
104
-
105
- class UpSample1d(nn.Module):
106
- def __init__(self, layer_type):
107
- super().__init__()
108
- self.layer_type = layer_type
109
-
110
- def forward(self, x):
111
- if self.layer_type == 'none':
112
- return x
113
- else:
114
- return F.interpolate(x, scale_factor=2, mode='nearest')
115
-
116
- class AdainResBlk1d(nn.Module):
117
- def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
118
- upsample='none', dropout_p=0.0):
119
- super().__init__()
120
- self.actv = actv
121
- self.upsample_type = upsample
122
- self.upsample = UpSample1d(upsample)
123
- self.learned_sc = dim_in != dim_out
124
- self._build_weights(dim_in, dim_out, style_dim)
125
- self.dropout = nn.Dropout(dropout_p)
126
-
127
- if upsample == 'none':
128
- self.pool = nn.Identity()
129
- else:
130
- self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
131
-
132
-
133
- def _build_weights(self, dim_in, dim_out, style_dim):
134
- self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
135
- self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
136
- self.norm1 = AdaIN1d(style_dim, dim_in)
137
- self.norm2 = AdaIN1d(style_dim, dim_out)
138
- if self.learned_sc:
139
- self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
140
-
141
- def _shortcut(self, x):
142
- x = self.upsample(x)
143
- if self.learned_sc:
144
- x = self.conv1x1(x)
145
- return x
146
-
147
- def _residual(self, x, s):
148
- x = self.norm1(x, s)
149
- x = self.actv(x)
150
- x = self.pool(x)
151
- x = self.conv1(self.dropout(x))
152
- x = self.norm2(x, s)
153
- x = self.actv(x)
154
- x = self.conv2(self.dropout(x))
155
- return x
156
-
157
- def forward(self, x, s):
158
- out = self._residual(x, s)
159
- out = (out + self._shortcut(x)) / np.sqrt(2)
160
- return out
161
-
162
- class AdaLayerNorm(nn.Module):
163
- def __init__(self, style_dim, channels, eps=1e-5):
164
- super().__init__()
165
- self.channels = channels
166
- self.eps = eps
167
-
168
- self.fc = nn.Linear(style_dim, channels*2)
169
-
170
- def forward(self, x, s):
171
- x = x.transpose(-1, -2)
172
- x = x.transpose(1, -1)
173
-
174
- h = self.fc(s)
175
- h = h.view(h.size(0), h.size(1), 1)
176
- gamma, beta = torch.chunk(h, chunks=2, dim=1)
177
- gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
178
-
179
-
180
- x = F.layer_norm(x, (self.channels,), eps=self.eps)
181
- x = (1 + gamma) * x + beta
182
- return x.transpose(1, -1).transpose(-1, -2)
183
-
184
- class ProsodyPredictor(nn.Module):
185
-
186
- def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
187
- super().__init__()
188
-
189
- self.text_encoder = DurationEncoder(sty_dim=style_dim,
190
- d_model=d_hid,
191
- nlayers=nlayers,
192
- dropout=dropout)
193
-
194
- self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
195
- self.duration_proj = LinearNorm(d_hid, max_dur)
196
-
197
- self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
198
- self.F0 = nn.ModuleList()
199
- self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
200
- self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
201
- self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
202
-
203
- self.N = nn.ModuleList()
204
- self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
205
- self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
206
- self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
207
-
208
- self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
209
- self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
210
-
211
-
212
- def forward(self, texts, style, text_lengths, alignment, m):
213
- d = self.text_encoder(texts, style, text_lengths, m)
214
-
215
- batch_size = d.shape[0]
216
- text_size = d.shape[1]
217
-
218
- # predict duration
219
- input_lengths = text_lengths.cpu().numpy()
220
- x = nn.utils.rnn.pack_padded_sequence(
221
- d, input_lengths, batch_first=True, enforce_sorted=False)
222
-
223
- m = m.to(text_lengths.device).unsqueeze(1)
224
-
225
- self.lstm.flatten_parameters()
226
- x, _ = self.lstm(x)
227
- x, _ = nn.utils.rnn.pad_packed_sequence(
228
- x, batch_first=True)
229
-
230
- x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
231
-
232
- x_pad[:, :x.shape[1], :] = x
233
- x = x_pad.to(x.device)
234
-
235
- duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
236
-
237
- en = (d.transpose(-1, -2) @ alignment)
238
-
239
- return duration.squeeze(-1), en
240
-
241
- def F0Ntrain(self, x, s):
242
- x, _ = self.shared(x.transpose(-1, -2))
243
-
244
- F0 = x.transpose(-1, -2)
245
- for block in self.F0:
246
- F0 = block(F0, s)
247
- F0 = self.F0_proj(F0)
248
-
249
- N = x.transpose(-1, -2)
250
- for block in self.N:
251
- N = block(N, s)
252
- N = self.N_proj(N)
253
-
254
- return F0.squeeze(1), N.squeeze(1)
255
-
256
- def length_to_mask(self, lengths):
257
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
258
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
259
- return mask
260
-
261
- class DurationEncoder(nn.Module):
262
-
263
- def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
264
- super().__init__()
265
- self.lstms = nn.ModuleList()
266
- for _ in range(nlayers):
267
- self.lstms.append(nn.LSTM(d_model + sty_dim,
268
- d_model // 2,
269
- num_layers=1,
270
- batch_first=True,
271
- bidirectional=True,
272
- dropout=dropout))
273
- self.lstms.append(AdaLayerNorm(sty_dim, d_model))
274
-
275
-
276
- self.dropout = dropout
277
- self.d_model = d_model
278
- self.sty_dim = sty_dim
279
-
280
- def forward(self, x, style, text_lengths, m):
281
- masks = m.to(text_lengths.device)
282
-
283
- x = x.permute(2, 0, 1)
284
- s = style.expand(x.shape[0], x.shape[1], -1)
285
- x = torch.cat([x, s], axis=-1)
286
- x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
287
-
288
- x = x.transpose(0, 1)
289
- input_lengths = text_lengths.cpu().numpy()
290
- x = x.transpose(-1, -2)
291
-
292
- for block in self.lstms:
293
- if isinstance(block, AdaLayerNorm):
294
- x = block(x.transpose(-1, -2), style).transpose(-1, -2)
295
- x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
296
- x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
297
- else:
298
- x = x.transpose(-1, -2)
299
- x = nn.utils.rnn.pack_padded_sequence(
300
- x, input_lengths, batch_first=True, enforce_sorted=False)
301
- block.flatten_parameters()
302
- x, _ = block(x)
303
- x, _ = nn.utils.rnn.pad_packed_sequence(
304
- x, batch_first=True)
305
- x = F.dropout(x, p=self.dropout, training=self.training)
306
- x = x.transpose(-1, -2)
307
-
308
- x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
309
-
310
- x_pad[:, :, :x.shape[-1]] = x
311
- x = x_pad.to(x.device)
312
-
313
- return x.transpose(-1, -2)
314
-
315
- def inference(self, x, style):
316
- x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
317
- style = style.expand(x.shape[0], x.shape[1], -1)
318
- x = torch.cat([x, style], axis=-1)
319
- src = self.pos_encoder(x)
320
- output = self.transformer_encoder(src).transpose(0, 1)
321
- return output
322
-
323
- def length_to_mask(self, lengths):
324
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
325
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
326
- return mask
327
-
328
- # https://github.com/yl4579/StyleTTS2/blob/main/utils.py
329
- def recursive_munch(d):
330
- if isinstance(d, dict):
331
- return Munch((k, recursive_munch(v)) for k, v in d.items())
332
- elif isinstance(d, list):
333
- return [recursive_munch(v) for v in d]
334
- else:
335
- return d
336
-
337
- def build_model(path, device):
338
- config = Path(__file__).parent / 'config.json'
339
- assert config.exists(), f'Config path incorrect: config.json not found at {config}'
340
- with open(config, 'r') as r:
341
- args = recursive_munch(json.load(r))
342
- assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
343
- decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
344
- resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
345
- upsample_rates = args.decoder.upsample_rates,
346
- upsample_initial_channel=args.decoder.upsample_initial_channel,
347
- resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
348
- upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
349
- gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
350
- text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
351
- predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
352
- bert = load_plbert()
353
- bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
354
- for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
355
- for child in parent.children():
356
- if isinstance(child, nn.RNNBase):
357
- child.flatten_parameters()
358
- model = Munch(
359
- bert=bert.to(device).eval(),
360
- bert_encoder=bert_encoder.to(device).eval(),
361
- predictor=predictor.to(device).eval(),
362
- decoder=decoder.to(device).eval(),
363
- text_encoder=text_encoder.to(device).eval(),
364
- )
365
- for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
366
- assert key in model, key
367
- try:
368
- model[key].load_state_dict(state_dict)
369
- except:
370
- state_dict = {k[7:]: v for k, v in state_dict.items()}
371
- model[key].load_state_dict(state_dict, strict=False)
372
- return model