Katock commited on
Commit
24bc27d
·
1 Parent(s): 1cca979

内存优化(减少占用20%)

Browse files
app.py CHANGED
@@ -2,7 +2,6 @@ import argparse
2
  import logging
3
  import os
4
  import re
5
- import subprocess
6
  import gradio.processing_utils as gr_pu
7
  import gradio as gr
8
  import librosa
@@ -11,6 +10,8 @@ import soundfile
11
  from scipy.io import wavfile
12
  import tempfile
13
  import edge_tts
 
 
14
 
15
  from inference.infer_tool import Svc
16
 
@@ -28,6 +29,11 @@ tts_voice = {
28
  "英文女": "en-US-AnaNeural"
29
  }
30
 
 
 
 
 
 
31
 
32
  def create_fn(model, spk):
33
  def svc_fn(input_audio, vc_transform, auto_f0, f0p):
@@ -39,6 +45,8 @@ def create_fn(model, spk):
39
  audio = librosa.to_mono(audio.transpose(1, 0))
40
  temp_path = "temp.wav"
41
  soundfile.write(temp_path, audio, sr, format="wav")
 
 
42
  out_audio = model.slice_inference(raw_audio_path=temp_path,
43
  spk=spk,
44
  slice_db=-40,
@@ -58,15 +66,6 @@ def create_fn(model, spk):
58
  input_text = re.sub(r"[\n\,\(\) ]", "", input_text)
59
  voice = tts_voice[gender]
60
  ratestr = "+{:.0%}".format(tts_rate) if tts_rate >= 0 else "{:.0%}".format(tts_rate)
61
- # temp_path = "temp.wav"
62
- # p = subprocess.Popen("edge-tts " +
63
- # " --text " + input_text +
64
- # " --write-media " + temp_path +
65
- # " --voice " + voice +
66
- # " --rate=" + ratestr, shell=True,
67
- # stdout=subprocess.PIPE,
68
- # stdin=subprocess.PIPE)
69
- # p.wait()
70
  communicate = edge_tts.Communicate(text=input_text,
71
  voice=voice,
72
  rate=ratestr)
 
2
  import logging
3
  import os
4
  import re
 
5
  import gradio.processing_utils as gr_pu
6
  import gradio as gr
7
  import librosa
 
10
  from scipy.io import wavfile
11
  import tempfile
12
  import edge_tts
13
+ import utils
14
+ import torch
15
 
16
  from inference.infer_tool import Svc
17
 
 
29
  "英文女": "en-US-AnaNeural"
30
  }
31
 
32
+ hubert_dict = {
33
+ "vec768l12": utils.get_speech_encoder("vec768l12", device="cpu"),
34
+ "vec256l9": utils.get_speech_encoder("vec256l9", device="cpu")
35
+ }
36
+
37
 
38
  def create_fn(model, spk):
39
  def svc_fn(input_audio, vc_transform, auto_f0, f0p):
 
45
  audio = librosa.to_mono(audio.transpose(1, 0))
46
  temp_path = "temp.wav"
47
  soundfile.write(temp_path, audio, sr, format="wav")
48
+
49
+ model.hubert_model = hubert_dict[model.speech_encoder]
50
  out_audio = model.slice_inference(raw_audio_path=temp_path,
51
  spk=spk,
52
  slice_db=-40,
 
66
  input_text = re.sub(r"[\n\,\(\) ]", "", input_text)
67
  voice = tts_voice[gender]
68
  ratestr = "+{:.0%}".format(tts_rate) if tts_rate >= 0 else "{:.0%}".format(tts_rate)
 
 
 
 
 
 
 
 
 
69
  communicate = edge_tts.Communicate(text=input_text,
70
  voice=voice,
71
  rate=ratestr)
inference/infer_tool.py CHANGED
@@ -131,6 +131,7 @@ class Svc(object):
131
  spk_mix_enable=False,
132
  feature_retrieval=False
133
  ):
 
134
  self.net_g_path = net_g_path
135
  self.only_diffusion = only_diffusion
136
  self.shallow_diffusion = shallow_diffusion
@@ -172,13 +173,9 @@ class Svc(object):
172
  self.shallow_diffusion = self.only_diffusion = False
173
 
174
  # load hubert and model
175
- if not self.only_diffusion:
176
- self.load_model(spk_mix_enable)
177
- self.hubert_model = utils.get_speech_encoder(self.speech_encoder, device=self.dev)
178
- self.volume_extractor = utils.Volume_Extractor(self.hop_size)
179
- else:
180
- self.hubert_model = utils.get_speech_encoder(self.diffusion_args.data.encoder, device=self.dev)
181
- self.volume_extractor = utils.Volume_Extractor(self.diffusion_args.data.block_size)
182
 
183
  if os.path.exists(cluster_model_path):
184
  if self.feature_retrieval:
 
131
  spk_mix_enable=False,
132
  feature_retrieval=False
133
  ):
134
+ self.hubert_model = None
135
  self.net_g_path = net_g_path
136
  self.only_diffusion = only_diffusion
