lj1995 commited on
Commit
6bb3db9
·
verified ·
1 Parent(s): e7d9a60

Update inference_webui.py

Browse files
Files changed (1) hide show
  1. inference_webui.py +128 -125
inference_webui.py CHANGED
@@ -343,135 +343,138 @@ def merge_short_text_in_array(texts, threshold):
343
  cache= {}
344
  def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=123):
345
  global cache
346
- if ref_wav_path:pass
347
- else:gr.Warning(i18n('请上传参考音频'))
348
- if text:pass
349
- else:gr.Warning(i18n('请填入推理文本'))
350
- t = []
351
- if prompt_text is None or len(prompt_text) == 0:
352
- ref_free = True
353
- t0 = ttime()
354
- prompt_language = dict_language[prompt_language]
355
- text_language = dict_language[text_language]
356
-
357
-
358
- if not ref_free:
359
- prompt_text = prompt_text.strip("\n")
360
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
361
- print(i18n("实际输入的参考文本:"), prompt_text)
362
- text = text.strip("\n")
363
- if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
364
-
365
- print(i18n("实际输入的目标文本:"), text)
366
- zero_wav = np.zeros(
367
- int(hps.data.sampling_rate * 0.3),
368
- dtype=np.float16 if is_half == True else np.float32,
369
- )
370
- if not ref_free:
371
- with torch.no_grad():
372
- wav16k, sr = librosa.load(ref_wav_path, sr=16000)
373
- if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
374
- gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
375
- raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
376
- wav16k = torch.from_numpy(wav16k)
377
- zero_wav_torch = torch.from_numpy(zero_wav)
378
- if is_half == True:
379
- wav16k = wav16k.half().to(device)
380
- zero_wav_torch = zero_wav_torch.half().to(device)
381
- else:
382
- wav16k = wav16k.to(device)
383
- zero_wav_torch = zero_wav_torch.to(device)
384
- wav16k = torch.cat([wav16k, zero_wav_torch])
385
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
386
- "last_hidden_state"
387
- ].transpose(
388
- 1, 2
389
- ) # .float()
390
- codes = vq_model.extract_latent(ssl_content)
391
- prompt_semantic = codes[0, 0]
392
- prompt = prompt_semantic.unsqueeze(0).to(device)
393
-
394
- t1 = ttime()
395
- t.append(t1-t0)
396
-
397
- if (how_to_cut == i18n("凑四句一切")):
398
- text = cut1(text)
399
- elif (how_to_cut == i18n("凑50字一切")):
400
- text = cut2(text)
401
- elif (how_to_cut == i18n("按中文句号。切")):
402
- text = cut3(text)
403
- elif (how_to_cut == i18n("按英文句号.切")):
404
- text = cut4(text)
405
- elif (how_to_cut == i18n("按标点符号切")):
406
- text = cut5(text)
407
- while "\n\n" in text:
408
- text = text.replace("\n\n", "\n")
409
- print(i18n("实际输入的目标文本(切句后):"), text)
410
- texts = text.split("\n")
411
- texts = process_text(texts)
412
- texts = merge_short_text_in_array(texts, 5)
413
- audio_opt = []
414
- if not ref_free:
415
- phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
416
-
417
- for i_text,text in enumerate(texts):
418
- # 解决输入目标文本的空行导致报错的问题
419
- if (len(text.strip()) == 0):
420
- continue
421
- if (text[-1] not in splits): text += "。" if text_language != "en" else "."
422
- print(i18n("实际输入的目标文本(每句):"), text)
423
- phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
424
- print(i18n("前端处理后的文本(每句):"), norm_text2)
425
- if not ref_free:
426
- bert = torch.cat([bert1, bert2], 1)
427
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
428
- else:
429
- bert = bert2
430
- all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
431
 
432
- bert = bert.to(device).unsqueeze(0)
433
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
434
 
435
- t2 = ttime()
436
- # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
437
- # print(cache.keys(),if_freeze)
438
- if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
439
- else:
 
 
 
 
 
 
 
 
440
  with torch.no_grad():
441
- pred_semantic, idx = t2s_model.model.infer_panel(
442
- all_phoneme_ids,
443
- all_phoneme_len,
444
- None if ref_free else prompt,
445
- bert,
446
- # prompt_phone_len=ph_offset,
447
- top_k=top_k,
448
- top_p=top_p,
449
- temperature=temperature,
450
- early_stop_num=hz * max_sec,
451
- )
452
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
453
- cache[i_text]=pred_semantic
454
- t3 = ttime()
455
- refers=[]
456
- if(inp_refs):
457
- for path in inp_refs:
458
- try:
459
- refer = get_spepc(hps, path.name).to(dtype).to(device)
460
- refers.append(refer)
461
- except:
462
- traceback.print_exc()
463
- if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
464
- audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0])
465
- max_audio=np.abs(audio).max()#简单防止16bit爆音
466
- if max_audio>1:audio/=max_audio
467
- audio_opt.append(audio)
468
- audio_opt.append(zero_wav)
469
- t4 = ttime()
470
- t.extend([t2 - t1,t3 - t2, t4 - t3])
471
  t1 = ttime()
472
- print("%.3f\t%.3f\t%.3f\t%.3f" %
473
- (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
474
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
476
  np.int16
477
  )
 
343
  cache= {}
344
  def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=123):
