oh-my-dear-ai commited on
Commit
8114915
·
verified ·
1 Parent(s): d1a578e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -20
app.py CHANGED
@@ -245,7 +245,18 @@ def splite_en_inf(sentence, language):
245
 
246
 
247
  def clean_text_inf(text, language):
248
- phones, word2ph, norm_text = clean_text(text, language.replace("all_",""))
 
 
 
 
 
 
 
 
 
 
 
249
  phones = cleaned_text_to_sequence(phones)
250
  return phones, word2ph, norm_text
251
 
@@ -305,9 +316,8 @@ def nonen_get_bert_inf(text, language):
305
  print(langlist)
306
  bert_list = []
307
  for i in range(len(textlist)):
308
- text = textlist[i]
309
  lang = langlist[i]
310
- phones, word2ph, norm_text = clean_text_inf(text, lang)
311
  bert = get_bert_inf(phones, word2ph, norm_text, lang)
312
  bert_list.append(bert)
313
  bert = torch.cat(bert_list, dim=1)
@@ -342,6 +352,23 @@ def get_bert_final(phones, word2ph, norm_text,language,device):
342
  bert = torch.zeros((1024, len(phones))).to(device)
343
  return bert
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  def get_ref_path_decor(func):
346
  # 为了hg部署添加的装饰函数,将参考文本的内容改为路径
347
  def inner(ref_wav_content, *args):
@@ -373,13 +400,19 @@ audio_folder_path = 'audio'
373
  text_to_audio_mappings, audio_to_text_mappings = load_audio_text_mappings(audio_folder_path, 'slicer_opt.list')
374
 
375
  @get_ref_path_decor
376
- def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切")):
 
 
377
  t0 = ttime()
378
- prompt_text = prompt_text.strip("\n")
379
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
 
 
 
 
380
  text = text.strip("\n")
381
  if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
382
- print(i18n("实际输入的参考文本:"), prompt_text)
383
  print(i18n("实际输入的目标文本:"), text)
384
  zero_wav = np.zeros(
385
  int(hps.data.sampling_rate * 0.3),
@@ -404,12 +437,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
404
  1, 2
405
  ) # .float()
406
  codes = vq_model.extract_latent(ssl_content)
 
407
  prompt_semantic = codes[0, 0]
408
  t1 = ttime()
409
- prompt_language = dict_language[prompt_language]
410
- text_language = dict_language[text_language]
411
-
412
- phones1, word2ph1, norm_text1=get_cleaned_text_fianl(prompt_text, prompt_language)
413
 
414
  if (how_to_cut == i18n("凑四句一切")):
415
  text = cut1(text)
@@ -421,11 +451,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
421
  text = cut4(text)
422
  elif (how_to_cut == i18n("按标点符号切")):
423
  text = cut5(text)
424
- text = text.replace("\n\n", "\n").replace("\n\n", "\n").replace("\n\n", "\n")
 
425
  print(i18n("实际输入的目标文本(切句后):"), text)
426
  texts = text.split("\n")
 
427
  audio_opt = []
428
- bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
 
 
429
 
430
  for text in texts:
431
  # 解决输入目标文本的空行导致报错的问题
@@ -433,12 +467,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
433
  continue
434
  if (text[-1] not in splits): text += "。" if text_language != "en" else "."
435
  print(i18n("实际输入的目标文本(每句):"), text)
436
- phones2, word2ph2, norm_text2 = get_cleaned_text_fianl(text, text_language)
437
  bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
 
 
 
 
 
 
438
 
439
- bert = torch.cat([bert1, bert2], 1)
440
-
441
- all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
442
  bert = bert.to(device).unsqueeze(0)
443
  all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
444
  prompt = prompt_semantic.unsqueeze(0).to(device)
@@ -448,10 +485,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
448
  pred_semantic, idx = t2s_model.model.infer_panel(
449
  all_phoneme_ids,
450
  all_phoneme_len,
451
- prompt,
452
  bert,
453
  # prompt_phone_len=ph_offset,
454
- top_k=config["inference"]["top_k"],
 
 
455
  early_stop_num=hz * max_sec,
456
  )
457
  t3 = ttime()
@@ -648,12 +687,17 @@ with gr.Blocks(title=f"GPT-SoVITS WebUI") as app:
648
  value=i18n("凑四句一切"),
649
  interactive=True,
650
  )
 
 
 
 
 
651
  inference_button = gr.Button(i18n("合成语音"), variant="primary")
652
  output = gr.Audio(label=i18n("输出的语音"))
653
 
654
  inference_button.click(
655
  get_tts_wav,
656
- [select_ref, ref_text, prompt_language, text, text_language, how_to_cut],
657
  [output],
658
  )
659
 
 
245
 
246
 
247
  def clean_text_inf(text, language):