137
  self.shallow_diffusion = shallow_diffusion
 
173
  self.shallow_diffusion = self.only_diffusion = False
174
 
175
  # load hubert and model
176
+ self.load_model(spk_mix_enable)
177
+ # self.hubert_model = utils.get_speech_encoder(self.speech_encoder, device=self.dev) // ram optimize
178
+ self.volume_extractor = utils.Volume_Extractor(self.hop_size)
 
 
 
 
179
 
180
  if os.path.exists(cluster_model_path):
181
  if self.feature_retrieval:
utils.py CHANGED
@@ -23,6 +23,7 @@ f0_min = 50.0
23
  f0_mel_min = 1127 * np.log(1 + f0_min / 700)
24
  f0_mel_max = 1127 * np.log(1 + f0_max / 700)
25
 
 
26
  def normalize_f0(f0, x_mask, uv, random_scale=True):
27
  # calculate means based on x_mask
28
  uv_sum = torch.sum(uv, dim=1, keepdim=True)
@@ -39,6 +40,7 @@ def normalize_f0(f0, x_mask, uv, random_scale=True):
39
  exit(0)
40
  return f0_norm * x_mask
41
 
 
42
  def plot_data_to_numpy(x, y):
43
  global MATPLOTLIB_FLAG
44
  if not MATPLOTLIB_FLAG:
@@ -61,6 +63,7 @@ def plot_data_to_numpy(x, y):
61
  plt.close()
62
  return data
63
 
 
64
  def interpolate_f0(f0):
