hzrr commited on
Commit
f7c1f1e
·
1 Parent(s): 69c7b60
Files changed (4) hide show
  1. api.py +0 -176
  2. app.py +53 -4
  3. inference.py +16 -35
  4. test.py +39 -0
api.py DELETED
@@ -1,176 +0,0 @@
1
- #encoding=utf-8
2
- from inference import api_run, load_model
3
- from flask import Flask, request
4
- # from werkzeug.middleware.proxy_fix import ProxyFix
5
- import json
6
- import logging
7
- import datetime
8
- import requests
9
- import random
10
- import hashlib
11
-
12
- app = Flask(__name__)
13
- # app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=1)
14
-
15
- def JP_mode(text):
16
-
17
- return text
18
-
19
- def ZH_mode(text):
20
-
21
- salt = str(random.randint(0, 10))
22
- content = appid + text + salt + key
23
- md5hash = hashlib.md5(content.encode("utf8"))
24
- md5 = md5hash.hexdigest()
25
-
26
- params = {
27
- "q": text,
28
- "from": "zh",
29
- "to": "jp",
30
- "appid": appid,
31
- "salt": salt,
32
- "sign": md5
33
- }
34
-
35
- resp = requests.get(url, params=params).json()
36
-
37
- return resp["trans_result"][0]["dst"]
38
-
39
- def PY_mode(text):
40
-
41
- pass
42
-
43
- @app.route("/inference", methods=["POST", "GET"])
44
- def check():
45
- # 默认返回内容
46
- return_dict = {"code": 200, "return_info": "NULL"}
47
- # 获取用户参数
48
- request_data = {}
49
- for i in request.values:
50
- request_data.update({i: request.values.get(i)})
51
- ip = request.remote_addr
52
- logger.info(f"[{ip}]Post Data: {str(request_data)}")
53
- # 判断参数是否符合要求
54
- if request_data == {}:
55
- return_dict["code"] = 500
56
- return_dict["return_info"] = "参数不能为空!"
57
- return json.dumps(return_dict, ensure_ascii=False)
58
- try:
59
- # print(request_data)
60
- text = request_data["text"]
61
- c_id = int(request_data["id"])
62
- mode = request_data["mode"]
63
- if not (mode in mode_dict.keys() and c_id in c_id_dict.keys()):
64
- return_dict["code"] = "500"
65
- return_dict["return_info"] = "参数错误!请参考文档提供正确参数"
66
- return json.dumps(return_dict, ensure_ascii=False)
67
- except KeyError:
68
- return_dict["code"] = 500
69
- return_dict["return_info"] = "参数不全!请参考文档提供正确参数"
70
- return json.dumps(return_dict, ensure_ascii=False)
71
- if text == "":
72
- return_dict["code"] = 500
73
- return_dict["return_info"] = "文本不能为空!"
74
- return json.dumps(return_dict, ensure_ascii=False)
75
- if len("".join(text.split())) > 22:
76
- return_dict["code"] = 500
77
- return_dict["return_info"] = "文本过长!"
78
- return json.dumps(return_dict, ensure_ascii=False)
79
-
80
- for i in replace_dict:
81
-
82
- text.replace(i, replace_dict[i])
83
-
84
- try:
85
- text = "."+mode_dict[mode](text)+"."
86
- url = api_run(c_id_dict[c_id], text)
87
- logger.info("Audio Url:"+url)
88
- return json.dumps({"code": 200, "url": url}, ensure_ascii=False)
89
- except Exception as e:
90
- return json.dumps({"code": 500, "return_info": repr(e)}, ensure_ascii=False)
91
-
92
- if __name__ == "__main__":
93
-
94
- load_model()
95
- mode_dict = {
96
- "JP": JP_mode,
97
- "ZH": ZH_mode,
98
- # "PY": PY_mode
99
- }
100
- c_id_dict = {
101
- 1: 1,
102
- 2: 2,
103
- 3: 3,
104
- 4: 4,
105
- 5: 5,
106
- 6: 6,
107
- 7: 7,
108
- 8: 8,
109
- 9: 9,
110
- 10: 11,
111
- 11: 12,
112
- 12: 13,
113
- 13: 14,
114
- }
115
- replace_dict ={
116
- "鸢一": "とびいち",
117
- "折纸": "おりがみ",
118
- "本条": "ほんじょう",
119
- "二亚": "にあ",
120
- "时崎": "ときさき",
121
- "狂三": "くるみ",
122
- "冰芽川": "ひめかわ",
123
- "四糸乃": "よしの",
124
- "五河": "いつか",
125
- "琴里": "ことり",
126
- "士道": "しどう",
127
- "星宫": "ほしみや",
128
- "六喰": "むくろ)",
129
- "镜野": "きょうの)",
130
- "七罪": "なつみ",
131
- "风待": "かざまち",
132
- "八舞": "やまい",
133
- "夕弦": "ゆづる",
134
- "耶俱矢": "かぐや",
135
- "诱宵": "いざよい",
136
- "美九": "みく",
137
- "夜刀神": "やとがみ",
138
- "十香": "とおか",
139
- "天香": "てんか",
140
- "園神": "そのがみ",
141
- "园神": "そのがみ",
142
- "凛祢": "りんね",
143
- "凛绪": "りお",
144
- "或守": "あるす",
145
- "鞠奈": "まりな",
146
- "鞠亜": "まりあ",
147
- "鞠亚": "まりあ",
148
- }
149
-
150
- # 百度翻译接口
151
- url = "http://api.fanyi.baidu.com/api/trans/vip/translate"
152
- appid = "20221004001369403"
153
- key = "2366SRyKMe4HDAfcD4a9"
154
-
155
-
156
- now = datetime.datetime.now().strftime("%Y-%m-%d-%H")
157
- logger = logging.getLogger("InferenceAPI")
158
- # logging.basicConfig(filename=str(now)+".log", filemode="a", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", level=logging.DEBUG)
159
- handler = logging.FileHandler(filename=f"logs/{str(now)}.log")# , encoding="utf-8", format="%(asctime)s%(name)s%(levelname)s:%(message)s", level=logging.DEBUG)
160
- handler.setLevel(logging.DEBUG)
161
- formatter = logging.Formatter("%(asctime)s %(name)s%(levelname)s:%(message)s")
162
- handler.setFormatter(formatter)
163
- console = logging.StreamHandler()
164
- console.setLevel(logging.DEBUG)
165
-
166
- logger.addHandler(handler)
167
- logger.addHandler(console)
168
-
169
- logger.info("successfully load model...")
170
-
171
- app.config["JSON_AS_ASCII"] = False
172
- app.config["DEBUG"] = False
173
- app.config["ENV"] = "development"
174
- app.run(port=39000, host="0.0.0.0")
175
-
176
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,7 +1,56 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ from inference import load_model, local_run, get_text
 
