lyimo commited on
Commit
17508a1
·
verified ·
1 Parent(s): 59d91ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -74
app.py CHANGED
@@ -5,84 +5,133 @@ import os
5
  from pydub import AudioSegment
6
  import tempfile
7
  from speechbrain.pretrained.separation import SepformerSeparation
 
 
 
 
8
 
9
- class AudioDenoiser:
10
  def __init__(self):
11
- # Initialize the SepFormer model for audio enhancement
12
  self.model = SepformerSeparation.from_hparams(
13
  source="speechbrain/sepformer-dns4-16k-enhancement",
14
  savedir='pretrained_models/sepformer-dns4-16k-enhancement'
15
  )
16
 
17
- # Create output directory if it doesn't exist
18
- os.makedirs("enhanced_audio", exist_ok=True)
19
-
20
- def convert_audio_to_wav(self, input_path):
21
- """
22
- Convert any audio format to WAV with proper settings
23
 
24
- Args:
25
- input_path (str): Path to input audio file
26
-
27
- Returns:
28
- str: Path to converted WAV file
29
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  try:
31
- # Create a temporary file for the converted audio
32
- temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
33
- temp_wav_path = temp_wav.name
34
-
35
- # Load audio using pydub (supports multiple formats)
36
- audio = AudioSegment.from_file(input_path)
37
-
38
- # Convert to mono if stereo
39
- if audio.channels > 1:
40
- audio = audio.set_channels(1)
41
 
42
- # Export as WAV with proper settings
43
- audio.export(
44
- temp_wav_path,
45
- format='wav',
46
- parameters=[
47
- '-ar', '16000', # Set sample rate to 16kHz
48
- '-ac', '1' # Set channels to mono
49
- ]
50
- )
51
 
52
- return temp_wav_path
53
 
54
  except Exception as e:
55
- raise gr.Error(f"Error converting audio format: {str(e)}")
56
-
57
- def enhance_audio(self, audio_path):
 
58
  """
59
- Process the input audio file and return the enhanced version
60
-
61
- Args:
62
- audio_path (str): Path to the input audio file
63
-
64
- Returns:
65
- str: Path to the enhanced audio file
66
  """
67
  try:
68
- # Convert input audio to proper WAV format
69
- wav_path = self.convert_audio_to_wav(audio_path)
 
 
70
 
71
- # Separate and enhance the audio
72
- est_sources = self.model.separate_file(path=wav_path)
 
73
 
74
- # Generate output filename
75
- output_path = os.path.join("enhanced_audio", "enhanced_audio.wav")
 
 
 
 
 
 
 
 
 
76
 
77
- # Save the enhanced audio
78
- torchaudio.save(
79
- output_path,
80
- est_sources[:, :, 0].detach().cpu(),
81
- 16000 # Sample rate
82
- )
83
 
84
- # Clean up temporary file
85
- os.unlink(wav_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  return output_path
88
 
@@ -91,11 +140,11 @@ class AudioDenoiser:
91
 
92
  def create_gradio_interface():
93
  # Initialize the denoiser
94
- denoiser = AudioDenoiser()
95
 
96
  # Create the Gradio interface
97
  interface = gr.Interface(
98
- fn=denoiser.enhance_audio,
99
  inputs=gr.Audio(
100
  type="filepath",
101
  label="Upload Noisy Audio"
@@ -104,21 +153,10 @@ def create_gradio_interface():
104
  label="Enhanced Audio",
105
  type="filepath"
106
  ),
107
- title="Audio Denoising using SepFormer",
108
  description="""
