Spaces:
Build error
Build error
import torch as T | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio | |
from utils import load_ckpt, print_colored | |
from tokenizer import make_tokenizer | |
from model import get_hertz_dev_config | |
import matplotlib.pyplot as plt | |
import spaces | |
import gradio as gr | |
device = 'cuda' if T.cuda.is_available() else 'cpu' | |
#T.cuda.set_device(0) | |
print(f"Using device: {device}") | |
audio_tokenizer = make_tokenizer(device) | |
TWO_SPEAKER = False | |
model_config = get_hertz_dev_config(is_split=TWO_SPEAKER) | |
generator = model_config() | |
generator = generator.eval().to(T.bfloat16).to(device) | |
############## | |
# Load audio | |
def load_and_preprocess_audio(audio_path): | |
gr.Info("Loading and preprocessing audio...") | |
# Load audio file | |
audio_tensor, sr = torchaudio.load(audio_path) | |
gr.Info(f"Loaded audio shape: {audio_tensor.shape}") | |
if TWO_SPEAKER: | |
if audio_tensor.shape[0] == 1: | |
gr.Info("Converting mono to stereo...") | |
audio_tensor = audio_tensor.repeat(2, 1) | |
gr.Info(f"Stereo audio shape: {audio_tensor.shape}") | |
else: | |
if audio_tensor.shape[0] == 2: | |
gr.Info("Converting stereo to mono...") | |
audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0) | |
gr.Info(f"Mono audio shape: {audio_tensor.shape}") | |
# Resample to 16kHz if needed | |
if sr != 16000: | |
gr.Info(f"Resampling from {sr}Hz to 16000Hz...") | |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) | |
audio_tensor = resampler(audio_tensor) | |
# Clip to 5 minutes if needed | |
max_samples = 16000 * 60 * 5 | |
if audio_tensor.shape[1] > max_samples: | |
# gr.Info("Clipping audio to 5 minutes...") | |
raise gr.Erorr("Maximum prompt is 5 minutes") | |
# audio_tensor = audio_tensor[:, :max_samples] | |
duration_seconds = audio_tensor.shape[1] / sample_rate | |
gr.Info("Audio preprocessing complete!") | |
return audio_tensor.unsqueeze(0), duration_seconds | |
############## | |
# Return audio to gradio | |
def display_audio(audio_tensor): | |
audio_tensor = audio_tensor.cpu().squeeze() | |
if audio_tensor.ndim == 1: | |
audio_tensor = audio_tensor.unsqueeze(0) | |
audio_tensor = audio_tensor.float() | |
# Make a waveform plot | |
# plt.figure(figsize=(4, 1)) | |
# plt.plot(audio_tensor.numpy()[0], linewidth=0.5) | |
# plt.axis('off') | |
# plt.show() | |
# Make an audio player | |
return (16000, audio_tensor.numpy()) | |
def get_completion(encoded_prompt_audio, prompt_len): | |
prompt_len_seconds = prompt_len / 8 | |
gr.Info(f"Prompt length: {prompt_len_seconds:.2f}s") | |
with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
completed_audio_batch = generator.completion( | |
encoded_prompt_audio, | |
temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp)) | |
use_cache=True) | |
completed_audio = completed_audio_batch | |
print_colored(f"Decoding completion...", "blue") | |
if TWO_SPEAKER: | |
decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16()) | |
decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16()) | |
decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0) | |
else: | |
decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16()) | |
gr.Info(f"Decoded completion shape: {decoded_completion.shape}") | |
gr.Info("Preparing audio for playback...") | |
audio_tensor = decoded_completion.cpu().squeeze() | |
if audio_tensor.ndim == 1: | |
audio_tensor = audio_tensor.unsqueeze(0) | |
audio_tensor = audio_tensor.float() | |
if audio_tensor.abs().max() > 1: | |
audio_tensor = audio_tensor / audio_tensor.abs().max() | |
return audio_tensor[:, max(prompt_len*2000 - 16000, 0):] | |
def run(audio_path): | |
prompt_audio, prompt_len_seconds = load_and_preprocess_audio(audio_path) | |
prompt_len = prompt_len_seconds * 8 | |
gr.Info("Encoding prompt...") | |
with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
if TWO_SPEAKER: | |
encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device)) | |
encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device)) | |
encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1) | |
else: | |
encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device)) | |
gr.Info(f"Encoded prompt shape: {encoded_prompt_audio.shape}") | |
gr.Info("Prompt encoded successfully!") | |
# num_completions = 10 | |
completion = get_completion(encoded_prompt_audio, prompt_len) | |
return display_audio(completion) | |
with gr.Blocks() as demo: | |
gr.Markdown("# hertz-dev") | |
inp = gr.Audio(label="Input Audio", type="filepath", interactive=True) | |
btn = gr.Button("Continue", variant="primary") | |
out = gr.Audio(label="Output", interactive=False) | |
btn.click(run, inputs=inp, outputs=out) | |
demo.queue().launch() |