Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torchaudio | |
import torch | |
import os | |
from pydub import AudioSegment | |
import tempfile | |
from speechbrain.pretrained.separation import SepformerSeparation | |
import numpy as np | |
import threading | |
from queue import Queue | |
import time | |
class RealtimeAudioDenoiser: | |
def __init__(self): | |
# Initialize the model | |
self.model = SepformerSeparation.from_hparams( | |
source="speechbrain/sepformer-dns4-16k-enhancement", | |
savedir='pretrained_models/sepformer-dns4-16k-enhancement' | |
) | |
# Move model to GPU if available | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
# Enable inference mode for better performance | |
self.model.eval() | |
torch.set_grad_enabled(False) | |
# Set chunk size for streaming (500ms chunks) | |
self.chunk_duration = 0.5 # seconds | |
self.sample_rate = 16000 | |
self.chunk_size = int(self.sample_rate * self.chunk_duration) | |
# Initialize processing queue and buffer | |
self.processing_queue = Queue() | |
self.output_buffer = Queue() | |
self.is_processing = False | |
# Start processing thread | |
self.processing_thread = threading.Thread(target=self._process_queue) | |
self.processing_thread.daemon = True | |
self.processing_thread.start() | |
# Create output directory | |
os.makedirs("enhanced_audio", exist_ok=True) | |
def _optimize_model(self): | |
"""Optimize model for inference""" | |
if self.device.type == 'cuda': | |
# Use mixed precision for faster processing | |
self.model = torch.quantization.quantize_dynamic( | |
self.model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
torch.backends.cudnn.benchmark = True | |
def _process_queue(self): | |
"""Background thread for processing audio chunks""" | |
while True: | |
if not self.processing_queue.empty(): | |
chunk = self.processing_queue.get() | |
if chunk is None: | |
continue | |
# Process audio chunk | |
enhanced_chunk = self._enhance_chunk(chunk) | |
self.output_buffer.put(enhanced_chunk) | |
else: | |
time.sleep(0.01) # Small delay to prevent CPU overuse | |
def _enhance_chunk(self, audio_chunk): | |
"""Process a single chunk of audio""" | |
try: | |
# Convert to tensor and move to device | |
chunk_tensor = torch.FloatTensor(audio_chunk).to(self.device) | |
chunk_tensor = chunk_tensor.unsqueeze(0) # Add batch dimension | |
# Process with model | |
with torch.inference_mode(): | |
enhanced = self.model.separate_batch(chunk_tensor) | |
enhanced = enhanced.squeeze(0).cpu().numpy() | |
return enhanced | |
except Exception as e: | |
print(f"Error processing chunk: {str(e)}") | |
return audio_chunk | |
def process_stream(self, audio_path): | |
""" | |
Process audio in streaming fashion | |
""" | |
try: | |
# Convert input audio to proper format | |
audio = AudioSegment.from_file(audio_path) | |
audio = audio.set_frame_rate(self.sample_rate) | |
audio = audio.set_channels(1) | |
# Convert to numpy array | |
samples = np.array(audio.get_array_of_samples(), dtype=np.float32) | |
samples = samples / np.max(np.abs(samples)) # Normalize | |
# Process in chunks | |
enhanced_chunks = [] | |
for i in range(0, len(samples), self.chunk_size): | |
chunk = samples[i:i + self.chunk_size] | |
# Pad last chunk if necessary | |
if len(chunk) < self.chunk_size: | |
chunk = np.pad(chunk, (0, self.chunk_size - len(chunk))) | |
# Add to processing queue | |
self.processing_queue.put(chunk) | |
# Wait for all chunks to be processed | |
while self.processing_queue.qsize() > 0 or self.output_buffer.qsize() > 0: | |
if not self.output_buffer.empty(): | |
enhanced_chunks.append(self.output_buffer.get()) | |
time.sleep(0.01) | |
# Combine chunks | |
enhanced_audio = np.concatenate(enhanced_chunks) | |
# Save enhanced audio | |
output_path = os.path.join("enhanced_audio", "enhanced_realtime.wav") | |
enhanced_audio = enhanced_audio * 32767 # Convert to int16 range | |
enhanced_audio = enhanced_audio.astype(np.int16) | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: | |
torchaudio.save( | |
f.name, | |
torch.tensor(enhanced_audio).unsqueeze(0), | |
self.sample_rate | |
) | |
os.replace(f.name, output_path) | |
return output_path | |
except Exception as e: | |
raise gr.Error(f"Error processing audio: {str(e)}") | |
def create_gradio_interface(): | |
# Initialize the denoiser | |
denoiser = RealtimeAudioDenoiser() | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=denoiser.process_stream, | |
inputs=gr.Audio( | |
type="filepath", | |
label="Upload Noisy Audio" | |
), | |
outputs=gr.Audio( | |
label="Enhanced Audio", | |
type="filepath" | |
), | |
title="Real-time Audio Denoising using SepFormer", | |
description=""" | |
Optimized for real-time processing with low latency. | |
Processes audio in 500ms chunks for streaming applications. | |
""" | |
) | |
return interface | |
if __name__ == "__main__": | |
# Create and launch the interface | |
demo = create_gradio_interface() | |
demo.launch() |