File size: 3,886 Bytes
69c7b60
06441c0
 
 
 
 
 
 
 
 
 
69c7b60
06441c0
 
 
 
 
 
 
69c7b60
06441c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7c1f1e
 
06441c0
f7c1f1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06441c0
 
 
 
 
 
 
 
 
 
 
f7c1f1e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
# import matplotlib.pyplot as plt
import logging
# logger = logging.getLogger(__name__)
import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
import time

def get_text(text, hps):
    # text_norm = requests.post("http://121.5.171.42:39001/texttosequence?text="+text).json()["text_norm"]
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    # print(hps.data.text_cleaners)
    # print(text_norm)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

def load_model(config_path, pth_path):
    global dev, hps_ms, net_g
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    hps_ms = utils.get_hparams_from_file(config_path)

    net_g = SynthesizerTrn(
        len(symbols),
        hps_ms.data.filter_length // 2 + 1,
        hps_ms.train.segment_size // hps_ms.data.hop_length,
        **hps_ms.model).to(dev)
    _ = net_g.eval()
    _ = utils.load_checkpoint(pth_path, net_g)

    return f"{pth_path}加载成功!"
    
def infer(c_id, text):
    stn_tst = get_text(text, hps_ms)
    with torch.no_grad():
        x_tst = stn_tst.to(dev).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
        sid = torch.LongTensor([c_id]).to(dev)
        audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()

        return audio

pth_path = "model/G_70000.pth"
config_path = "configs/config.json"
character_dict = {
    "十香": 1,
    "折纸": 2,
    "狂三": 3,
    "四糸乃": 4,
    "琴里": 5,
    "夕弦": 6,
    "耶俱矢": 7,
    "美九": 8,
    "凛祢": 9,
    "凛绪": 10,
    "鞠亚": 11,
    "鞠奈": 12,
    "真那": 13,
}

app = gr.Blocks()
with app:
    gr.HTML("""
<div
    style="width: 100%;padding-top:116px;background-image: url('https://huggingface.co/spaces/tumuyan/vits-miki/resolve/main/bg.webp');;background-size:cover">
    <div>
                <div>
                    <h4 class="h-sign" style="font-size: 12px;">
                        这是一个使用<a href="https://github.com/thesupersonic16/DALTools" target="_blank">thesupersonic16/DALTools</a>提供的解包音频作为数据集,
                        使用<a href="https://github.com/jaywalnut310/vits" target="_blank">VITS</a>技术训练的语音合成demo。
                    </h4>
                </div>
            </div>
</div>
    """)
    tmp = gr.Markdown("")
    with gr.Tabs():
        with gr.TabItem("Basic"):
            with gr.Raw():
                model_submit = gr.Button("加载/重载模型", variant="primary")
                output_1 = gr.Markdown("")
            with gr.Raw():
                tts_input1 = gr.TextArea(
                    label="请输入文本(仅支持日语)", value="你好,世界!")
                tts_input2 = gr.Dropdown(choices=[character_dict.keys], type="index",label="选择角色", optional=False)
                tts_submit = gr.Button("用文本合成", variant="primary")
                tts_output2 = gr.Audio(label="Output")
        model_submit.click(load_model, [config_path, pth_path], [output_1])
        tts_submit.click(infer, [tts_input2+1, tts_input1], [tts_output2])
    gr.HTML("""
<div style="text-align:center"> 
    仅供学习交流,不可用于商业或非法用途
    <br/>
    使用本项目模型直接或间接生成的音频,必须声明由AI技术或VITS技术合成
</div>
    """)
    app.launch()