65
  '''
66
  对F0进行插值处理
@@ -97,15 +100,16 @@ def interpolate_f0(f0):
97
  ip_data[i] = data[i]
98
  last_value = data[i]
99
 
100
- return ip_data[:,0], vuv_vector[:,0]
 
101
 
102
  def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
103
  import parselmouth
104
  x = wav_numpy
105
  if p_len is None:
106
- p_len = x.shape[0]//hop_length
107
  else:
108
- assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
109
  time_step = hop_length / sampling_rate * 1000
110
  f0_min = 50
111
  f0_max = 1100
@@ -113,22 +117,25 @@ def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_lengt
113
  time_step=time_step / 1000, voicing_threshold=0.6,
114
  pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
115
 
116
- pad_size=(p_len - len(f0) + 1) // 2
117
- if(pad_size>0 or p_len - len(f0) - pad_size>0):
118
- f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
119
  return f0
120
 
 
121
  def resize_f0(x, target_len):
122
  source = np.array(x)
123
- source[source<0.001] = np.nan
124
- target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
 
125
  res = np.nan_to_num(target)
126
  return res
127
 
 
128
  def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
129
  import pyworld
130
  if p_len is None:
131
- p_len = wav_numpy.shape[0]//hop_length
132
  f0, t = pyworld.dio(
133
  wav_numpy.astype(np.double),
134
  fs=sampling_rate,
@@ -140,45 +147,49 @@ def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
140
  f0[index] = round(pitch, 1)
141
  return resize_f0(f0, p_len)
142
 
 
143
  def f0_to_coarse(f0):
144
- is_torch = isinstance(f0, torch.Tensor)
145
- f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
146
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
 
 
 
 
 
 
147
 
148
- f0_mel[f0_mel <= 1] = 1
149
- f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
150
- f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int)
151
- assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
152
- return f0_coarse
153
 
154
  def get_hubert_model():
155
- vec_path = "hubert/checkpoint_best_legacy_500.pt"
156
- print("load model(s) from {}".format(vec_path))
157
- from fairseq import checkpoint_utils
158
- models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
159
- [vec_path],
160
- suffix="",
161
- )
162
- model = models[0]
163
- model.eval()
164
- return model
 
165
 
166
  def get_hubert_content(hmodel, wav_16k_tensor):
167
- feats = wav_16k_tensor
168
- if feats.dim() == 2: # double channels
169
- feats = feats.mean(-1)
170
- assert feats.dim() == 1, feats.dim()
171
- feats = feats.view(1, -1)
172
- padding_mask = torch.BoolTensor(feats.shape).fill_(False)
173
- inputs = {
174
- "source": feats.to(wav_16k_tensor.device),
175
- "padding_mask": padding_mask.to(wav_16k_tensor.device),
176
- "output_layer": 9, # layer 9
177
- }
178
- with torch.no_grad():
179
- logits = hmodel.extract_features(**inputs)
180
- feats = hmodel.final_proj(logits[0])
181
- return feats.transpose(1, 2)
 
182
 
183
  def get_content(cmodel, y):
184
  with torch.no_grad():
@@ -186,63 +197,37 @@ def get_content(cmodel, y):
186
  c = c.transpose(1, 2)
187
  return c
188
 
189
- def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
 
190
  if f0_predictor == "pm":
191
  from modules.F0Predictor.PMF0Predictor import PMF0Predictor
192
- f0_predictor_object = PMF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
193
  elif f0_predictor == "crepe":
194
  from modules.F0Predictor.CrepeF0Predictor import CrepeF0Predictor
195
- f0_predictor_object = CrepeF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,device=kargs["device"],threshold=kargs["threshold"])
 
196
  elif f0_predictor == "harvest":
197
  from modules.F0Predictor.HarvestF0Predictor import HarvestF0Predictor
198
- f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
199
  elif f0_predictor == "dio":
200
  from modules.F0Predictor.DioF0Predictor import DioF0Predictor
201
- f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
202
  else:
203
  raise Exception("Unknown f0 predictor")
204
  return f0_predictor_object
205
 
206
- def get_speech_encoder(speech_encoder,device=None,**kargs):
 
207
  if speech_encoder == "vec768l12":
208
  from vencoder.ContentVec768L12 import ContentVec768L12
209
- speech_encoder_object = ContentVec768L12(device = device)
210
  elif speech_encoder == "vec256l9":
211
  from vencoder.ContentVec256L9 import ContentVec256L9
212
- speech_encoder_object = ContentVec256L9(device = device)
213
- elif speech_encoder == "vec256l9-onnx":
214
- from vencoder.ContentVec256L9_Onnx import ContentVec256L9_Onnx
215
- speech_encoder_object = ContentVec256L9_Onnx(device = device)
216
- elif speech_encoder == "vec256l12-onnx":
217
- from vencoder.ContentVec256L12_Onnx import ContentVec256L12_Onnx
218
- speech_encoder_object = ContentVec256L12_Onnx(device = device)
219
- elif speech_encoder == "vec768l9-onnx":
220
- from vencoder.ContentVec768L9_Onnx import ContentVec768L9_Onnx
221
- speech_encoder_object = ContentVec768L9_Onnx(device = device)
222
- elif speech_encoder == "vec768l12-onnx":
223
- from vencoder.ContentVec768L12_Onnx import ContentVec768L12_Onnx
224
- speech_encoder_object = ContentVec768L12_Onnx(device = device)
225
- elif speech_encoder == "hubertsoft-onnx":
226
- from vencoder.HubertSoft_Onnx import HubertSoft_Onnx
227
- speech_encoder_object = HubertSoft_Onnx(device = device)
228
- elif speech_encoder == "hubertsoft":
229
- from vencoder.HubertSoft import HubertSoft
230
- speech_encoder_object = HubertSoft(device = device)
231
- elif speech_encoder == "whisper-ppg":
232
- from vencoder.WhisperPPG import WhisperPPG
233
- speech_encoder_object = WhisperPPG(device = device)
234
- elif speech_encoder == "cnhubertlarge":
235
- from vencoder.CNHubertLarge import CNHubertLarge
236
- speech_encoder_object = CNHubertLarge(device = device)
237
- elif speech_encoder == "dphubert":
238
- from vencoder.DPHubert import DPHubert
239
- speech_encoder_object = DPHubert(device = device)
240
- elif speech_encoder == "whisper-ppg-large":
241
- from vencoder.WhisperPPGLarge import WhisperPPGLarge
242
- speech_encoder_object = WhisperPPGLarge(device = device)
243
  else:
244
  raise Exception("Unknown speech encoder")
245
- return speech_encoder_object
 
246
 
247
  def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
248
  assert os.path.isfile(checkpoint_path)
@@ -276,164 +261,168 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
276
  checkpoint_path, iteration))
277
  return model, optimizer, learning_rate, iteration
278
 
 
279
  def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
280
- logger.info("Saving model and optimizer state at iteration {} to {}".format(
281
- iteration, checkpoint_path))
282
- if hasattr(model, 'module'):
283
- state_dict = model.module.state_dict()
284
- else:
285
- state_dict = model.state_dict()
286
- torch.save({'model': state_dict,
287
- 'iteration': iteration,
288
- 'optimizer': optimizer.state_dict(),
289
- 'learning_rate': learning_rate}, checkpoint_path)
 
290
 
291
  def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
292
- """Freeing up space by deleting saved ckpts
293
-
294
- Arguments:
295
- path_to_models -- Path to the model directory
296
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
297
- sort_by_time -- True -> chronologically delete ckpts
298
- False -> lexicographically delete ckpts
299
- """
300
- ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
301
- name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
302
- time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
303
- sort_key = time_key if sort_by_time else name_key
304
- x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
305
- to_del = [os.path.join(path_to_models, fn) for fn in
306
- (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
307
- del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
308
- del_routine = lambda x: [os.remove(x), del_info(x)]
309
- rs = [del_routine(fn) for fn in to_del]
 
 
310
 
311
  def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
312
- for k, v in scalars.items():
313
- writer.add_scalar(k, v, global_step)
314
- for k, v in histograms.items():
315
- writer.add_histogram(k, v, global_step)
316
- for k, v in images.items():
317
- writer.add_image(k, v, global_step, dataformats='HWC')
318
- for k, v in audios.items():
319
- writer.add_audio(k, v, global_step, audio_sampling_rate)
320
 
321
 
322
  def latest_checkpoint_path(dir_path, regex="G_*.pth"):
323
- f_list = glob.glob(os.path.join(dir_path, regex))
324
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
325
- x = f_list[-1]
326
- print(x)
327
- return x
328
 
329
 
330
  def plot_spectrogram_to_numpy(spectrogram):
331
- global MATPLOTLIB_FLAG
332
- if not MATPLOTLIB_FLAG:
333
- import matplotlib
334
- matplotlib.use("Agg")
335
- MATPLOTLIB_FLAG = True
336
- mpl_logger = logging.getLogger('matplotlib')
337
- mpl_logger.setLevel(logging.WARNING)
338
- import matplotlib.pylab as plt
339
- import numpy as np
340
-
341
- fig, ax = plt.subplots(figsize=(10,2))
342
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
343
- interpolation='none')
344
- plt.colorbar(im, ax=ax)
345
- plt.xlabel("Frames")
346
- plt.ylabel("Channels")
347
- plt.tight_layout()
348
-
349
- fig.canvas.draw()
350
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
351
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
352
- plt.close()
353
- return data
354
 
355
 
356
  def plot_alignment_to_numpy(alignment, info=None):
357
- global MATPLOTLIB_FLAG
358
- if not MATPLOTLIB_FLAG:
359
- import matplotlib
360
- matplotlib.use("Agg")
361
- MATPLOTLIB_FLAG = True
362
- mpl_logger = logging.getLogger('matplotlib')
363
- mpl_logger.setLevel(logging.WARNING)
364
- import matplotlib.pylab as plt
365
- import numpy as np
366
-
367
- fig, ax = plt.subplots(figsize=(6, 4))
368
- im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
369
- interpolation='none')
370
- fig.colorbar(im, ax=ax)
371
- xlabel = 'Decoder timestep'
372
- if info is not None:
373
- xlabel += '\n\n' + info
374
- plt.xlabel(xlabel)
375
- plt.ylabel('Encoder timestep')
376
- plt.tight_layout()
377
-
378
- fig.canvas.draw()
379
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
380
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
381
- plt.close()
382
- return data
383
 
384
 
385
  def load_wav_to_torch(full_path):
386
- sampling_rate, data = read(full_path)
387
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
388
 
389
 
390
  def load_filepaths_and_text(filename, split="|"):
391
- with open(filename, encoding='utf-8') as f:
392
- filepaths_and_text = [line.strip().split(split) for line in f]
393
- return filepaths_and_text
394
 
395
 
396
  def get_hparams_from_file(config_path):
397
- with open(config_path, "r") as f:
398
- data = f.read()
399
- config = json.loads(data)
400
- hparams =HParams(**config)
401
- return hparams
402
 
403
 
404
  def check_git_hash(model_dir):
405
- source_dir = os.path.dirname(os.path.realpath(__file__))
406
- if not os.path.exists(os.path.join(source_dir, ".git")):
407
- logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
408
- source_dir
409
- ))
410
- return
411
-
412
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
413
-
414
- path = os.path.join(model_dir, "githash")
415
- if os.path.exists(path):
416
- saved_hash = open(path).read()
417
- if saved_hash != cur_hash:
418
- logger.warn("git hash values are different. {}(saved) != {}(current)".format(
419
- saved_hash[:8], cur_hash[:8]))
420
- else:
421
- open(path, "w").write(cur_hash)
422
 
423
 
424
  def get_logger(model_dir, filename="train.log"):
425
- global logger
426
- logger = logging.getLogger(os.path.basename(model_dir))
427
- logger.setLevel(logging.DEBUG)
428
 
429
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
430
- if not os.path.exists(model_dir):
431
- os.makedirs(model_dir)
432
- h = logging.FileHandler(os.path.join(model_dir, filename))
433
- h.setLevel(logging.DEBUG)
434
- h.setFormatter(formatter)
435
- logger.addHandler(h)
436
- return logger
437
 
438
 
439
  def repeat_expand_2d(content, target_len):
@@ -441,10 +430,10 @@ def repeat_expand_2d(content, target_len):
441
 
442
  src_len = content.shape[-1]
443
  target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
444
- temp = torch.arange(src_len+1) * target_len / src_len
445
  current_pos = 0
446
  for i in range(target_len):
447
- if i < temp[current_pos+1]:
448
  target[:, i] = content[:, current_pos]
449
  else:
450
  current_pos += 1
@@ -453,7 +442,6 @@ def repeat_expand_2d(content, target_len):
453
  return target
454
 
455
 
456
-
457
  def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比 from RVC
458
  # print(data1.max(),data2.max())
459
  rms1 = librosa.feature.rms(
@@ -470,56 +458,59 @@ def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出
470
  ).squeeze()
471
  rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
472
  data2 *= (
473
- torch.pow(rms1, torch.tensor(1 - rate))
474
- * torch.pow(rms2, torch.tensor(rate - 1))
475
  )
476
  return data2
477
 
478
 
479
  class HParams():
480
- def __init__(self, **kwargs):
481
- for k, v in kwargs.items():
482
- if type(v) == dict:
483
- v = HParams(**v)
484
- self[k] = v
 
 
 
485
 
486
- def keys(self):
487
- return self.__dict__.keys()
488
 
489
- def items(self):
490
- return self.__dict__.items()
491
 
492
- def values(self):
493
- return self.__dict__.values()
494
 
495
- def __len__(self):
496
- return len(self.__dict__)
497
 
498
- def __getitem__(self, key):
499
- return getattr(self, key)
500
 
501
- def __setitem__(self, key, value):
502
- return setattr(self, key, value)
503
 
504
- def __contains__(self, key):
505
- return key in self.__dict__
506
 
507
- def __repr__(self):
508
- return self.__dict__.__repr__()
509
 
510
- def get(self,index):
511
- return self.__dict__.get(index)
512
 
513
  class Volume_Extractor:
514
- def __init__(self, hop_size = 512):
515
  self.hop_size = hop_size
516
-
517
- def extract(self, audio): # audio: 2d tensor array
518
- if not isinstance(audio,torch.Tensor):
519
- audio = torch.Tensor(audio)
520
  n_frames = int(audio.size(-1) // self.hop_size)
521
  audio2 = audio ** 2
522
- audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
523
- volume = torch.FloatTensor([torch.mean(audio2[:,int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
 
 
524
  volume = torch.sqrt(volume)
525
  return volume
 
23
  f0_mel_min = 1127 * np.log(1 + f0_min / 700)
24
  f0_mel_max = 1127 * np.log(1 + f0_max / 700)
25
 
26
+
27
  def normalize_f0(f0, x_mask, uv, random_scale=True):
28
  # calculate means based on x_mask
29
  uv_sum = torch.sum(uv, dim=1, keepdim=True)
 
40
  exit(0)
41
  return f0_norm * x_mask
42
 
43
+
44
  def plot_data_to_numpy(x, y):
45
  global MATPLOTLIB_FLAG
46
  if not MATPLOTLIB_FLAG:
 
63
  plt.close()
64
  return data
65
 
66
+
67
  def interpolate_f0(f0):
68
  '''
