Spaces:
Build error
Build error
import gradio as gr | |
# import matplotlib.pyplot as plt | |
import logging | |
# logger = logging.getLogger(__name__) | |
import os | |
import json | |
import math | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
import commons | |
import utils | |
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate | |
from models import SynthesizerTrn | |
from text.symbols import symbols | |
from text import text_to_sequence | |
import time | |
def get_text(text, hps): | |
# text_norm = requests.post("http://121.5.171.42:39001/texttosequence?text="+text).json()["text_norm"] | |
text_norm = text_to_sequence(text, hps.data.text_cleaners) | |
# print(hps.data.text_cleaners) | |
# print(text_norm) | |
if hps.data.add_blank: | |
text_norm = commons.intersperse(text_norm, 0) | |
text_norm = torch.LongTensor(text_norm) | |
return text_norm | |
def load_model(config_path, pth_path): | |
global dev, hps, net_g | |
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
hps = utils.get_hparams_from_file(config_path) | |
net_g = SynthesizerTrn( | |
len(symbols), | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
n_speakers=hps.data.n_speakers, | |
**hps.model).to(dev) | |
_ = net_g.eval() | |
_ = utils.load_checkpoint(pth_path, net_g) | |
print(f"{pth_path}加载成功!") | |
def infer(c_id, text): | |
if c_id not in list(range[1, 14]): | |
raise gr.Error("角色id超出范围!") | |
print(c_id) | |
stn_tst = get_text(text, hps) | |
with torch.no_grad(): | |
x_tst = stn_tst.to(dev).unsqueeze(0) | |
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev) | |
sid = torch.LongTensor([c_id]).to(dev) | |
audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy() | |
return (hps.data.sampling_rate, audio) | |
pth_path = "model/G_70000.pth" | |
config_path = "configs/config.json" | |
character_dict = { | |
"十香": 1, | |
"折纸": 2, | |
"狂三": 3, | |
"四糸乃": 4, | |
"琴里": 5, | |
"夕弦": 6, | |
"耶俱矢": 7, | |
"美九": 8, | |
"凛祢": 9, | |
"凛绪": 10, | |
"鞠亚": 11, | |
"鞠奈": 12, | |
"真那": 13, | |
} | |
load_model(config_path, pth_path) | |
app = gr.Blocks() | |
with app: | |
with gr.Tabs(): | |
with gr.Row(): | |
tts_input1 = gr.TextArea( | |
label="请输入文本(仅支持日语)", value="こんにちは,世界!") | |
with gr.Row(): | |
tts_input2 = gr.TextArea( | |
label="请输入角色id(参考文档或者页面下方表格)") | |
with gr.Row(): | |
tts_submit = gr.Button("用文本合成", variant="primary") | |
with gr.Row(): | |
tts_output2 = gr.Audio(label="Output") | |
# model_submit.click(load_model, [config_path, pth_path]) | |
tts_submit.click(infer, [tts_input2, tts_input1], [tts_output2]) | |
gr.Markdown( | |
""" | |
| id | 角色名 | | |
|--|--| | |
| 1 | 夜刀神十香 | | |
| 2 | 鸢一折纸 | | |
| 3 | 时崎狂三 | | |
| 4 | 冰芽川四糸乃 | | |
| 5 | 五河琴里 | | |
| 6 | 八舞夕弦 | | |
| 7 | 八舞耶俱矢 | | |
| 8 | 诱宵美九 | | |
| 9 | 园神凛祢 | | |
| 10 | 园神凛绪 | | |
| 11 | 或守鞠亚 | | |
| 12 | 或守鞠奈 | | |
| 13 | 崇宫真那 | | |
""" | |
) | |
gr.HTML(""" | |
<div style="text-align:center"> | |
<h4 class="h-sign" style="font-size: 12px;"> | |
这是一个使用<a href="https://github.com/thesupersonic16/DALTools" target="_blank">thesupersonic16/DALTools</a>提供的解包音频作为数据集, | |
使用<a href="https://github.com/jaywalnut310/vits" target="_blank">VITS</a>技术训练的语音合成demo。 | |
</h4> | |
</div> | |
</div> | |
<div style="text-align:center"> | |
仅供学习交流,不可用于商业或非法用途 | |
<br/> | |
使用本项目模型直接或间接生成的音频,必须声明由AI技术或VITS技术合成 | |
</div> | |
""") | |
app.launch() | |