Spark-TTS-0.5B / app.py
mrfakename's picture
Update app.py
47f2da0 verified
raw
history blame
5.12 kB
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from utils import load_ckpt, print_colored
from tokenizer import make_tokenizer
from model import get_hertz_dev_config
import matplotlib.pyplot as plt
import spaces
import gradio as gr
device = 'cuda' if T.cuda.is_available() else 'cpu'
#T.cuda.set_device(0)
print(f"Using device: {device}")
audio_tokenizer = make_tokenizer(device)
TWO_SPEAKER = False
model_config = get_hertz_dev_config(is_split=TWO_SPEAKER)
generator = model_config()
generator = generator.eval().to(T.bfloat16).to(device)
##############
# Load audio
def load_and_preprocess_audio(audio_path):
gr.Info("Loading and preprocessing audio...")
# Load audio file
audio_tensor, sr = torchaudio.load(audio_path)
gr.Info(f"Loaded audio shape: {audio_tensor.shape}")
if TWO_SPEAKER:
if audio_tensor.shape[0] == 1:
gr.Info("Converting mono to stereo...")
audio_tensor = audio_tensor.repeat(2, 1)
gr.Info(f"Stereo audio shape: {audio_tensor.shape}")
else:
if audio_tensor.shape[0] == 2:
gr.Info("Converting stereo to mono...")
audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
gr.Info(f"Mono audio shape: {audio_tensor.shape}")
# Resample to 16kHz if needed
if sr != 16000:
gr.Info(f"Resampling from {sr}Hz to 16000Hz...")
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
audio_tensor = resampler(audio_tensor)
# Clip to 5 minutes if needed
max_samples = 16000 * 60 * 5
if audio_tensor.shape[1] > max_samples:
# gr.Info("Clipping audio to 5 minutes...")
raise gr.Erorr("Maximum prompt is 5 minutes")
# audio_tensor = audio_tensor[:, :max_samples]
duration_seconds = audio_tensor.shape[1] / sample_rate
gr.Info("Audio preprocessing complete!")
return audio_tensor.unsqueeze(0), duration_seconds
##############
# Return audio to gradio
def display_audio(audio_tensor):
audio_tensor = audio_tensor.cpu().squeeze()
if audio_tensor.ndim == 1:
audio_tensor = audio_tensor.unsqueeze(0)
audio_tensor = audio_tensor.float()
# Make a waveform plot
# plt.figure(figsize=(4, 1))
# plt.plot(audio_tensor.numpy()[0], linewidth=0.5)
# plt.axis('off')
# plt.show()
# Make an audio player
return (16000, audio_tensor.numpy())
def get_completion(encoded_prompt_audio, prompt_len):
prompt_len_seconds = prompt_len / 8
gr.Info(f"Prompt length: {prompt_len_seconds:.2f}s")
with T.autocast(device_type='cuda', dtype=T.bfloat16):
completed_audio_batch = generator.completion(
encoded_prompt_audio,
temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))
use_cache=True)
completed_audio = completed_audio_batch
print_colored(f"Decoding completion...", "blue")
if TWO_SPEAKER:
decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())
decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())
decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)
else:
decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())
gr.Info(f"Decoded completion shape: {decoded_completion.shape}")
gr.Info("Preparing audio for playback...")
audio_tensor = decoded_completion.cpu().squeeze()
if audio_tensor.ndim == 1:
audio_tensor = audio_tensor.unsqueeze(0)
audio_tensor = audio_tensor.float()
if audio_tensor.abs().max() > 1:
audio_tensor = audio_tensor / audio_tensor.abs().max()
return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]
@spaces.GPU
def run(audio_path):
prompt_audio, prompt_len_seconds = load_and_preprocess_audio(audio_path)
prompt_len = prompt_len_seconds * 8
gr.Info("Encoding prompt...")
with T.autocast(device_type='cuda', dtype=T.bfloat16):
if TWO_SPEAKER:
encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))
encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))
encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)
else:
encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
gr.Info(f"Encoded prompt shape: {encoded_prompt_audio.shape}")
gr.Info("Prompt encoded successfully!")
# num_completions = 10
completion = get_completion(encoded_prompt_audio, prompt_len)
return display_audio(completion)
with gr.Blocks() as demo:
gr.Markdown("# hertz-dev")
inp = gr.Audio(label="Input Audio", type="filepath", interactive=True)
btn = gr.Button("Continue", variant="primary")
out = gr.Audio(label="Output", interactive=False)
btn.click(run, inputs=inp, outputs=out)
demo.queue().launch()