109
- This application uses the SepFormer model from SpeechBrain to enhance audio quality
110
- by removing background noise. Supports various audio formats including MP3 and WAV.
111
- """,
112
- article="""
113
- Supported audio formats:
114
- - MP3
115
- - WAV
116
- - OGG
117
- - FLAC
118
- - M4A
119
- and more...
120
-
121
- The audio will automatically be converted to the correct format for processing.
122
  """
123
  )
124
 
 
5
  from pydub import AudioSegment
6
  import tempfile
7
  from speechbrain.pretrained.separation import SepformerSeparation
8
+ import numpy as np
9
+ import threading
10
+ from queue import Queue
11
+ import time
12
 
13
+ class RealtimeAudioDenoiser:
14
  def __init__(self):
15
+ # Initialize the model
16
  self.model = SepformerSeparation.from_hparams(
17
  source="speechbrain/sepformer-dns4-16k-enhancement",
18
  savedir='pretrained_models/sepformer-dns4-16k-enhancement'
19
  )
20
 
21
+ # Move model to GPU if available
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ self.model.to(self.device)
 
 
 
24
 
25
+ # Enable inference mode for better performance
26
+ self.model.eval()
27
+ torch.set_grad_enabled(False)
28
+
29
+ # Set chunk size for streaming (500ms chunks)
30
+ self.chunk_duration = 0.5 # seconds
31
+ self.sample_rate = 16000
32
+ self.chunk_size = int(self.sample_rate * self.chunk_duration)
33
+
34
+ # Initialize processing queue and buffer
35
+ self.processing_queue = Queue()
36
+ self.output_buffer = Queue()
37
+ self.is_processing = False
38
+
39
+ # Start processing thread
40
+ self.processing_thread = threading.Thread(target=self._process_queue)
41
+ self.processing_thread.daemon = True
42
+ self.processing_thread.start()
43
+
44
+ # Create output directory
45
+ os.makedirs("enhanced_audio", exist_ok=True)
46
+
47
+ def _optimize_model(self):
48
+ """Optimize model for inference"""
49
+ if self.device.type == 'cuda':
50
+ # Use mixed precision for faster processing
51
+ self.model = torch.quantization.quantize_dynamic(
52
+ self.model, {torch.nn.Linear}, dtype=torch.qint8
53
+ )
54
+ torch.backends.cudnn.benchmark = True
55
+
56
+ def _process_queue(self):
57
+ """Background thread for processing audio chunks"""
58
+ while True:
59
+ if not self.processing_queue.empty():
60
+ chunk = self.processing_queue.get()
61
+ if chunk is None:
62
+ continue
63
+
64
+ # Process audio chunk
65
+ enhanced_chunk = self._enhance_chunk(chunk)
66
+ self.output_buffer.put(enhanced_chunk)
67
+ else:
68
+ time.sleep(0.01) # Small delay to prevent CPU overuse
69
+
70
+ def _enhance_chunk(self, audio_chunk):
71
+ """Process a single chunk of audio"""
72
  try:
73
+ # Convert to tensor and move to device
74
+ chunk_tensor = torch.FloatTensor(audio_chunk).to(self.device)
75
+ chunk_tensor = chunk_tensor.unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
 
76
 
77
+ # Process with model
78
+ with torch.inference_mode():
79
+ enhanced = self.model.separate_batch(chunk_tensor)
80
+ enhanced = enhanced.squeeze(0).cpu().numpy()
 
 
 
 
 
81
 
82
+ return enhanced
83
 
84
  except Exception as e:
85
+ print(f"Error processing chunk: {str(e)}")
86
+ return audio_chunk
87
+
88
+ def process_stream(self, audio_path):
89
  """
90
+ Process audio in streaming fashion
 
 
 
 
 
 
91
  """
92
  try:
93
+ # Convert input audio to proper format
94
+ audio = AudioSegment.from_file(audio_path)
95
+ audio = audio.set_frame_rate(self.sample_rate)
96
+ audio = audio.set_channels(1)
97
 
98
+ # Convert to numpy array
99
+ samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
100
+ samples = samples / np.max(np.abs(samples)) # Normalize
101
 
102
+ # Process in chunks
103
+ enhanced_chunks = []
104
+ for i in range(0, len(samples), self.chunk_size):
105
+ chunk = samples[i:i + self.chunk_size]
106
+
107
+ # Pad last chunk if necessary
108
+ if len(chunk) < self.chunk_size:
109
+ chunk = np.pad(chunk, (0, self.chunk_size - len(chunk)))
110
+
111
+ # Add to processing queue
112
+ self.processing_queue.put(chunk)
113
 
114
+ # Wait for all chunks to be processed
115
+ while self.processing_queue.qsize() > 0 or self.output_buffer.qsize() > 0:
116
+ if not self.output_buffer.empty():
117
+ enhanced_chunks.append(self.output_buffer.get())
118
+ time.sleep(0.01)
 
119
 
120
+ # Combine chunks
121
+ enhanced_audio = np.concatenate(enhanced_chunks)
122
+
123
+ # Save enhanced audio
124
+ output_path = os.path.join("enhanced_audio", "enhanced_realtime.wav")
125
+ enhanced_audio = enhanced_audio * 32767 # Convert to int16 range
126
+ enhanced_audio = enhanced_audio.astype(np.int16)
127
+
128
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
129
+ torchaudio.save(
130
+ f.name,
131
+ torch.tensor(enhanced_audio).unsqueeze(0),
132
+ self.sample_rate
133
+ )
134
+ os.replace(f.name, output_path)
135
 
136
  return output_path
137
 
 
140
 
141
  def create_gradio_interface():
142
  # Initialize the denoiser
143
+ denoiser = RealtimeAudioDenoiser()
144
 
145
  # Create the Gradio interface
146
  interface = gr.Interface(
147
+ fn=denoiser.process_stream,
148
  inputs=gr.Audio(
149
  type="filepath",
150
  label="Upload Noisy Audio"
 
153
  label="Enhanced Audio",
154
  type="filepath"
155
  ),
156
+ title="Real-time Audio Denoising using SepFormer",
157
  description="""
158
+ Optimized for real-time processing with low latency.
159
+ Processes audio in 500ms chunks for streaming applications.
 
 
 
 
 
 
 
 
 
 
 
160
  """
161
  )
162