Aashiue commited on
Commit
490fae8
·
verified ·
1 Parent(s): 4af631f

Update separation_utils.py

Browse files
Files changed (1) hide show
  1. separation_utils.py +76 -22
separation_utils.py CHANGED
@@ -8,8 +8,11 @@ import seaborn as sns
8
  from scipy.special import rel_entr
9
  import nussl
10
  import types
 
 
 
 
11
 
12
- # Required to patch mask validation for nussl
13
 
14
  def _validate_mask_patched(self, mask_):
15
  assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
@@ -23,7 +26,7 @@ def _validate_mask_patched(self, mask_):
23
  nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
24
  _validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)
25
 
26
- # Separation methods
27
  def Repet(mix):
28
  return nussl.separation.primitive.Repet(mix)( )
29
 
@@ -33,7 +36,7 @@ def Repet_Sim(mix):
33
  def Two_DFT(mix):
34
  return nussl.separation.primitive.FT2D(mix)( )
35
 
36
- # Audio metrics
37
  def calculate_psnr(clean_signal, separated_signal):
38
  min_length = min(len(clean_signal), len(separated_signal))
39
  clean_signal = clean_signal[:min_length]
@@ -56,30 +59,81 @@ def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128):
56
  y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
57
  )
58
 
59
- # Main function used in Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def process_audio(file_path):
62
  signal = nussl.AudioSignal(file_path)
63
- mix_signal, sr1 = librosa.load(file_path, sr=None)
64
 
65
  ft2d_bg, ft2d_fg = Two_DFT(signal)
66
  repet_bg, repet_fg = Repet(signal)
67
  rsim_bg, rsim_fg = Repet_Sim(signal)
68
 
69
- output_file1 = "output_foreground_2dft.wav"
70
- output_file2 = "output_foreground_repet.wav"
71
- output_file3 = "output_foreground_rsim.wav"
72
-
73
- ft2d_fg.write_audio_to_file(output_file1)
74
- repet_fg.write_audio_to_file(output_file2)
75
- rsim_fg.write_audio_to_file(output_file3)
76
-
77
- output_snr1 = calculate_psnr(signal.audio_data, ft2d_fg.audio_data)
78
- output_snr2 = calculate_psnr(signal.audio_data, repet_fg.audio_data)
79
- output_snr3 = calculate_psnr(signal.audio_data, rsim_fg.audio_data)
80
-
81
- output_kl1 = calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr1)
82
- output_kl2 = calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr1)
83
- output_kl3 = calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr1)
84
-
85
- return output_file1, output_snr1, output_kl1, output_file2, output_snr2, output_kl2, output_file3, output_snr3, output_kl3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from scipy.special import rel_entr
9
  import nussl
10
  import types
11
+ import pandas as pd
12
+ from sklearn.ensemble import RandomForestClassifier
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.metrics import classification_report
15
 
 
16
 
17
  def _validate_mask_patched(self, mask_):
18
  assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
 
26
  nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
27
  _validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)
28
 
29
+
30
  def Repet(mix):
31
  return nussl.separation.primitive.Repet(mix)( )
32
 
 
36
  def Two_DFT(mix):
37
  return nussl.separation.primitive.FT2D(mix)( )
38
 
39
+
40
  def calculate_psnr(clean_signal, separated_signal):
41
  min_length = min(len(clean_signal), len(separated_signal))
42
  clean_signal = clean_signal[:min_length]
 
59
  y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
60
  )
61
 