69
  对F0进行插值处理
 
100
  ip_data[i] = data[i]
101
  last_value = data[i]
102
 
103
+ return ip_data[:, 0], vuv_vector[:, 0]
104
+
105
 
106
  def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
107
  import parselmouth
108
  x = wav_numpy
109
  if p_len is None:
110
+ p_len = x.shape[0] // hop_length
111
  else:
112
+ assert abs(p_len - x.shape[0] // hop_length) < 4, "pad length error"
113
  time_step = hop_length / sampling_rate * 1000
114
  f0_min = 50
115
  f0_max = 1100
 
117
  time_step=time_step / 1000, voicing_threshold=0.6,
118
  pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
119
 
120
+ pad_size = (p_len - len(f0) + 1) // 2
121
+ if (pad_size > 0 or p_len - len(f0) - pad_size > 0):
122
+ f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode='constant')
123
  return f0
124
 
125
+
126
  def resize_f0(x, target_len):
127
  source = np.array(x)
128
+ source[source < 0.001] = np.nan
129
+ target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)),
130
+ source)
131
  res = np.nan_to_num(target)
132
  return res
133
 
134
+
135
  def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
136
  import pyworld
137
  if p_len is None:
138
+ p_len = wav_numpy.shape[0] // hop_length
139
  f0, t = pyworld.dio(
140
  wav_numpy.astype(np.double),
141
  fs=sampling_rate,
 
147
  f0[index] = round(pitch, 1)
148
  return resize_f0(f0, p_len)
149
 
150
+
151
  def f0_to_coarse(f0):
152
+ is_torch = isinstance(f0, torch.Tensor)
153
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
154
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
155
+
156
+ f0_mel[f0_mel <= 1] = 1
157
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
158
+ f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int)
159
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
160
+ return f0_coarse
161
 
 
 
 
 
 
162
 
