|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import soundfile as sf |
|
import logging |
|
import argparse |
|
import gradio as gr |
|
from datetime import datetime |
|
from cli.SparkTTS import SparkTTS |
|
from sparktts.utils.token_parser import LEVELS_MAP_UI |
|
from huggingface_hub import snapshot_download |
|
import spaces |
|
|
|
MODEL = None |
|
|
|
def initialize_model(model_dir=None, device="cpu"): |
|
"""Load the model once at the beginning.""" |
|
|
|
if model_dir is None: |
|
logging.info(f"Downloading model to: {model_dir}") |
|
model_dir = snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B") |
|
|
|
logging.info(f"Loading model from: {model_dir}") |
|
device = torch.device(device) |
|
model = SparkTTS(model_dir, device) |
|
return model |
|
|
|
@spaces.GPU |
|
def generate(text, |
|
prompt_speech, |
|
prompt_text, |
|
gender, |
|
pitch, |
|
speed, |
|
): |
|
"""Generate audio from text.""" |
|
|
|
global MODEL |
|
|
|
|
|
if MODEL is None: |
|
MODEL = initialize_model(device="cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = MODEL |
|
|
|
|
|
if torch.cuda.is_available(): |
|
print("Moving model to GPU") |
|
model.to("cuda") |
|
|
|
with torch.no_grad(): |
|
wav = model.inference( |
|
text, |
|
prompt_speech, |
|
prompt_text, |
|
gender, |
|
pitch, |
|
speed, |
|
) |
|
|
|
return wav |
|
|
|
|
|
def run_tts( |
|
text, |
|
prompt_text=None, |
|
prompt_speech=None, |
|
gender=None, |
|
pitch=None, |
|
speed=None, |
|
save_dir="example/results", |
|
): |
|
"""Perform TTS inference and save the generated audio.""" |
|
logging.info(f"Saving audio to: {save_dir}") |
|
|
|
if prompt_text is not None: |
|
prompt_text = None if len(prompt_text) <= 1 else prompt_text |
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
save_path = os.path.join(save_dir, f"{timestamp}.wav") |
|
|
|
logging.info("Starting inference...") |
|
|
|
|
|
wav = generate(text, |
|
prompt_speech, |
|
prompt_text, |
|
gender, |
|
pitch, |
|
speed,) |
|
|
|
|
|
sf.write(save_path, wav, samplerate=16000) |
|
|
|
logging.info(f"Audio saved at: {save_path}") |
|
|
|
return save_path |
|
|
|
|
|
def build_ui(model_dir, device=0): |
|
|
|
global MODEL |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() and device != "cpu" else "cpu" |
|
if MODEL is None: |
|
MODEL = initialize_model(model_dir, device=device) |
|
if device == "cuda": |
|
MODEL = MODEL.to(device) |
|
|
|
|
|
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record): |
|
""" |
|
Gradio callback to clone voice using text and optional prompt speech. |
|
- text: The input text to be synthesised. |
|
- prompt_text: Additional textual info for the prompt (optional). |
|
- prompt_wav_upload/prompt_wav_record: Audio files used as reference. |
|
""" |
|
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record |
|
prompt_text_clean = None if len(prompt_text) < 2 else prompt_text |
|
|
|
audio_output_path = run_tts( |
|
text, |
|
prompt_text=prompt_text_clean, |
|
prompt_speech=prompt_speech |
|
) |
|
return audio_output_path |
|
|
|
|
|
def voice_creation(text, gender, pitch, speed): |
|
""" |
|
Gradio callback to create a synthetic voice with adjustable parameters. |
|
- text: The input text for synthesis. |
|
- gender: 'male' or 'female'. |
|
- pitch/speed: Ranges mapped by LEVELS_MAP_UI. |
|
""" |
|
pitch_val = LEVELS_MAP_UI[int(pitch)] |
|
speed_val = LEVELS_MAP_UI[int(speed)] |
|
audio_output_path = run_tts( |
|
text, |
|
gender=gender, |
|
pitch=pitch_val, |
|
speed=speed_val |
|
) |
|
return audio_output_path |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.HTML('<h1 style="text-align: center;">(Official) Spark-TTS by SparkAudio</h1>') |
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("Voice Clone"): |
|
gr.Markdown( |
|
"### Upload reference audio or recording (上传参考音频或者录音)" |
|
) |
|
|
|
with gr.Row(): |
|
prompt_wav_upload = gr.Audio( |
|
sources="upload", |
|
type="filepath", |
|
label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", |
|
) |
|
prompt_wav_record = gr.Audio( |
|
sources="microphone", |
|
type="filepath", |
|
label="Record the prompt audio file.", |
|
) |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Text", lines=3, placeholder="Enter text here" |
|
) |
|
prompt_text_input = gr.Textbox( |
|
label="Text of prompt speech (Optional; recommended for cloning in the same language.)", |
|
lines=3, |
|
placeholder="Enter text of the prompt speech.", |
|
) |
|
|
|
audio_output = gr.Audio( |
|
label="Generated Audio", autoplay=True, streaming=True |
|
) |
|
|
|
generate_buttom_clone = gr.Button("Generate") |
|
|
|
generate_buttom_clone.click( |
|
voice_clone, |
|
inputs=[ |
|
text_input, |
|
prompt_text_input, |
|
prompt_wav_upload, |
|
prompt_wav_record, |
|
], |
|
outputs=[audio_output], |
|
) |
|
|
|
|
|
with gr.TabItem("Voice Creation"): |
|
gr.Markdown( |
|
"### Create your own voice based on the following parameters" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gender = gr.Radio( |
|
choices=["male", "female"], value="male", label="Gender" |
|
) |
|
pitch = gr.Slider( |
|
minimum=1, maximum=5, step=1, value=3, label="Pitch" |
|
) |
|
speed = gr.Slider( |
|
minimum=1, maximum=5, step=1, value=3, label="Speed" |
|
) |
|
with gr.Column(): |
|
text_input_creation = gr.Textbox( |
|
label="Input Text", |
|
lines=3, |
|
placeholder="Enter text here", |
|
value="You can generate a customized voice by adjusting parameters such as pitch and speed.", |
|
) |
|
create_button = gr.Button("Create Voice") |
|
|
|
audio_output = gr.Audio( |
|
label="Generated Audio", autoplay=True, streaming=True |
|
) |
|
create_button.click( |
|
voice_creation, |
|
inputs=[text_input_creation, gender, pitch, speed], |
|
outputs=[audio_output], |
|
) |
|
|
|
return demo |
|
|
|
|
|
def parse_arguments(): |
|
""" |
|
Parse command-line arguments such as model directory and device ID. |
|
""" |
|
parser = argparse.ArgumentParser(description="Spark TTS Gradio server.") |
|
parser.add_argument( |
|
"--model_dir", |
|
type=str, |
|
default=None, |
|
help="Path to the model directory." |
|
) |
|
parser.add_argument( |
|
"--device", |
|
type=str, |
|
default="cpu", |
|
help="Device to use (e.g., 'cpu' or 'cuda:0')." |
|
) |
|
parser.add_argument( |
|
"--server_name", |
|
type=str, |
|
default=None, |
|
help="Server host/IP for Gradio app." |
|
) |
|
parser.add_argument( |
|
"--server_port", |
|
type=int, |
|
default=None, |
|
help="Server port for Gradio app." |
|
) |
|
return parser.parse_args() |
|
|
|
if __name__ == "__main__": |
|
|
|
args = parse_arguments() |
|
|
|
|
|
demo = build_ui( |
|
model_dir=args.model_dir, |
|
device=args.device |
|
) |
|
|
|
|
|
demo.launch( |
|
server_name=args.server_name, |
|
server_port=args.server_port |
|
) |