File size: 6,049 Bytes
8bb1d29
 
 
 
5d191e9
 
0245417
17508a1
 
 
 
8bb1d29
17508a1
8bb1d29
17508a1
0245417
8bb1d29
 
 
 
17508a1
 
 
5d191e9
17508a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d191e9
17508a1
 
 
5d191e9
17508a1
 
 
 
5d191e9
17508a1
5d191e9
 
17508a1
 
 
 
8bb1d29
17508a1
8bb1d29
 
17508a1
 
 
 
5d191e9
17508a1
 
 
8bb1d29
17508a1
 
 
 
 
 
 
 
 
 
 
8bb1d29
17508a1
 
 
 
 
8bb1d29
17508a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d191e9
8bb1d29
 
 
 
 
 
 
17508a1
8bb1d29
 
 
17508a1
8bb1d29
 
0245417
8bb1d29
 
5d191e9
 
8bb1d29
17508a1
8bb1d29
17508a1
 
8bb1d29
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
import torchaudio
import torch
import os
from pydub import AudioSegment
import tempfile
from speechbrain.pretrained.separation import SepformerSeparation
import numpy as np
import threading
from queue import Queue
import time

class RealtimeAudioDenoiser:
    def __init__(self):
        # Initialize the model
        self.model = SepformerSeparation.from_hparams(
            source="speechbrain/sepformer-dns4-16k-enhancement",
            savedir='pretrained_models/sepformer-dns4-16k-enhancement'
        )
        
        # Move model to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
        # Enable inference mode for better performance
        self.model.eval()
        torch.set_grad_enabled(False)
        
        # Set chunk size for streaming (500ms chunks)
        self.chunk_duration = 0.5  # seconds
        self.sample_rate = 16000
        self.chunk_size = int(self.sample_rate * self.chunk_duration)
        
        # Initialize processing queue and buffer
        self.processing_queue = Queue()
        self.output_buffer = Queue()
        self.is_processing = False
        
        # Start processing thread
        self.processing_thread = threading.Thread(target=self._process_queue)
        self.processing_thread.daemon = True
        self.processing_thread.start()
        
        # Create output directory
        os.makedirs("enhanced_audio", exist_ok=True)

    def _optimize_model(self):
        """Optimize model for inference"""
        if self.device.type == 'cuda':
            # Use mixed precision for faster processing
            self.model = torch.quantization.quantize_dynamic(
                self.model, {torch.nn.Linear}, dtype=torch.qint8
            )
            torch.backends.cudnn.benchmark = True

    def _process_queue(self):
        """Background thread for processing audio chunks"""
        while True:
            if not self.processing_queue.empty():
                chunk = self.processing_queue.get()
                if chunk is None:
                    continue
                
                # Process audio chunk
                enhanced_chunk = self._enhance_chunk(chunk)
                self.output_buffer.put(enhanced_chunk)
            else:
                time.sleep(0.01)  # Small delay to prevent CPU overuse

    def _enhance_chunk(self, audio_chunk):
        """Process a single chunk of audio"""
        try:
            # Convert to tensor and move to device
            chunk_tensor = torch.FloatTensor(audio_chunk).to(self.device)
            chunk_tensor = chunk_tensor.unsqueeze(0)  # Add batch dimension
            
            # Process with model
            with torch.inference_mode():
                enhanced = self.model.separate_batch(chunk_tensor)
                enhanced = enhanced.squeeze(0).cpu().numpy()
            
            return enhanced
            
        except Exception as e:
            print(f"Error processing chunk: {str(e)}")
            return audio_chunk

    def process_stream(self, audio_path):
        """
        Process audio in streaming fashion
        """
        try:
            # Convert input audio to proper format
            audio = AudioSegment.from_file(audio_path)
            audio = audio.set_frame_rate(self.sample_rate)
            audio = audio.set_channels(1)
            
            # Convert to numpy array
            samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
            samples = samples / np.max(np.abs(samples))  # Normalize
            
            # Process in chunks
            enhanced_chunks = []
            for i in range(0, len(samples), self.chunk_size):
                chunk = samples[i:i + self.chunk_size]
                
                # Pad last chunk if necessary
                if len(chunk) < self.chunk_size:
                    chunk = np.pad(chunk, (0, self.chunk_size - len(chunk)))
                
                # Add to processing queue
                self.processing_queue.put(chunk)
            
            # Wait for all chunks to be processed
            while self.processing_queue.qsize() > 0 or self.output_buffer.qsize() > 0:
                if not self.output_buffer.empty():
                    enhanced_chunks.append(self.output_buffer.get())
                time.sleep(0.01)
            
            # Combine chunks
            enhanced_audio = np.concatenate(enhanced_chunks)
            
            # Save enhanced audio
            output_path = os.path.join("enhanced_audio", "enhanced_realtime.wav")
            enhanced_audio = enhanced_audio * 32767  # Convert to int16 range
            enhanced_audio = enhanced_audio.astype(np.int16)
            
            with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
                torchaudio.save(
                    f.name,
                    torch.tensor(enhanced_audio).unsqueeze(0),
                    self.sample_rate
                )
                os.replace(f.name, output_path)
            
            return output_path
            
        except Exception as e:
            raise gr.Error(f"Error processing audio: {str(e)}")

def create_gradio_interface():
    # Initialize the denoiser
    denoiser = RealtimeAudioDenoiser()
    
    # Create the Gradio interface
    interface = gr.Interface(
        fn=denoiser.process_stream,
        inputs=gr.Audio(
            type="filepath",
            label="Upload Noisy Audio"
        ),
        outputs=gr.Audio(
            label="Enhanced Audio",
            type="filepath"
        ),
        title="Real-time Audio Denoising using SepFormer",
        description="""
        Optimized for real-time processing with low latency.
        Processes audio in 500ms chunks for streaming applications.
        """
    )
    
    return interface

if __name__ == "__main__":
    # Create and launch the interface
    demo = create_gradio_interface()
    demo.launch()