hzrr commited on
Commit
7c27bc6
·
1 Parent(s): a87d03d
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -29,30 +29,32 @@ def get_text(text, hps):
29
  return text_norm
30
 
31
  def load_model(config_path, pth_path):
32
- global dev, hps_ms, net_g
33
  dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
- hps_ms = utils.get_hparams_from_file(config_path)
35
 
36
  net_g = SynthesizerTrn(
37
  len(symbols),
38
- hps_ms.data.filter_length // 2 + 1,
39
- hps_ms.train.segment_size // hps_ms.data.hop_length,
40
- **hps_ms.model).to(dev)
 
41
  _ = net_g.eval()
 
42
  _ = utils.load_checkpoint(pth_path, net_g)
43
 
44
  print(f"{pth_path}加载成功!")
45
 
46
  def infer(text):
47
  c_id = 2
48
- stn_tst = get_text(text, hps_ms)
49
  with torch.no_grad():
50
  x_tst = stn_tst.to(dev).unsqueeze(0)
51
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
52
  sid = torch.LongTensor([c_id]).to(dev)
53
  audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
54
 
55
- return (hps_ms.data.sampling_rate, audio)
56
 
57
  pth_path = "model/G_70000.pth"
58
  config_path = "configs/config.json"
@@ -77,9 +79,6 @@ load_model(config_path, pth_path)
77
  app = gr.Blocks()
78
  with app:
79
  gr.HTML("""
80
- <div
81
- style="width: 100%;padding-top:116px;background-image: url('https://huggingface.co/spaces/tumuyan/vits-miki/resolve/main/bg.webp');;background-size:cover">
82
- <div>
83
  <div>
84
  <h4 class="h-sign" style="font-size: 12px;">
85
  这是一个使用<a href="https://github.com/thesupersonic16/DALTools" target="_blank">thesupersonic16/DALTools</a>提供的解包音频作为数据集,
@@ -87,19 +86,17 @@ with app:
87
  </h4>
88
  </div>
89
  </div>
90
- </div>
91
  """)
92
  tmp = gr.Markdown("")
93
  with gr.Tabs():
94
- with gr.TabItem("Basic"):
95
  # with gr.Row():
96
  # model_submit = gr.Button("加载/重载模型", variant="primary")
97
 
98
- with gr.Row():
99
- tts_input1 = gr.TextArea(
100
- label="请输入文本(仅支持日语)", value="你好,世界!")
101
- tts_submit = gr.Button("用文本合成", variant="primary")
102
- tts_output2 = gr.Audio(label="Output")
103
  # model_submit.click(load_model, [config_path, pth_path])
104
  tts_submit.click(infer, [tts_input1], [tts_output2])
105
  gr.HTML("""
 
29
  return text_norm
30
 
31
  def load_model(config_path, pth_path):
32
+ global dev, hps, net_g
33
  dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+ hps = utils.get_hparams_from_file(config_path)
35
 
36
  net_g = SynthesizerTrn(
37
  len(symbols),
38
+ hps.data.filter_length // 2 + 1,
39
+ hps.train.segment_size // hps.data.hop_length,
40
+ n_speakers=hps.data.n_speakers,
41
+ **hps.model).to(dev)
42
  _ = net_g.eval()
43
+
44
  _ = utils.load_checkpoint(pth_path, net_g)
45
 
46
  print(f"{pth_path}加载成功!")
47
 
48
  def infer(text):
49
  c_id = 2
50
+ stn_tst = get_text(text, hps)
51
  with torch.no_grad():
52
  x_tst = stn_tst.to(dev).unsqueeze(0)
53
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
54
  sid = torch.LongTensor([c_id]).to(dev)
55
  audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
56
 
57
+ return (hps.data.sampling_rate, audio)
58
 
59
  pth_path = "model/G_70000.pth"
60
  config_path = "configs/config.json"
 
79
  app = gr.Blocks()
80
  with app:
81
  gr.HTML("""
 
 
 
82
  <div>
83
  <h4 class="h-sign" style="font-size: 12px;">
84
  这是一个使用<a href="https://github.com/thesupersonic16/DALTools" target="_blank">thesupersonic16/DALTools</a>提供的解包音频作为数据集,
 
86
  </h4>
87
  </div>
88
  </div>
 
89
  """)
90
  tmp = gr.Markdown("")
91
  with gr.Tabs():
 
92
  # with gr.Row():
93
  # model_submit = gr.Button("加载/重载模型", variant="primary")
94
 
95
+ with gr.Row():
96
+ tts_input1 = gr.TextArea(
97
+ label="请输入文本(仅支持日语)", value="你好,世界!")
98
+ tts_submit = gr.Button("用文本合成", variant="primary")
99
+ tts_output2 = gr.Audio(label="Output")
100
  # model_submit.click(load_model, [config_path, pth_path])
101
  tts_submit.click(infer, [tts_input1], [tts_output2])
102
  gr.HTML("""