Spaces:
Runtime error
Runtime error
# separation_utils.py | |
import numpy as np | |
import os | |
import librosa | |
import soundfile as sf | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from scipy.special import rel_entr | |
import nussl | |
import types | |
import pandas as pd | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import classification_report | |
import collections | |
import collections.abc | |
collections.MutableMapping = collections.abc.MutableMapping | |
collections.Sequence = collections.abc.Sequence | |
collections.Mapping = collections.abc.Mapping | |
def _validate_mask_patched(self, mask_): | |
assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!' | |
if mask_.dtype == bool: | |
return mask_ | |
mask_ = mask_ > 0.5 | |
if not np.all(np.logical_or(mask_, np.logical_not(mask_))): | |
raise ValueError('All mask entries must be 0 or 1.') | |
return mask_ | |
nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType( | |
_validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask) | |
def Repet(mix): | |
return nussl.separation.primitive.Repet(mix)( ) | |
def Repet_Sim(mix): | |
return nussl.separation.primitive.RepetSim(mix)( ) | |
def Two_DFT(mix): | |
return nussl.separation.primitive.FT2D(mix)( ) | |
def calculate_psnr(clean_signal, separated_signal): | |
min_length = min(len(clean_signal), len(separated_signal)) | |
clean_signal = clean_signal[:min_length] | |
separated_signal = separated_signal[:min_length] | |
mse = np.mean((clean_signal - separated_signal) ** 2) | |
if mse == 0: | |
return float('inf') | |
max_val = np.max(np.abs(clean_signal)) | |
return 10 * np.log10((max_val ** 2) / mse) | |
def calculate_melspectrogram_kl_divergence(clean_signal, separated_signal, sr): | |
clean_mel = compute_mel_spectrogram(clean_signal, sr) | |
separated_mel = compute_mel_spectrogram(separated_signal, sr) | |
clean_mel_norm = clean_mel / np.sum(clean_mel) | |
separated_mel_norm = separated_mel / np.sum(separated_mel) | |
return np.sum(rel_entr(np.clip(clean_mel_norm, 1e-10, None), np.clip(separated_mel_norm, 1e-10, None))) | |
def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128): | |
return librosa.feature.melspectrogram( | |
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0 | |
) | |
def extract_features(audio, sr, frame_size=5046, hop_length=2048): | |
zcr = librosa.feature.zero_crossing_rate(audio, frame_length=frame_size, hop_length=hop_length) | |
rms = librosa.feature.rms(y=audio, frame_length=frame_size, hop_length=hop_length) | |
spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr, hop_length=hop_length) | |
features = np.vstack((zcr, rms, spectral_centroid)).T | |
return features | |
def process_pipeline(fg_path, bg_path, sr): | |
fg_audio, _ = librosa.load(fg_path, sr=sr) | |
bg_audio, _ = librosa.load(bg_path, sr=sr) | |
fg_features = extract_features(fg_audio, sr) | |
bg_features = extract_features(bg_audio, sr) | |
fg_labels = np.ones(fg_features.shape[0]) | |
bg_labels = np.zeros(bg_features.shape[0]) | |
features = np.vstack((fg_features, bg_features)) | |
labels = np.hstack((fg_labels, bg_labels)) | |
return features, labels | |
def train_rf_model(X, y): | |
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) | |
clf = RandomForestClassifier(n_estimators=100, random_state=42) | |
clf.fit(X_train, y_train) | |
y_pred = clf.predict(X_val) | |
print(classification_report(y_val, y_pred)) | |
return clf | |
def reconstruct_audio(mixed_audio, labels, sr, frame_size=2048, hop_length=512): | |
frames = librosa.util.frame(mixed_audio, frame_length=frame_size, hop_length=hop_length).T | |
labels = labels[:frames.shape[0]] | |
fg_frames = frames[labels == 1.0] if np.any(labels == 1.0) else np.zeros_like(frames[:1]) | |
bg_frames = frames[labels == 0.0] if np.any(labels == 0.0) else np.zeros_like(frames[:1]) | |
fg_audio = librosa.istft(fg_frames.T, hop_length=hop_length) if fg_frames.shape[0] > 0 else np.zeros_like(mixed_audio) | |
bg_audio = librosa.istft(bg_frames.T, hop_length=hop_length) if bg_frames.shape[0] > 0 else np.zeros_like(mixed_audio) | |
return fg_audio, bg_audio | |
def process_audio(file_path): | |
signal = nussl.AudioSignal(file_path) | |
mix_signal, sr = librosa.load(file_path, sr=None) | |
ft2d_bg, ft2d_fg = Two_DFT(signal) | |
repet_bg, repet_fg = Repet(signal) | |
rsim_bg, rsim_fg = Repet_Sim(signal) | |
# Save the 3 outputs | |
fg_paths = { | |
"2dft": "output_foreground_2dft.wav", | |
"repet": "output_foreground_repet.wav", | |
"rsim": "output_foreground_rsim.wav" | |
} | |
ft2d_fg.write_audio_to_file(fg_paths["2dft"]) | |
repet_fg.write_audio_to_file(fg_paths["repet"]) | |
rsim_fg.write_audio_to_file(fg_paths["rsim"]) | |
# Select best for training | |
fg_path, bg_path = fg_paths["rsim"], fg_paths["repet"] # Use RepetSim FG and Repet BG | |
features, labels = process_pipeline(fg_path, bg_path, sr) | |
clf = train_rf_model(features, labels) | |
test_features = extract_features(mix_signal, sr) | |
predicted_labels = clf.predict(test_features) | |
fg_rec, bg_rec = reconstruct_audio(mix_signal, predicted_labels, sr) | |
fg_rf_path = "output_foreground_rf.wav" | |
bg_rf_path = "output_background_rf.wav" | |
sf.write(fg_rf_path, fg_rec, sr) | |
sf.write(bg_rf_path, bg_rec, sr) | |
psnr_rf = calculate_psnr(signal.audio_data, fg_rec) | |
kl_rf = calculate_melspectrogram_kl_divergence(signal.audio_data, fg_rec, sr) | |
return ( | |
fg_paths["2dft"], calculate_psnr(signal.audio_data, ft2d_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr), | |
fg_paths["repet"], calculate_psnr(signal.audio_data, repet_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr), | |
fg_paths["rsim"], calculate_psnr(signal.audio_data, rsim_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr), | |
fg_rf_path, psnr_rf, kl_rf, | |
bg_rf_path | |
) | |