audio_denoiser / app.py
wrice's picture
handle audio with different shapes
c13028f
raw
history blame
1.77 kB
"""Gradio demo for denoisers."""
import gradio as gr
import numpy as np
import torch
import torchaudio
from denoisers import UNet1DModel, WaveUNetModel
from tqdm import tqdm
MODELS = [
"wrice/unet1d-vctk-48khz",
"wrice/waveunet-vctk-48khz",
"wrice/waveunet-vctk-24khz",
]
def denoise(model_name, inputs):
"""Denoise audio."""
if "unet1d" in model_name:
model = UNet1DModel.from_pretrained(model_name)
else:
model = WaveUNetModel.from_pretrained(model_name)
sr, audio = inputs
audio = torch.from_numpy(audio)
audio = audio / 32768.0
if audio.ndim == 1:
audio = audio.unsqueeze(0)
print(f"Audio shape: {audio.shape}")
print(f"Sample rate: {sr}")
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
print(f"Audio shape: {audio.shape}")
if sr != model.config.sample_rate:
audio = torchaudio.functional.resample(audio, sr, model.config.sample_rate)
chunk_size = model.config.max_length
padding = abs(audio.size(-1) % chunk_size - chunk_size)
padded = torch.nn.functional.pad(audio, (0, padding))
clean = []
for i in tqdm(range(0, padded.shape[-1], chunk_size)):
audio_chunk = padded[:, :, i : i + chunk_size]
with torch.no_grad():
clean_chunk = model(audio_chunk).audio
clean.append(clean_chunk.squeeze(0))
denoised = torch.concat(clean, 1)[:, : audio.shape[-1]].clamp(-1.0, 1.0)
denoised = (denoised * 32767.0).numpy().astype(np.int16)
print(f"Denoised shape: {denoised.shape}")
return model.config.sample_rate, denoised.transpose()
iface = gr.Interface(
fn=denoise,
inputs=[gr.Dropdown(choices=MODELS, value=MODELS[0]), "audio"],
outputs="audio",
)
iface.launch()