62
+ def extract_features(audio, sr, frame_size=5046, hop_length=2048):
63
+ zcr = librosa.feature.zero_crossing_rate(audio, frame_length=frame_size, hop_length=hop_length)
64
+ rms = librosa.feature.rms(y=audio, frame_length=frame_size, hop_length=hop_length)
65
+ spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr, hop_length=hop_length)
66
+ features = np.vstack((zcr, rms, spectral_centroid)).T
67
+ return features
68
+
69
+ def process_pipeline(fg_path, bg_path, sr):
70
+ fg_audio, _ = librosa.load(fg_path, sr=sr)
71
+ bg_audio, _ = librosa.load(bg_path, sr=sr)
72
+ fg_features = extract_features(fg_audio, sr)
73
+ bg_features = extract_features(bg_audio, sr)
74
+ fg_labels = np.ones(fg_features.shape[0])
75
+ bg_labels = np.zeros(bg_features.shape[0])
76
+ features = np.vstack((fg_features, bg_features))
77
+ labels = np.hstack((fg_labels, bg_labels))
78
+ return features, labels
79
+
80
+ def train_rf_model(X, y):
81
+ X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
82
+ clf = RandomForestClassifier(n_estimators=100, random_state=42)
83
+ clf.fit(X_train, y_train)
84
+ y_pred = clf.predict(X_val)
85
+ print(classification_report(y_val, y_pred))
86
+ return clf
87
+
88
+ def reconstruct_audio(mixed_audio, labels, sr, frame_size=2048, hop_length=512):
89
+ frames = librosa.util.frame(mixed_audio, frame_length=frame_size, hop_length=hop_length).T
90
+ labels = labels[:frames.shape[0]]
91
+ fg_frames = frames[labels == 1.0] if np.any(labels == 1.0) else np.zeros_like(frames[:1])
92
+ bg_frames = frames[labels == 0.0] if np.any(labels == 0.0) else np.zeros_like(frames[:1])
93
+ fg_audio = librosa.istft(fg_frames.T, hop_length=hop_length) if fg_frames.shape[0] > 0 else np.zeros_like(mixed_audio)
94
+ bg_audio = librosa.istft(bg_frames.T, hop_length=hop_length) if bg_frames.shape[0] > 0 else np.zeros_like(mixed_audio)
95
+ return fg_audio, bg_audio
96
 
97
  def process_audio(file_path):
98
  signal = nussl.AudioSignal(file_path)
99
+ mix_signal, sr = librosa.load(file_path, sr=None)
100
 
101
  ft2d_bg, ft2d_fg = Two_DFT(signal)
102
  repet_bg, repet_fg = Repet(signal)
103
  rsim_bg, rsim_fg = Repet_Sim(signal)
104
 
105
+ # Save the 3 outputs
106
+ fg_paths = {
107
+ "2dft": "output_foreground_2dft.wav",
108
+ "repet": "output_foreground_repet.wav",
109
+ "rsim": "output_foreground_rsim.wav"
110
+ }
111
+ ft2d_fg.write_audio_to_file(fg_paths["2dft"])
112
+ repet_fg.write_audio_to_file(fg_paths["repet"])
113
+ rsim_fg.write_audio_to_file(fg_paths["rsim"])
114
+
115
+ # Select best for training
116
+ fg_path, bg_path = fg_paths["rsim"], fg_paths["repet"] # Use RepetSim FG and Repet BG
117
+
118
+ features, labels = process_pipeline(fg_path, bg_path, sr)
119
+ clf = train_rf_model(features, labels)
120
+
121
+ test_features = extract_features(mix_signal, sr)
122
+ predicted_labels = clf.predict(test_features)
123
+ fg_rec, bg_rec = reconstruct_audio(mix_signal, predicted_labels, sr)
124
+
125
+ fg_rf_path = "output_foreground_rf.wav"
126
+ bg_rf_path = "output_background_rf.wav"
127
+ sf.write(fg_rf_path, fg_rec, sr)
128
+ sf.write(bg_rf_path, bg_rec, sr)
129
+
130
+ psnr_rf = calculate_psnr(signal.audio_data, fg_rec)
131
+ kl_rf = calculate_melspectrogram_kl_divergence(signal.audio_data, fg_rec, sr)
132
+
133
+ return (
134
+ fg_paths["2dft"], calculate_psnr(signal.audio_data, ft2d_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr),
135
+ fg_paths["repet"], calculate_psnr(signal.audio_data, repet_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr),
136
+ fg_paths["rsim"], calculate_psnr(signal.audio_data, rsim_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr),
137
+ fg_rf_path, psnr_rf, kl_rf,
138
+ bg_rf_path
139
+ )