248
+ formattext = ""
249
+ language = language.replace("all_","")
250
+ for tmp in LangSegment.getTexts(text):
251
+ if language == "ja":
252
+ if tmp["lang"] == language or tmp["lang"] == "zh":
253
+ formattext += tmp["text"] + " "
254
+ continue
255
+ if tmp["lang"] == language:
256
+ formattext += tmp["text"] + " "
257
+ while " " in formattext:
258
+ formattext = formattext.replace(" ", " ")
259
+ phones, word2ph, norm_text = clean_text(formattext, language)
260
  phones = cleaned_text_to_sequence(phones)
261
  return phones, word2ph, norm_text
262
 
 
316
  print(langlist)
317
  bert_list = []
318
  for i in range(len(textlist)):
 
319
  lang = langlist[i]
320
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
321
  bert = get_bert_inf(phones, word2ph, norm_text, lang)
322
  bert_list.append(bert)
323
  bert = torch.cat(bert_list, dim=1)
 
352
  bert = torch.zeros((1024, len(phones))).to(device)
353
  return bert
354
 
355
+ def merge_short_text_in_array(texts, threshold):
356
+ if (len(texts)) < 2:
357
+ return texts
358
+ result = []
359
+ text = ""
360
+ for ele in texts:
361
+ text += ele
362
+ if len(text) >= threshold:
363
+ result.append(text)
364
+ text = ""
365
+ if (len(text) > 0):
366
+ if len(result) == 0:
367
+ result.append(text)
368
+ else:
369
+ result[len(result) - 1] += text
370
+ return result
371
+
372
  def get_ref_path_decor(func):
373
  # 为了hg部署添加的装饰函数,将参考文本的内容改为路径
374
  def inner(ref_wav_content, *args):
 
400
  text_to_audio_mappings, audio_to_text_mappings = load_audio_text_mappings(audio_folder_path, 'slicer_opt.list')
401
 
402
  @get_ref_path_decor
403
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
404
+ if prompt_text is None or len(prompt_text) == 0:
405
+ ref_free = True
406
  t0 = ttime()
407
+ prompt_language = dict_language[prompt_language]
408
+ text_language = dict_language[text_language]
409
+ if not ref_free:
410
+ prompt_text = prompt_text.strip("\n")
411
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
412
+ print(i18n("实际输入的参考文本:"), prompt_text)
413
  text = text.strip("\n")
414
  if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
415
+
416
  print(i18n("实际输入的目标文本:"), text)
417
  zero_wav = np.zeros(
418
  int(hps.data.sampling_rate * 0.3),
 
437
  1, 2
438
  ) # .float()
439
  codes = vq_model.extract_latent(ssl_content)
440
+
441
  prompt_semantic = codes[0, 0]
442
  t1 = ttime()
 
 
 
 
443
 
444
  if (how_to_cut == i18n("凑四句一切")):
445
  text = cut1(text)
 
451
  text = cut4(text)
452
  elif (how_to_cut == i18n("按标点符号切")):
453
  text = cut5(text)
454
+ while "\n\n" in text:
455
+ text = text.replace("\n\n", "\n")
456
  print(i18n("实际输入的目标文本(切句后):"), text)
457
  texts = text.split("\n")
458
+ texts = merge_short_text_in_array(texts, 5)
459
  audio_opt = []
460
+ if not ref_free:
461
+ phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
462
+ bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
463
 
464
  for text in texts:
465
  # 解决输入目标文本的空行导致报错的问题
 
467
  continue
468
  if (text[-1] not in splits): text += "。" if text_language != "en" else "."
469
  print(i18n("实际输入的目标文本(每句):"), text)
470
+ phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
471
  bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
472
+ if not ref_free:
473
+ bert = torch.cat([bert1, bert2], 1)
474
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
475
+ else:
476
+ bert = bert2
477
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
478
 
 
 
 
479
  bert = bert.to(device).unsqueeze(0)
480
  all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
481
  prompt = prompt_semantic.unsqueeze(0).to(device)
 
485
  pred_semantic, idx = t2s_model.model.infer_panel(
486
  all_phoneme_ids,
487
  all_phoneme_len,
488
+ None if ref_free else prompt,
489
  bert,
490
  # prompt_phone_len=ph_offset,
491
+ top_k=top_k,
492
+ top_p=top_p,
493
+ temperature=temperature,
494
  early_stop_num=hz * max_sec,
495
  )
496
  t3 = ttime()
 
687
  value=i18n("凑四句一切"),
688
  interactive=True,
689
  )
690
+ with gr.Row():
691
+ gr.Markdown("gpt采样参数(无参考文本时不要太低):")
692
+ top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
693
+ top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
694
+ temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
695
  inference_button = gr.Button(i18n("合成语音"), variant="primary")
696
  output = gr.Audio(label=i18n("输出的语音"))
697
 
698
  inference_button.click(
699
  get_tts_wav,
700
+ [select_ref, ref_text, prompt_language, text, text_language, how_to_cut, top_p, temperature],
701
  [output],
702
  )
703