Spaces:
Running
Running
import os | |
os.environ['TMPDIR'] = './temps' # avoid the system default temp folder not having access permissions | |
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' # use huggingfacae mirror for users that could not login to huggingface | |
import re | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
import gradio as gr | |
from api import StableTTSAPI | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
tts_model_path = './checkpoints/checkpoint_0.pt' | |
vocoder_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt' | |
vocoder_type = 'ffgan' | |
model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type).to(device) | |
def inference(text, ref_audio, language, step, temperature, length_scale, solver, cfg): | |
text = remove_newlines_after_punctuation(text) | |
if language == 'chinese': | |
text = text.replace(' ', '') | |
audio, mel = model.inference(text, ref_audio, language, step, temperature, length_scale, solver, cfg) | |
max_val = torch.max(torch.abs(audio)) | |
if max_val > 1: | |
audio = audio / max_val | |
audio_output = (model.mel_config.sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio | |
mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel | |
return audio_output, mel_output | |
def plot_mel_spectrogram(mel_spectrogram): | |
plt.close() # prevent memory leak | |
fig, ax = plt.subplots(figsize=(20, 8)) | |
ax.imshow(mel_spectrogram, aspect='auto', origin='lower') | |
plt.axis('off') | |
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges | |
return fig | |
def remove_newlines_after_punctuation(text): | |
pattern = r'([,。!?、“”‘’《》【】;:,.!?\'\"<>()\[\]{}])\n' | |
return re.sub(pattern, r'\1', text) | |
def main(): | |
from pathlib import Path | |
examples = list(Path('./audios').rglob('*.wav')) | |
# gradio wabui, reference: https://huggingface.co/spaces/fishaudio/fish-speech-1 | |
gui_title = 'StableTTS' | |
gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3.""" | |
example_text = """你指尖跳动的电光,是我永恒不变的信仰。唯我超电磁炮永世长存!""" | |
with gr.Blocks(theme=gr.themes.Base()) as demo: | |
demo.load(None, None, js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(f"# {gui_title}") | |
gr.Markdown(gui_description) | |
with gr.Row(): | |
with gr.Column(): | |
input_text_gr = gr.Textbox( | |
label="Input Text", | |
info="Put your text here", | |
value=example_text, | |
) | |
ref_audio_gr = gr.Audio( | |
label="Reference Audio", | |
type="filepath" | |
) | |
language_gr = gr.Dropdown( | |
label='Language', | |
choices=list(model.supported_languages), | |
value = 'chinese' | |
) | |
step_gr = gr.Slider( | |
label='Step', | |
minimum=1, | |
maximum=100, | |
value=25, | |
step=1 | |
) | |
temperature_gr = gr.Slider( | |
label='Temperature', | |
minimum=0, | |
maximum=2, | |
value=1, | |
) | |
length_scale_gr = gr.Slider( | |
label='Length_Scale', | |
minimum=0, | |
maximum=5, | |
value=1, | |
) | |
solver_gr = gr.Dropdown( | |
label='ODE Solver', | |
choices=['euler', 'midpoint', 'dopri5', 'rk4', 'implicit_adams', 'bosh3', 'fehlberg2', 'adaptive_heun'], | |
value = 'dopri5' | |
) | |
cfg_gr = gr.Slider( | |
label='CFG', | |
minimum=0, | |
maximum=10, | |
value=3, | |
) | |
with gr.Column(): | |
mel_gr = gr.Plot(label="Mel Visual") | |
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) | |
tts_button = gr.Button("\U0001F3A7 Generate / 合成", elem_id="send-btn", visible=True, variant="primary") | |
examples = gr.Examples(examples, ref_audio_gr) | |
tts_button.click(inference, [input_text_gr, ref_audio_gr, language_gr, step_gr, temperature_gr, length_scale_gr, solver_gr, cfg_gr], outputs=[audio_gr, mel_gr]) | |
demo.queue() | |
demo.launch(debug=True, show_api=True) | |
if __name__ == '__main__': | |
main() |