File size: 5,340 Bytes
d358e26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os

from dataclasses import asdict
from text import symbols
import torch
import torchaudio

from utils.audio import LogMelSpectrogram
from config import ModelConfig, VocosConfig, MelConfig
from models.model import StableTTS
from vocos_pytorch.models.model import Vocos
from text.english import english_to_ipa2
from text import cleaned_text_to_sequence
from datas.dataset import intersperse

import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

device = 'cpu'

@ torch.inference_mode()
def inference(text: str, ref_audio: torch.Tensor, checkpoint_path: str, step: int=10) -> torch.Tensor:
    global last_checkpoint_path
    if checkpoint_path != last_checkpoint_path:
        tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) 
        last_checkpoint_path = checkpoint_path
        
    phonemizer = english_to_ipa2
    
    # prepare input for tts model
    x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
    x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
    waveform, sr = torchaudio.load(ref_audio)
    if sr != sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
    y = mel_extractor(waveform).to(device)
    
    # inference
    mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs']
    audio = vocoder(mel)
    
    # process output for gradio
    audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
    mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
    return audio_output, mel_output

def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path):
    tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
    mel_extractor = LogMelSpectrogram(mel_config)
    vocoder = Vocos(vocoder_config, mel_config)
    # tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
    tts_model.to(device)
    tts_model.eval()
    vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
    vocoder.to(device)
    vocoder.eval()
    return tts_model, mel_extractor, vocoder

def plot_mel_spectrogram(mel_spectrogram):
    fig, ax = plt.subplots(figsize=(20, 8))
    ax.imshow(mel_spectrogram, aspect='auto', origin='lower')
    plt.axis('off')
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges
    return fig


def main():
    tts_model_config = ModelConfig()
    mel_config = MelConfig()
    vocoder_config = VocosConfig()

    tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints
    vocoder_checkpoint_path = './checkpoints/vocoder.pt'

    global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path
    sample_rate = mel_config.sample_rate
    last_checkpoint_path = None
    tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
    
    tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name]
    audios = list(Path('./audios').rglob('*.wav')) + list(Path('./audios').rglob('*.flac'))

    # gradio wabui
    gui_title = 'StableTTS'
    gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3."""
    with gr.Blocks(analytics_enabled=False) as demo:

        with gr.Row():
            with gr.Column():
                gr.Markdown(f"# {gui_title}")
                gr.Markdown(gui_description)

        with gr.Row():
            with gr.Column():
                input_text_gr = gr.Textbox(
                    label="Input Text",
                    info="One or two sentences at a time is better. Up to 200 text characters.",
                    value="Today I want to tell you three stories from my life. That's it. No big deal. Just three stories.",
                )
             
                ref_audio_gr = gr.Dropdown(
                    label='reference audio',
                    choices=audios,
                    value = 0
                )
                
                
                checkpoint_gr = gr.Dropdown(
                    label='checkpoint',
                    choices=tts_checkpoint_path,
                    value = 0
                )
                
                step_gr = gr.Slider(
                    label='Step',
                    minimum=1,
                    maximum=40,
                    value=8,
                    step=1
                )


                tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
                
            with gr.Column():
                mel_gr = gr.Plot(label="Mel Visual")
                audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)

        tts_button.click(inference, [input_text_gr, ref_audio_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr])

    demo.queue()  
    demo.launch(debug=True, show_api=True)


if __name__ == '__main__':
    main()