Spaces:
Runtime error
Runtime error
Update separation_utils.py
Browse files- separation_utils.py +76 -22
separation_utils.py
CHANGED
@@ -8,8 +8,11 @@ import seaborn as sns
|
|
8 |
from scipy.special import rel_entr
|
9 |
import nussl
|
10 |
import types
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
# Required to patch mask validation for nussl
|
13 |
|
14 |
def _validate_mask_patched(self, mask_):
|
15 |
assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
|
@@ -23,7 +26,7 @@ def _validate_mask_patched(self, mask_):
|
|
23 |
nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
|
24 |
_validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)
|
25 |
|
26 |
-
|
27 |
def Repet(mix):
|
28 |
return nussl.separation.primitive.Repet(mix)( )
|
29 |
|
@@ -33,7 +36,7 @@ def Repet_Sim(mix):
|
|
33 |
def Two_DFT(mix):
|
34 |
return nussl.separation.primitive.FT2D(mix)( )
|
35 |
|
36 |
-
|
37 |
def calculate_psnr(clean_signal, separated_signal):
|
38 |
min_length = min(len(clean_signal), len(separated_signal))
|
39 |
clean_signal = clean_signal[:min_length]
|
@@ -56,30 +59,81 @@ def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128):
|
|
56 |
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
|
57 |
)
|
58 |
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def process_audio(file_path):
|
62 |
signal = nussl.AudioSignal(file_path)
|
63 |
-
mix_signal,
|
64 |
|
65 |
ft2d_bg, ft2d_fg = Two_DFT(signal)
|
66 |
repet_bg, repet_fg = Repet(signal)
|
67 |
rsim_bg, rsim_fg = Repet_Sim(signal)
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from scipy.special import rel_entr
|
9 |
import nussl
|
10 |
import types
|
11 |
+
import pandas as pd
|
12 |
+
from sklearn.ensemble import RandomForestClassifier
|
13 |
+
from sklearn.model_selection import train_test_split
|
14 |
+
from sklearn.metrics import classification_report
|
15 |
|
|
|
16 |
|
17 |
def _validate_mask_patched(self, mask_):
|
18 |
assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
|
|
|
26 |
nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
|
27 |
_validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)
|
28 |
|
29 |
+
|
30 |
def Repet(mix):
|
31 |
return nussl.separation.primitive.Repet(mix)( )
|
32 |
|
|
|
36 |
def Two_DFT(mix):
|
37 |
return nussl.separation.primitive.FT2D(mix)( )
|
38 |
|
39 |
+
|
40 |
def calculate_psnr(clean_signal, separated_signal):
|
41 |
min_length = min(len(clean_signal), len(separated_signal))
|
42 |
clean_signal = clean_signal[:min_length]
|
|
|
59 |
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
|
60 |
)
|
61 |
|
62 |
+
def extract_features(audio, sr, frame_size=5046, hop_length=2048):
|
63 |
+
zcr = librosa.feature.zero_crossing_rate(audio, frame_length=frame_size, hop_length=hop_length)
|
64 |
+
rms = librosa.feature.rms(y=audio, frame_length=frame_size, hop_length=hop_length)
|
65 |
+
spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr, hop_length=hop_length)
|
66 |
+
features = np.vstack((zcr, rms, spectral_centroid)).T
|
67 |
+
return features
|
68 |
+
|
69 |
+
def process_pipeline(fg_path, bg_path, sr):
|
70 |
+
fg_audio, _ = librosa.load(fg_path, sr=sr)
|
71 |
+
bg_audio, _ = librosa.load(bg_path, sr=sr)
|
72 |
+
fg_features = extract_features(fg_audio, sr)
|
73 |
+
bg_features = extract_features(bg_audio, sr)
|
74 |
+
fg_labels = np.ones(fg_features.shape[0])
|
75 |
+
bg_labels = np.zeros(bg_features.shape[0])
|
76 |
+
features = np.vstack((fg_features, bg_features))
|
77 |
+
labels = np.hstack((fg_labels, bg_labels))
|
78 |
+
return features, labels
|
79 |
+
|
80 |
+
def train_rf_model(X, y):
|
81 |
+
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
|
82 |
+
clf = RandomForestClassifier(n_estimators=100, random_state=42)
|
83 |
+
clf.fit(X_train, y_train)
|
84 |
+
y_pred = clf.predict(X_val)
|
85 |
+
print(classification_report(y_val, y_pred))
|
86 |
+
return clf
|
87 |
+
|
88 |
+
def reconstruct_audio(mixed_audio, labels, sr, frame_size=2048, hop_length=512):
|
89 |
+
frames = librosa.util.frame(mixed_audio, frame_length=frame_size, hop_length=hop_length).T
|
90 |
+
labels = labels[:frames.shape[0]]
|
91 |
+
fg_frames = frames[labels == 1.0] if np.any(labels == 1.0) else np.zeros_like(frames[:1])
|
92 |
+
bg_frames = frames[labels == 0.0] if np.any(labels == 0.0) else np.zeros_like(frames[:1])
|
93 |
+
fg_audio = librosa.istft(fg_frames.T, hop_length=hop_length) if fg_frames.shape[0] > 0 else np.zeros_like(mixed_audio)
|
94 |
+
bg_audio = librosa.istft(bg_frames.T, hop_length=hop_length) if bg_frames.shape[0] > 0 else np.zeros_like(mixed_audio)
|
95 |
+
return fg_audio, bg_audio
|
96 |
|
97 |
def process_audio(file_path):
|
98 |
signal = nussl.AudioSignal(file_path)
|
99 |
+
mix_signal, sr = librosa.load(file_path, sr=None)
|
100 |
|
101 |
ft2d_bg, ft2d_fg = Two_DFT(signal)
|
102 |
repet_bg, repet_fg = Repet(signal)
|
103 |
rsim_bg, rsim_fg = Repet_Sim(signal)
|
104 |
|
105 |
+
# Save the 3 outputs
|
106 |
+
fg_paths = {
|
107 |
+
"2dft": "output_foreground_2dft.wav",
|
108 |
+
"repet": "output_foreground_repet.wav",
|
109 |
+
"rsim": "output_foreground_rsim.wav"
|
110 |
+
}
|
111 |
+
ft2d_fg.write_audio_to_file(fg_paths["2dft"])
|
112 |
+
repet_fg.write_audio_to_file(fg_paths["repet"])
|
113 |
+
rsim_fg.write_audio_to_file(fg_paths["rsim"])
|
114 |
+
|
115 |
+
# Select best for training
|
116 |
+
fg_path, bg_path = fg_paths["rsim"], fg_paths["repet"] # Use RepetSim FG and Repet BG
|
117 |
+
|
118 |
+
features, labels = process_pipeline(fg_path, bg_path, sr)
|
119 |
+
clf = train_rf_model(features, labels)
|
120 |
+
|
121 |
+
test_features = extract_features(mix_signal, sr)
|
122 |
+
predicted_labels = clf.predict(test_features)
|
123 |
+
fg_rec, bg_rec = reconstruct_audio(mix_signal, predicted_labels, sr)
|
124 |
+
|
125 |
+
fg_rf_path = "output_foreground_rf.wav"
|
126 |
+
bg_rf_path = "output_background_rf.wav"
|
127 |
+
sf.write(fg_rf_path, fg_rec, sr)
|
128 |
+
sf.write(bg_rf_path, bg_rec, sr)
|
129 |
+
|
130 |
+
psnr_rf = calculate_psnr(signal.audio_data, fg_rec)
|
131 |
+
kl_rf = calculate_melspectrogram_kl_divergence(signal.audio_data, fg_rec, sr)
|
132 |
+
|
133 |
+
return (
|
134 |
+
fg_paths["2dft"], calculate_psnr(signal.audio_data, ft2d_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr),
|
135 |
+
fg_paths["repet"], calculate_psnr(signal.audio_data, repet_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr),
|
136 |
+
fg_paths["rsim"], calculate_psnr(signal.audio_data, rsim_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr),
|
137 |
+
fg_rf_path, psnr_rf, kl_rf,
|
138 |
+
bg_rf_path
|
139 |
+
)
|