163
  def get_hubert_model():
164
+ vec_path = "hubert/checkpoint_best_legacy_500.pt"
165
+ print("load model(s) from {}".format(vec_path))
166
+ from fairseq import checkpoint_utils
167
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
168
+ [vec_path],
169
+ suffix="",
170
+ )
171
+ model = models[0]
172
+ model.eval()
173
+ return model
174
+
175
 
176
  def get_hubert_content(hmodel, wav_16k_tensor):
177
+ feats = wav_16k_tensor
178
+ if feats.dim() == 2: # double channels
179
+ feats = feats.mean(-1)
180
+ assert feats.dim() == 1, feats.dim()
181
+ feats = feats.view(1, -1)
182
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
183
+ inputs = {
184
+ "source": feats.to(wav_16k_tensor.device),
185
+ "padding_mask": padding_mask.to(wav_16k_tensor.device),
186
+ "output_layer": 9, # layer 9
187
+ }
188
+ with torch.no_grad():
189
+ logits = hmodel.extract_features(**inputs)
190
+ feats = hmodel.final_proj(logits[0])
191
+ return feats.transpose(1, 2)
192
+
193
 
194
  def get_content(cmodel, y):
195
  with torch.no_grad():
 
197
  c = c.transpose(1, 2)
198
  return c