4
 
5
+
6
+ pth_path = "model/G_70000.pth"
7
+ config_json = "configs/config.json"
8
+ character_dict = {
9
+ "十香": 1,
10
+ "折纸": 2,
11
+ "狂三": 3,
12
+ "四糸乃": 4,
13
+ "琴里": 5,
14
+ "夕弦": 6,
15
+ "耶俱矢": 7,
16
+ "美九": 8,
17
+ "凛祢": 9,
18
+ "凛绪": 10,
19
+ "鞠亚": 11,
20
+ "鞠奈": 12,
21
+ "真那": 13,
22
+ }
23
+
24
+ app = gr.Blocks()
25
+ with app:
26
+ gr.HTML("""
27
+ <div
28
+ style="width: 100%;padding-top:116px;background-image: url('https://huggingface.co/spaces/tumuyan/vits-miki/resolve/main/bg.webp');;background-size:cover">
29
+ <div>
30
+ <div>
31
+ <h4 class="h-sign" style="font-size: 12px;">
32
+ 这是一个使用<a href="https://github.com/thesupersonic16/DALTools" target="_blank">thesupersonic16/DALTools</a>提供的解包音频作为数据集,
33
+ 使用<a href="https://github.com/jaywalnut310/vits" target="_blank">VITS</a>技术训练的语音合成demo。
34
+ </h4>
35
+ </div>
36
+ </div>
37
+ </div>
38
+ """)
39
+ tmp = gr.Markdown("")
40
+ with gr.Tabs():
41
+ with gr.TabItem("Basic"):
42
+ with gr.Row():
43
+ choice_model = gr.Dropdown(
44
+ choices=character_dict.keys(), label="模型", value=character_dict.values(), visible=False)
45
+
46
+ with gr.TabItem("Audios"):
47
+
48
+ pass
49
+ gr.HTML("""
50
+ <div style="text-align:center">
51
+ 仅供学习交流,不可用于商业或非法用途
52
+ <br/>
53
+ 使用本项目模型直接或间接生成的音频,必须声明由AI技术或VITS技术合成
54
+ </div>
55
+ """)
56
+ app.launch()
inference.py CHANGED
@@ -16,7 +16,6 @@ from models import SynthesizerTrn
16
  from text.symbols import symbols
17
  from text import text_to_sequence
18
  import time
19
- from scipy.io.wavfile import write
20
 
21
  def get_text(text, hps):
22
  # text_norm = requests.post("http://121.5.171.42:39001/texttosequence?text="+text).json()["text_norm"]
@@ -28,52 +27,34 @@ def get_text(text, hps):
28
  text_norm = torch.LongTensor(text_norm)
29
  return text_norm
30
 
31
- def load_model():
 
 
32
 
33
  global net_g
34
  net_g = SynthesizerTrn(
35
  len(symbols),
36
- hps.data.filter_length // 2 + 1,
37
- hps.train.segment_size // hps.data.hop_length,
38
- n_speakers=hps.data.n_speakers,
39
- **hps.model).to(device)
40
  _ = net_g.eval()