345
  global cache
346
+ try:
347
+ if ref_wav_path:pass
348
+ else:gr.Warning(i18n('请上传参考音频'))
349
+ if text:pass
350
+ else:gr.Warning(i18n('请填入推理文本'))
351
+ t = []
352
+ if prompt_text is None or len(prompt_text) == 0:
353
+ ref_free = True
354
+ t0 = ttime()
355
+ prompt_language = dict_language[prompt_language]
356
+ text_language = dict_language[text_language]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
 
 
358
 
359
+ if not ref_free:
360
+ prompt_text = prompt_text.strip("\n")
361
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
362
+ print(i18n("实际输入的参考文本:"), prompt_text)
363
+ text = text.strip("\n")
364
+ if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
365
+
366
+ print(i18n("实际输入的目标文本:"), text)
367
+ zero_wav = np.zeros(
368
+ int(hps.data.sampling_rate * 0.3),
369
+ dtype=np.float16 if is_half == True else np.float32,
370
+ )
371
+ if not ref_free:
372
  with torch.no_grad():
373
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
374
+ if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
375
+ gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
376
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
377
+ wav16k = torch.from_numpy(wav16k)
378
+ zero_wav_torch = torch.from_numpy(zero_wav)
379
+ if is_half == True:
380
+ wav16k = wav16k.half().to(device)
381
+ zero_wav_torch = zero_wav_torch.half().to(device)
382
+ else:
383
+ wav16k = wav16k.to(device)
384
+ zero_wav_torch = zero_wav_torch.to(device)
385
+ wav16k = torch.cat([wav16k, zero_wav_torch])
386
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
387
+ "last_hidden_state"
388
+ ].transpose(
389
+ 1, 2
390
+ ) # .float()
391
+ codes = vq_model.extract_latent(ssl_content)
392
+ prompt_semantic = codes[0, 0]
393
+ prompt = prompt_semantic.unsqueeze(0).to(device)
394
+
 
 
 
 
 
 
 
 
395
  t1 = ttime()
396
+ t.append(t1-t0)
397
+
398
+ if (how_to_cut == i18n("凑四句一切")):
399
+ text = cut1(text)
400
+ elif (how_to_cut == i18n("凑50字一切")):
401
+ text = cut2(text)
402
+ elif (how_to_cut == i18n("按中文句号。切")):
403
+ text = cut3(text)
404
+ elif (how_to_cut == i18n("按英文句号.切")):
405
+ text = cut4(text)
406
+ elif (how_to_cut == i18n("按标点符号切")):
407
+ text = cut5(text)
408
+ while "\n\n" in text:
409
+ text = text.replace("\n\n", "\n")
410
+ print(i18n("实际输入的目标文本(切句后):"), text)
411
+ texts = text.split("\n")
412
+ texts = process_text(texts)
413
+ texts = merge_short_text_in_array(texts, 5)
414
+ audio_opt = []
415
+ if not ref_free:
416
+ phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
417
+
418
+ for i_text,text in enumerate(texts):
419
+ # 解决输入目标文本的空行导致报错的问题
420
+ if (len(text.strip()) == 0):
421
+ continue
422
+ if (text[-1] not in splits): text += "。" if text_language != "en" else "."
423
+ print(i18n("实际输入的目标文本(每句):"), text)
424
+ phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
425
+ print(i18n("前端处理后的文本(每句):"), norm_text2)
426
+ if not ref_free:
427
+ bert = torch.cat([bert1, bert2], 1)
428
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
429
+ else:
430
+ bert = bert2
431
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
432
+
433
+ bert = bert.to(device).unsqueeze(0)
434
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
435
+
436
+ t2 = ttime()
437
+ # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
438
+ # print(cache.keys(),if_freeze)
439
+ if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
440
+ else:
441
+ with torch.no_grad():
442
+ pred_semantic, idx = t2s_model.model.infer_panel(
443
+ all_phoneme_ids,
444
+ all_phoneme_len,
445
+ None if ref_free else prompt,
446
+ bert,
447
+ # prompt_phone_len=ph_offset,
448
+ top_k=top_k,
449
+ top_p=top_p,
450
+ temperature=temperature,
451
+ early_stop_num=hz * max_sec,
452
+ )
453
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
454
+ cache[i_text]=pred_semantic
455
+ t3 = ttime()
456
+ refers=[]
457
+ if(inp_refs):
458
+ for path in inp_refs:
459
+ try:
460
+ refer = get_spepc(hps, path.name).to(dtype).to(device)
461
+ refers.append(refer)
462
+ except:
463
+ traceback.print_exc()
464
+ if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
465
+ audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0])
466
+ max_audio=np.abs(audio).max()#简单防止16bit爆音
467
+ if max_audio>1:audio/=max_audio
468
+ audio_opt.append(audio)
469
+ audio_opt.append(zero_wav)
470
+ t4 = ttime()
471
+ t.extend([t2 - t1,t3 - t2, t4 - t3])
472
+ t1 = ttime()
473
+ print("%.3f\t%.3f\t%.3f\t%.3f" %
474
+ (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
475
+ )
476
+ except:
477
+ print(traceback.format_exc())
478
  yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
479
  np.int16
480
  )