audio_denoiser / app.py
wrice's picture
Update to denoisers 0.2.0
8c12fe6
raw
history blame
2.04 kB
"""Gradio demo for denoisers."""
import gradio as gr
import torch
import torchaudio
from denoisers import UNet1DModel, WaveUNetModel
from tqdm import tqdm
MODELS = [
"wrice/unet1d-vctk-48khz",
"wrice/waveunet-vctk-48khz",
"wrice/waveunet-vctk-24khz",
]
def denoise(model_name: str, audio_path: str) -> str:
"""Denoise audio."""
if "unet1d" in model_name:
model = UNet1DModel.from_pretrained(model_name)
else:
model = WaveUNetModel.from_pretrained(model_name)
if torch.cuda.is_available():
model = model.cuda()
stream_reader = torchaudio.io.StreamReader(audio_path)
stream_reader.add_basic_audio_stream(
frames_per_chunk=model.config.max_length,
sample_rate=model.config.sample_rate,
num_channels=1,
)
stream_writer = torchaudio.io.StreamWriter("denoised.wav")
stream_writer.add_audio_stream(sample_rate=model.config.sample_rate, num_channels=1)
chunk_size = model.config.max_length
with stream_writer.open():
for (audio_chunk,) in tqdm(stream_reader.stream()):
if audio_chunk is None:
break
audio_chunk = audio_chunk.permute(1, 0)
original_chunk_size = audio_chunk.size(-1)
if audio_chunk.size(-1) < chunk_size:
padding = chunk_size - audio_chunk.size(-1)
audio_chunk = torch.nn.functional.pad(audio_chunk, (0, padding))
if torch.cuda.is_available():
audio_chunk = audio_chunk.cuda()
with torch.no_grad():
denoised_chunk = model(audio_chunk[None]).audio
denoised_chunk = denoised_chunk[:, :, :original_chunk_size]
stream_writer.write_audio_chunk(
0, denoised_chunk.squeeze(0).permute(1, 0).cpu()
)
return "denoised.wav"
iface = gr.Interface(
fn=denoise,
inputs=[gr.Dropdown(choices=MODELS, value=MODELS[0]), gr.Audio(type="filepath")],
outputs=gr.Audio(type="filepath"),
)
iface.launch()