199
 
200
+
201
+ def get_f0_predictor(f0_predictor, hop_length, sampling_rate, **kargs):
202
  if f0_predictor == "pm":
203
  from modules.F0Predictor.PMF0Predictor import PMF0Predictor
204
+ f0_predictor_object = PMF0Predictor(hop_length=hop_length, sampling_rate=sampling_rate)
205
  elif f0_predictor == "crepe":
206
  from modules.F0Predictor.CrepeF0Predictor import CrepeF0Predictor
207
+ f0_predictor_object = CrepeF0Predictor(hop_length=hop_length, sampling_rate=sampling_rate,
208
+ device=kargs["device"], threshold=kargs["threshold"])
209
  elif f0_predictor == "harvest":
210
  from modules.F0Predictor.HarvestF0Predictor import HarvestF0Predictor
211
+ f0_predictor_object = HarvestF0Predictor(hop_length=hop_length, sampling_rate=sampling_rate)
212
  elif f0_predictor == "dio":
213
  from modules.F0Predictor.DioF0Predictor import DioF0Predictor
214
+ f0_predictor_object = DioF0Predictor(hop_length=hop_length, sampling_rate=sampling_rate)
215
  else:
216
  raise Exception("Unknown f0 predictor")
217
  return f0_predictor_object
218
 
219
+
220
+ def get_speech_encoder(speech_encoder, device=None, **kargs):
221
  if speech_encoder == "vec768l12":
222
  from vencoder.ContentVec768L12 import ContentVec768L12
223
+ speech_encoder_object = ContentVec768L12(device=device)
224
  elif speech_encoder == "vec256l9":
225
  from vencoder.ContentVec256L9 import ContentVec256L9
226
+ speech_encoder_object = ContentVec256L9(device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  else:
228
  raise Exception("Unknown speech encoder")
229
+ return speech_encoder_object
230
+
231
 
232
  def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
233
  assert os.path.isfile(checkpoint_path)
 
261
  checkpoint_path, iteration))
262
  return model, optimizer, learning_rate, iteration
263
 
264
+
265
  def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
266
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
267
+ iteration, checkpoint_path))
268
+ if hasattr(model, 'module'):
269
+ state_dict = model.module.state_dict()
270
+ else:
271
+ state_dict = model.state_dict()
272
+ torch.save({'model': state_dict,
273
+ 'iteration': iteration,
274
+ 'optimizer': optimizer.state_dict(),
275
+ 'learning_rate': learning_rate}, checkpoint_path)
276
+
277
 
278
  def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
279
+ """Freeing up space by deleting saved ckpts
280
+
281
+ Arguments:
282
+ path_to_models -- Path to the model directory
283
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
284
+ sort_by_time -- True -> chronologically delete ckpts
285
+ False -> lexicographically delete ckpts
286
+ """
287
+ ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
288
+ name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
289
+ time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
290
+ sort_key = time_key if sort_by_time else name_key
291
+ x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')],
292
+ key=sort_key)
293
+ to_del = [os.path.join(path_to_models, fn) for fn in
294
+ (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
295
+ del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
296
+ del_routine = lambda x: [os.remove(x), del_info(x)]
297
+ rs = [del_routine(fn) for fn in to_del]
298
+
299
 
300
  def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
301
+ for k, v in scalars.items():
302
+ writer.add_scalar(k, v, global_step)
303
+ for k, v in histograms.items():
304
+ writer.add_histogram(k, v, global_step)
305
+ for k, v in images.items():
306
+ writer.add_image(k, v, global_step, dataformats='HWC')
307
+ for k, v in audios.items():
308
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
309
 
310
 
311
  def latest_checkpoint_path(dir_path, regex="G_*.pth"):
312
+ f_list = glob.glob(os.path.join(dir_path, regex))
313
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
314
+ x = f_list[-1]
315
+ print(x)
316
+ return x
317
 
318
 
319
  def plot_spectrogram_to_numpy(spectrogram):
320
+ global MATPLOTLIB_FLAG
321
+ if not MATPLOTLIB_FLAG:
322
+ import matplotlib
323
+ matplotlib.use("Agg")
324
+ MATPLOTLIB_FLAG = True
325
+ mpl_logger = logging.getLogger('matplotlib')
326
+ mpl_logger.setLevel(logging.WARNING)
327
+ import matplotlib.pylab as plt
328
+ import numpy as np
329
+
330
+ fig, ax = plt.subplots(figsize=(10, 2))
331
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
332
+ interpolation='none')
333
+ plt.colorbar(im, ax=ax)
334
+ plt.xlabel("Frames")
335
+ plt.ylabel("Channels")
336
+ plt.tight_layout()
337
+
338
+ fig.canvas.draw()
339
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
340
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
341
+ plt.close()
342
+ return data
343
 
344
 
345
  def plot_alignment_to_numpy(alignment, info=None):
346
+ global MATPLOTLIB_FLAG
347
+ if not MATPLOTLIB_FLAG:
348
+ import matplotlib
349
+ matplotlib.use("Agg")
350
+ MATPLOTLIB_FLAG = True
351
+ mpl_logger = logging.getLogger('matplotlib')
352
+ mpl_logger.setLevel(logging.WARNING)
353
+ import matplotlib.pylab as plt
354
+ import numpy as np
355
+
356
+ fig, ax = plt.subplots(figsize=(6, 4))
357
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
358
+ interpolation='none')
359
+ fig.colorbar(im, ax=ax)
360
+ xlabel = 'Decoder timestep'
361
+ if info is not None:
362
+ xlabel += '\n\n' + info
363
+ plt.xlabel(xlabel)
364
+ plt.ylabel('Encoder timestep')
365
+ plt.tight_layout()
366
+
367
+ fig.canvas.draw()
368
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
369
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
370
+ plt.close()
371
+ return data
372
 
