hzrr commited on
Commit
06441c0
·
1 Parent(s): 0ba5c8c
Files changed (2) hide show
  1. app.py +63 -9
  2. inference.py +0 -60
app.py CHANGED
@@ -1,10 +1,60 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
 
3
- from inference import load_model, local_run, get_text
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  pth_path = "model/G_70000.pth"
7
- config_json = "configs/config.json"
8
  character_dict = {
9
  "十香": 1,
10
  "折纸": 2,
@@ -39,13 +89,17 @@ with app:
39
  tmp = gr.Markdown("")
40
  with gr.Tabs():
41
  with gr.TabItem("Basic"):
42
- with gr.Row():
43
- choice_model = gr.Dropdown(
44
- choices=[character_dict.keys()], label="模型", value=[character_dict.values()], visible=False)
45
-
46
- with gr.TabItem("Audios"):
47
-
48
- pass
 
 
 
 
49
  gr.HTML("""
50
  <div style="text-align:center">
51
  仅供学习交流,不可用于商业或非法用途
 
1
  import gradio as gr
2
+ # import matplotlib.pyplot as plt
3
+ import logging
4
+ # logger = logging.getLogger(__name__)
5
+ import os
6
+ import json
7
+ import math
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torch.utils.data import DataLoader
12
 
13
+ import commons
14
+ import utils
15
+ from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
16
+ from models import SynthesizerTrn
17
+ from text.symbols import symbols
18
+ from text import text_to_sequence
19
+ import time
20
 
21
+ def get_text(text, hps):
22
+ # text_norm = requests.post("http://121.5.171.42:39001/texttosequence?text="+text).json()["text_norm"]
23
+ text_norm = text_to_sequence(text, hps.data.text_cleaners)
24
+ # print(hps.data.text_cleaners)
25
+ # print(text_norm)
26
+ if hps.data.add_blank:
27
+ text_norm = commons.intersperse(text_norm, 0)
28
+ text_norm = torch.LongTensor(text_norm)
29
+ return text_norm
30
+
31
+ def load_model(config_path, pth_path):
32
+ global dev, hps_ms, net_g
33
+ dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+ hps_ms = utils.get_hparams_from_file(config_path)
35
+
36
+ net_g = SynthesizerTrn(
37
+ len(symbols),
38
+ hps_ms.data.filter_length // 2 + 1,
39
+ hps_ms.train.segment_size // hps_ms.data.hop_length,
40
+ **hps_ms.model).to(dev)
41
+ _ = net_g.eval()
42
+ _ = utils.load_checkpoint(pth_path, net_g)
43
+
44
+ return f"{pth_path}加载成功!"
45
+
46
+ def infer(c_id, text):
47
+ stn_tst = get_text(text, hps_ms)
48
+ with torch.no_grad():
49
+ x_tst = stn_tst.to(dev).unsqueeze(0)
50
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
51
+ sid = torch.LongTensor([c_id]).to(dev)
52
+ audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
53
+
54
+ return audio
55
 
56
  pth_path = "model/G_70000.pth"
57
+ config_path = "configs/config.json"
58
  character_dict = {
59
  "十香": 1,
60
  "折纸": 2,
 
89
  tmp = gr.Markdown("")
90
  with gr.Tabs():
91
  with gr.TabItem("Basic"):
92
+ with gr.Raw():
93
+ model_submit = gr.Button("加载/重载模型", variant="primary")
94
+ output_1 = gr.Markdown("")
95
+ with gr.Raw():
96
+ tts_input1 = gr.TextArea(
97
+ label="请输入文本(仅支持日语)", value="你好,世界!")
98
+ tts_input2 = gr.Dropdown(choices=[character_dict.keys], type="index",label="选择角色", optional=False)
99
+ tts_submit = gr.Button("用文本合成", variant="primary")
100
+ tts_output2 = gr.Audio(label="Output")
101
+ model_submit.click(load_model, [config_path, pth_path], [output_1])
102
+ tts_submit.click(infer, [tts_input2+1, tts_input1], [tts_output2])
103
  gr.HTML("""
104
  <div style="text-align:center">
105
  仅供学习交流,不可用于商业或非法用途
inference.py CHANGED
@@ -1,60 +0,0 @@
1
- # import matplotlib.pyplot as plt
2
- import logging
3
- # logger = logging.getLogger(__name__)
4
- import os
5
- import json
6
- import math
7
- import torch
8
- from torch import nn
9
- from torch.nn import functional as F
10
- from torch.utils.data import DataLoader
11
-
12
- import commons
13
- import utils
14
- from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
15
- from models import SynthesizerTrn
16
- from text.symbols import symbols
17
- from text import text_to_sequence
18
- import time
19
-
20
- def get_text(text, hps):
21
- # text_norm = requests.post("http://121.5.171.42:39001/texttosequence?text="+text).json()["text_norm"]
22
- text_norm = text_to_sequence(text, hps.data.text_cleaners)
23
- # print(hps.data.text_cleaners)
24
- # print(text_norm)
25
- if hps.data.add_blank:
26
- text_norm = commons.intersperse(text_norm, 0)
27
- text_norm = torch.LongTensor(text_norm)
28
- return text_norm
29
-
30
- def load_model(config_json, pth_path):
31
- dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
- hps_ms = utils.get_hparams_from_file(f"./configs/{config_json}")
33
-
34
- global net_g
35
- net_g = SynthesizerTrn(
36
- len(symbols),
37
- hps_ms.data.filter_length // 2 + 1,
38
- hps_ms.train.segment_size // hps_ms.data.hop_length,
39
- **hps_ms.model).to(dev)
40
- _ = net_g.eval()
41
- _ = utils.load_checkpoint(pth_path, net_g)
42
-
43
- print("load_model:"+pth_path)
44
- return net_g
45
-
46
- def local_run(c_id, text):
47
- stn_tst = get_text(text, hps)
48
- with torch.no_grad():
49
- x_tst = stn_tst.to(dev).unsqueeze(0)
50
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
51
- sid = torch.LongTensor([c_id]).to(dev)
52
- audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
53
-
54
- return audio
55
-
56
- CONFIG_FILE = "configs/config.json"
57
-
58
- dev = torch.device("cpu")
59
- hps = utils.get_hparams_from_file(CONFIG_FILE)
60
-