lyimo's picture
Update app.py
17508a1 verified
raw
history blame
6.05 kB
import gradio as gr
import torchaudio
import torch
import os
from pydub import AudioSegment
import tempfile
from speechbrain.pretrained.separation import SepformerSeparation
import numpy as np
import threading
from queue import Queue
import time
class RealtimeAudioDenoiser:
def __init__(self):
# Initialize the model
self.model = SepformerSeparation.from_hparams(
source="speechbrain/sepformer-dns4-16k-enhancement",
savedir='pretrained_models/sepformer-dns4-16k-enhancement'
)
# Move model to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Enable inference mode for better performance
self.model.eval()
torch.set_grad_enabled(False)
# Set chunk size for streaming (500ms chunks)
self.chunk_duration = 0.5 # seconds
self.sample_rate = 16000
self.chunk_size = int(self.sample_rate * self.chunk_duration)
# Initialize processing queue and buffer
self.processing_queue = Queue()
self.output_buffer = Queue()
self.is_processing = False
# Start processing thread
self.processing_thread = threading.Thread(target=self._process_queue)
self.processing_thread.daemon = True
self.processing_thread.start()
# Create output directory
os.makedirs("enhanced_audio", exist_ok=True)
def _optimize_model(self):
"""Optimize model for inference"""
if self.device.type == 'cuda':
# Use mixed precision for faster processing
self.model = torch.quantization.quantize_dynamic(
self.model, {torch.nn.Linear}, dtype=torch.qint8
)
torch.backends.cudnn.benchmark = True
def _process_queue(self):
"""Background thread for processing audio chunks"""
while True:
if not self.processing_queue.empty():
chunk = self.processing_queue.get()
if chunk is None:
continue
# Process audio chunk
enhanced_chunk = self._enhance_chunk(chunk)
self.output_buffer.put(enhanced_chunk)
else:
time.sleep(0.01) # Small delay to prevent CPU overuse
def _enhance_chunk(self, audio_chunk):
"""Process a single chunk of audio"""
try:
# Convert to tensor and move to device
chunk_tensor = torch.FloatTensor(audio_chunk).to(self.device)
chunk_tensor = chunk_tensor.unsqueeze(0) # Add batch dimension
# Process with model
with torch.inference_mode():
enhanced = self.model.separate_batch(chunk_tensor)
enhanced = enhanced.squeeze(0).cpu().numpy()
return enhanced
except Exception as e:
print(f"Error processing chunk: {str(e)}")
return audio_chunk
def process_stream(self, audio_path):
"""
Process audio in streaming fashion
"""
try:
# Convert input audio to proper format
audio = AudioSegment.from_file(audio_path)
audio = audio.set_frame_rate(self.sample_rate)
audio = audio.set_channels(1)
# Convert to numpy array
samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
samples = samples / np.max(np.abs(samples)) # Normalize
# Process in chunks
enhanced_chunks = []
for i in range(0, len(samples), self.chunk_size):
chunk = samples[i:i + self.chunk_size]
# Pad last chunk if necessary
if len(chunk) < self.chunk_size:
chunk = np.pad(chunk, (0, self.chunk_size - len(chunk)))
# Add to processing queue
self.processing_queue.put(chunk)
# Wait for all chunks to be processed
while self.processing_queue.qsize() > 0 or self.output_buffer.qsize() > 0:
if not self.output_buffer.empty():
enhanced_chunks.append(self.output_buffer.get())
time.sleep(0.01)
# Combine chunks
enhanced_audio = np.concatenate(enhanced_chunks)
# Save enhanced audio
output_path = os.path.join("enhanced_audio", "enhanced_realtime.wav")
enhanced_audio = enhanced_audio * 32767 # Convert to int16 range
enhanced_audio = enhanced_audio.astype(np.int16)
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
torchaudio.save(
f.name,
torch.tensor(enhanced_audio).unsqueeze(0),
self.sample_rate
)
os.replace(f.name, output_path)
return output_path
except Exception as e:
raise gr.Error(f"Error processing audio: {str(e)}")
def create_gradio_interface():
# Initialize the denoiser
denoiser = RealtimeAudioDenoiser()
# Create the Gradio interface
interface = gr.Interface(
fn=denoiser.process_stream,
inputs=gr.Audio(
type="filepath",
label="Upload Noisy Audio"
),
outputs=gr.Audio(
label="Enhanced Audio",
type="filepath"
),
title="Real-time Audio Denoising using SepFormer",
description="""
Optimized for real-time processing with low latency.
Processes audio in 500ms chunks for streaming applications.
"""
)
return interface
if __name__ == "__main__":
# Create and launch the interface
demo = create_gradio_interface()
demo.launch()