373
 
374
  def load_wav_to_torch(full_path):
375
+ sampling_rate, data = read(full_path)
376
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
377
 
378
 
379
  def load_filepaths_and_text(filename, split="|"):
380
+ with open(filename, encoding='utf-8') as f:
381
+ filepaths_and_text = [line.strip().split(split) for line in f]
382
+ return filepaths_and_text
383
 
384
 
385
  def get_hparams_from_file(config_path):
386
+ with open(config_path, "r") as f:
387
+ data = f.read()
388
+ config = json.loads(data)
389
+ hparams = HParams(**config)
390
+ return hparams
391
 
392
 
393
  def check_git_hash(model_dir):
394
+ source_dir = os.path.dirname(os.path.realpath(__file__))
395
+ if not os.path.exists(os.path.join(source_dir, ".git")):
396
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
397
+ source_dir
398
+ ))
399
+ return
400
+
401
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
402
+
403
+ path = os.path.join(model_dir, "githash")
404
+ if os.path.exists(path):
405
+ saved_hash = open(path).read()
406
+ if saved_hash != cur_hash:
407
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
408
+ saved_hash[:8], cur_hash[:8]))
409
+ else:
410
+ open(path, "w").write(cur_hash)
411
 
412
 
413
  def get_logger(model_dir, filename="train.log"):
414
+ global logger
415
+ logger = logging.getLogger(os.path.basename(model_dir))
416
+ logger.setLevel(logging.DEBUG)
417
 
418
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
419
+ if not os.path.exists(model_dir):
420
+ os.makedirs(model_dir)
421
+ h = logging.FileHandler(os.path.join(model_dir, filename))
422
+ h.setLevel(logging.DEBUG)
423
+ h.setFormatter(formatter)
424
+ logger.addHandler(h)
425
+ return logger
426
 
427
 
428
  def repeat_expand_2d(content, target_len):
 
430
 
431
  src_len = content.shape[-1]
432
  target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
433
+ temp = torch.arange(src_len + 1) * target_len / src_len
434
  current_pos = 0
435
  for i in range(target_len):
436
+ if i < temp[current_pos + 1]:
437
  target[:, i] = content[:, current_pos]
438
  else:
439
  current_pos += 1
 
442
  return target
443
 
444
 
 
445
  def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比 from RVC
446
  # print(data1.max(),data2.max())
447
  rms1 = librosa.feature.rms(
 
458
  ).squeeze()
459
  rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
460
  data2 *= (
461
+ torch.pow(rms1, torch.tensor(1 - rate))
462
+ * torch.pow(rms2, torch.tensor(rate - 1))
463
  )
464
  return data2
465
 
466
 
467
  class HParams():
468
+ def __init__(self, **kwargs):
469
+ for k, v in kwargs.items():
470
+ if type(v) == dict:
471
+ v = HParams(**v)
472
+ self[k] = v
473
+
474
+ def keys(self):
475
+ return self.__dict__.keys()
476
 
477
+ def items(self):
478
+ return self.__dict__.items()
479
 
480
+ def values(self):
481
+ return self.__dict__.values()
482
 
483
+ def __len__(self):
484
+ return len(self.__dict__)
485
 
486
+ def __getitem__(self, key):
487
+ return getattr(self, key)
488
 
489
+ def __setitem__(self, key, value):
490
+ return setattr(self, key, value)
491
 
492
+ def __contains__(self, key):
493
+ return key in self.__dict__
494
 
495
+ def __repr__(self):
496
+ return self.__dict__.__repr__()
497
 
498
+ def get(self, index):
499
+ return self.__dict__.get(index)
500
 
 
 
501
 
502
  class Volume_Extractor:
503
+ def __init__(self, hop_size=512):
504
  self.hop_size = hop_size
505
+
506
+ def extract(self, audio): # audio: 2d tensor array
507
+ if not isinstance(audio, torch.Tensor):
508
+ audio = torch.Tensor(audio)
509
  n_frames = int(audio.size(-1) // self.hop_size)
510
  audio2 = audio ** 2
511
+ audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)),
512
+ mode='reflect')
513
+ volume = torch.FloatTensor(
514
+ [torch.mean(audio2[:, int(n * self.hop_size): int((n + 1) * self.hop_size)]) for n in range(n_frames)])
515
  volume = torch.sqrt(volume)