41
-
42
- _ = utils.load_checkpoint(MODEL_FILE, net_g, None)
 
 
43
 
44
  def local_run(c_id, text):
45
  stn_tst = get_text(text, hps)
46
  with torch.no_grad():
47
- x_tst = stn_tst.to(device).unsqueeze(0)
48
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
49
- sid = torch.LongTensor([c_id]).to(device)
50
- audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
51
- file_name = str(time.time()).split(".")[0] + ".wav"
52
- out_path = WAVSPATH + "/" + file_name
53
- write(out_path, hps.data.sampling_rate, audio)
54
- return "http://datealive.xyz/vits/wavs/" + file_name
55
-
56
- def api_run(c_id, text):
57
-
58
- stn_tst = get_text(text, hps)
59
- with torch.no_grad():
60
- x_tst = stn_tst.to(device).unsqueeze(0)
61
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
62
- sid = torch.LongTensor([c_id]).to(device)
63
  audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
64
- file_name = str(time.time()).split(".")[0] + ".wav"
65
- out_path = WAVSPATH + "/" + file_name
66
- write(out_path, hps.data.sampling_rate, audio)
67
- write("/www/wwwroot/main-website/wavs/"+file_name, hps.data.sampling_rate, audio)
68
-
69
- # return "http://datealive.xyz/vits/wavs/" + file_name
70
-
71
- return "http://hzrr.xyz/wavs/" + file_name
72
 
73
  CONFIG_FILE = "configs/config.json"
74
- MODEL_FILE = "model/DAL.pth"
75
- WAVSPATH = "/www/wwwroot/datealive.xyz/vits/wavs"
76
 
77
- device = torch.device("cpu")
78
  hps = utils.get_hparams_from_file(CONFIG_FILE)
79
 
 
16
  from text.symbols import symbols
17
  from text import text_to_sequence
18
  import time
 
19
 
20
  def get_text(text, hps):
21
  # text_norm = requests.post("http://121.5.171.42:39001/texttosequence?text="+text).json()["text_norm"]
 
27
  text_norm = torch.LongTensor(text_norm)
28
  return text_norm
29
 
30
+ def load_model(config_json, pth_path):
31
+ dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+ hps_ms = utils.get_hparams_from_file(f"./configs/{config_json}")
33
 
34
  global net_g
35
  net_g = SynthesizerTrn(
36
  len(symbols),
37
+ hps_ms.data.filter_length // 2 + 1,
38
+ hps_ms.train.segment_size // hps_ms.data.hop_length,
39
+ **hps_ms.model).to(dev)
 
40
  _ = net_g.eval()
41
+ _ = utils.load_checkpoint(pth_path, net_g)
42
+
43
+ print("load_model:"+pth_path)
44
+ return net_g
45
 
46
  def local_run(c_id, text):
47
  stn_tst = get_text(text, hps)
48
  with torch.no_grad():
49
+ x_tst = stn_tst.to(dev).unsqueeze(0)
50
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
51
+ sid = torch.LongTensor([c_id]).to(dev)
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
53
+
54
+ return audio
 
 
 
 
 
 
55
 
56
  CONFIG_FILE = "configs/config.json"
 
 
57
 
58
+ dev = torch.device("cpu")
59
  hps = utils.get_hparams_from_file(CONFIG_FILE)
60
 
test.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ #from inference import load_model, local_run
4
+
5
+
6
+ pth_path = "model/G_70000.pth"
7
+ config_json = "configs/config.json"
8
+
9
+ app = gr.Blocks()
10
+ with app:
11
+ gr.HTML("""
12
+ <div
13
+ style="width: 100%;padding-top:116px;background-image: url('https://huggingface.co/spaces/tumuyan/vits-miki/resolve/main/bg.webp');;background-size:cover">
14
+ <div>
15
+ <div>
16
+ <h4 class="h-sign" style="font-size: 12px;">
17
+ 这是一个使用<a href="https://github.com/thesupersonic16/DALTools" target="_blank">thesupersonic16/DALTools</a>提供的解包音频作为数据集,
18
+ 使用<a href="https://github.com/jaywalnut310/vits" target="_blank">VITS</a>技术训练的语音合成demo。
19
+ </h4>
20
+ </div>
21
+ </div>
22
+ </div>
23
+ """)
24
+ tmp = gr.Markdown("")
25
+ with gr.Tabs():
26
+ with gr.TabItem("Basic"):
27
+
28
+
29
+ with gr.TabItem("Audios"):
30
+ pass
31
+
32
+ gr.HTML("""
33
+ <div style="text-align:center">
34
+ 仅供学习交流,不可用于商业或非法用途
35
+ <br/>
36
+ 使用本项目模型直接或间接生成的音频,必须声明由AI技术或VITS技术合成
37
+ </div>
38
+ """)
39
+ app.launch()