Spaces:
Runtime error
Runtime error
Create separation_utils.py
Browse files- separation_utils.py +85 -0
separation_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# separation_utils.py
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import librosa
|
5 |
+
import soundfile as sf
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import seaborn as sns
|
8 |
+
from scipy.special import rel_entr
|
9 |
+
import nussl
|
10 |
+
import types
|
11 |
+
|
12 |
+
# Required to patch mask validation for nussl
|
13 |
+
|
14 |
+
def _validate_mask_patched(self, mask_):
|
15 |
+
assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
|
16 |
+
if mask_.dtype == bool:
|
17 |
+
return mask_
|
18 |
+
mask_ = mask_ > 0.5
|
19 |
+
if not np.all(np.logical_or(mask_, np.logical_not(mask_))):
|
20 |
+
raise ValueError('All mask entries must be 0 or 1.')
|
21 |
+
return mask_
|
22 |
+
|
23 |
+
nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
|
24 |
+
_validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)
|
25 |
+
|
26 |
+
# Separation methods
|
27 |
+
def Repet(mix):
|
28 |
+
return nussl.separation.primitive.Repet(mix)( )
|
29 |
+
|
30 |
+
def Repet_Sim(mix):
|
31 |
+
return nussl.separation.primitive.RepetSim(mix)( )
|
32 |
+
|
33 |
+
def Two_DFT(mix):
|
34 |
+
return nussl.separation.primitive.FT2D(mix)( )
|
35 |
+
|
36 |
+
# Audio metrics
|
37 |
+
def calculate_psnr(clean_signal, separated_signal):
|
38 |
+
min_length = min(len(clean_signal), len(separated_signal))
|
39 |
+
clean_signal = clean_signal[:min_length]
|
40 |
+
separated_signal = separated_signal[:min_length]
|
41 |
+
mse = np.mean((clean_signal - separated_signal) ** 2)
|
42 |
+
if mse == 0:
|
43 |
+
return float('inf')
|
44 |
+
max_val = np.max(np.abs(clean_signal))
|
45 |
+
return 10 * np.log10((max_val ** 2) / mse)
|
46 |
+
|
47 |
+
def calculate_melspectrogram_kl_divergence(clean_signal, separated_signal, sr):
|
48 |
+
clean_mel = compute_mel_spectrogram(clean_signal, sr)
|
49 |
+
separated_mel = compute_mel_spectrogram(separated_signal, sr)
|
50 |
+
clean_mel_norm = clean_mel / np.sum(clean_mel)
|
51 |
+
separated_mel_norm = separated_mel / np.sum(separated_mel)
|
52 |
+
return np.sum(rel_entr(np.clip(clean_mel_norm, 1e-10, None), np.clip(separated_mel_norm, 1e-10, None)))
|
53 |
+
|
54 |
+
def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128):
|
55 |
+
return librosa.feature.melspectrogram(
|
56 |
+
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
|
57 |
+
)
|
58 |
+
|
59 |
+
# Main function used in Gradio
|
60 |
+
|
61 |
+
def process_audio(file_path):
|
62 |
+
signal = nussl.AudioSignal(file_path)
|
63 |
+
mix_signal, sr1 = librosa.load(file_path, sr=None)
|
64 |
+
|
65 |
+
ft2d_bg, ft2d_fg = Two_DFT(signal)
|
66 |
+
repet_bg, repet_fg = Repet(signal)
|
67 |
+
rsim_bg, rsim_fg = Repet_Sim(signal)
|
68 |
+
|
69 |
+
output_file1 = "output_foreground_2dft.wav"
|
70 |
+
output_file2 = "output_foreground_repet.wav"
|
71 |
+
output_file3 = "output_foreground_rsim.wav"
|
72 |
+
|
73 |
+
ft2d_fg.write_audio_to_file(output_file1)
|
74 |
+
repet_fg.write_audio_to_file(output_file2)
|
75 |
+
rsim_fg.write_audio_to_file(output_file3)
|
76 |
+
|
77 |
+
output_snr1 = calculate_psnr(signal.audio_data, ft2d_fg.audio_data)
|
78 |
+
output_snr2 = calculate_psnr(signal.audio_data, repet_fg.audio_data)
|
79 |
+
output_snr3 = calculate_psnr(signal.audio_data, rsim_fg.audio_data)
|
80 |
+
|
81 |
+
output_kl1 = calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr1)
|
82 |
+
output_kl2 = calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr1)
|
83 |
+
output_kl3 = calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr1)
|
84 |
+
|
85 |
+
return output_file1, output_snr1, output_kl1, output_file2, output_snr2, output_kl2, output_file3, output_snr3, output_kl3
|