516
  return volume
vencoder/ContentVec256L9.py CHANGED
@@ -2,12 +2,13 @@ from vencoder.encoder import SpeechEncoder
2
  import torch
3
  from fairseq import checkpoint_utils
4
 
 
5
  class ContentVec256L9(SpeechEncoder):
6
- def __init__(self,vec_path = "pretrain/checkpoint_best_legacy_500.pt",device=None):
7
  print("load model(s) from {}".format(vec_path))
8
  models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
9
- [vec_path],
10
- suffix="",
11
  )
12
  self.hidden_dim = 256
13
  if device is None:
@@ -20,16 +21,16 @@ class ContentVec256L9(SpeechEncoder):
20
  def encoder(self, wav):
21
  feats = wav
22
  if feats.dim() == 2: # double channels
23
- feats = feats.mean(-1)
24
  assert feats.dim() == 1, feats.dim()
25
  feats = feats.view(1, -1)
26
  padding_mask = torch.BoolTensor(feats.shape).fill_(False)
27
  inputs = {
28
- "source": feats.to(wav.device),
29
- "padding_mask": padding_mask.to(wav.device),
30
- "output_layer": 9, # layer 9
31
  }
32
  with torch.no_grad():
33
- logits = self.model.extract_features(**inputs)
34
- feats = self.model.final_proj(logits[0])
35
  return feats.transpose(1, 2)
 
2
  import torch
3
  from fairseq import checkpoint_utils
4
 
5
+
6
  class ContentVec256L9(SpeechEncoder):
7
+ def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None):
8
  print("load model(s) from {}".format(vec_path))
9
  models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
10
+ [vec_path],
11
+ suffix="",
12
  )
13
  self.hidden_dim = 256
14
  if device is None:
 
21
  def encoder(self, wav):
22
  feats = wav
23
  if feats.dim() == 2: # double channels
24
+ feats = feats.mean(-1)
25
  assert feats.dim() == 1, feats.dim()
26
  feats = feats.view(1, -1)
27
  padding_mask = torch.BoolTensor(feats.shape).fill_(False)
28
  inputs = {
29
+ "source": feats.to(wav.device),
30
+ "padding_mask": padding_mask.to(wav.device),
31
+ "output_layer": 9, # layer 9
32
  }
33
  with torch.no_grad():
34
+ logits = self.model.extract_features(**inputs)
35
+ feats = self.model.final_proj(logits[0])
36
  return feats.transpose(1, 2)
vencoder/ContentVec768L12.py CHANGED
@@ -2,13 +2,14 @@ from vencoder.encoder import SpeechEncoder
2
  import torch
3
  from fairseq import checkpoint_utils
4
 
 
5
  class ContentVec768L12(SpeechEncoder):
6
- def __init__(self,vec_path = "pretrain/checkpoint_best_legacy_500.pt",device=None):
7
  print("load model(s) from {}".format(vec_path))
8
  self.hidden_dim = 768
9
  models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
10
- [vec_path],
11
- suffix="",
12
  )
13
  if device is None:
14
  self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -20,15 +21,15 @@ class ContentVec768L12(SpeechEncoder):
20
  def encoder(self, wav):
21
  feats = wav
22
  if feats.dim() == 2: # double channels
23
- feats = feats.mean(-1)
24
  assert feats.dim() == 1, feats.dim()
25
  feats = feats.view(1, -1)
26
  padding_mask = torch.BoolTensor(feats.shape).fill_(False)
27
  inputs = {
28
- "source": feats.to(wav.device),
29
- "padding_mask": padding_mask.to(wav.device),
30
- "output_layer": 12, # layer 12
31
  }
32
  with torch.no_grad():
33
- logits = self.model.extract_features(**inputs)
34
- return logits[0].transpose(1, 2)
 
2
  import torch
3
  from fairseq import checkpoint_utils
4
 
5
+
6
  class ContentVec768L12(SpeechEncoder):
7
+ def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None):
8
  print("load model(s) from {}".format(vec_path))
9
  self.hidden_dim = 768
10
  models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
11
+ [vec_path],
12
+ suffix="",
13
  )
14
  if device is None:
15
  self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
21
  def encoder(self, wav):
22
  feats = wav
23
  if feats.dim() == 2: # double channels
24
+ feats = feats.mean(-1)
25
  assert feats.dim() == 1, feats.dim()
26
  feats = feats.view(1, -1)
27
  padding_mask = torch.BoolTensor(feats.shape).fill_(False)
28
  inputs = {
29
+ "source": feats.to(wav.device),
30
+ "padding_mask": padding_mask.to(wav.device),
31
+ "output_layer": 12, # layer 12
32
  }
33
  with torch.no_grad():
34
+ logits = self.model.extract_features(**inputs)
35
+ return logits[0].transpose(1, 2)