SpeakerSourceSeparation / separation_utils.py
Aashiue's picture
Update separation_utils.py
e66c674 verified
# separation_utils.py
import numpy as np
import os
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.special import rel_entr
import nussl
import types
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import collections
import collections.abc
collections.MutableMapping = collections.abc.MutableMapping
collections.Sequence = collections.abc.Sequence
collections.Mapping = collections.abc.Mapping
def _validate_mask_patched(self, mask_):
assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
if mask_.dtype == bool:
return mask_
mask_ = mask_ > 0.5
if not np.all(np.logical_or(mask_, np.logical_not(mask_))):
raise ValueError('All mask entries must be 0 or 1.')
return mask_
nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
_validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)
def Repet(mix):
return nussl.separation.primitive.Repet(mix)( )
def Repet_Sim(mix):
return nussl.separation.primitive.RepetSim(mix)( )
def Two_DFT(mix):
return nussl.separation.primitive.FT2D(mix)( )
def calculate_psnr(clean_signal, separated_signal):
min_length = min(len(clean_signal), len(separated_signal))
clean_signal = clean_signal[:min_length]
separated_signal = separated_signal[:min_length]
mse = np.mean((clean_signal - separated_signal) ** 2)
if mse == 0:
return float('inf')
max_val = np.max(np.abs(clean_signal))
return 10 * np.log10((max_val ** 2) / mse)
def calculate_melspectrogram_kl_divergence(clean_signal, separated_signal, sr):
clean_mel = compute_mel_spectrogram(clean_signal, sr)
separated_mel = compute_mel_spectrogram(separated_signal, sr)
clean_mel_norm = clean_mel / np.sum(clean_mel)
separated_mel_norm = separated_mel / np.sum(separated_mel)
return np.sum(rel_entr(np.clip(clean_mel_norm, 1e-10, None), np.clip(separated_mel_norm, 1e-10, None)))
def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128):
return librosa.feature.melspectrogram(
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
)
def extract_features(audio, sr, frame_size=5046, hop_length=2048):
zcr = librosa.feature.zero_crossing_rate(audio, frame_length=frame_size, hop_length=hop_length)
rms = librosa.feature.rms(y=audio, frame_length=frame_size, hop_length=hop_length)
spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr, hop_length=hop_length)
features = np.vstack((zcr, rms, spectral_centroid)).T
return features
def process_pipeline(fg_path, bg_path, sr):
fg_audio, _ = librosa.load(fg_path, sr=sr)
bg_audio, _ = librosa.load(bg_path, sr=sr)
fg_features = extract_features(fg_audio, sr)
bg_features = extract_features(bg_audio, sr)
fg_labels = np.ones(fg_features.shape[0])
bg_labels = np.zeros(bg_features.shape[0])
features = np.vstack((fg_features, bg_features))
labels = np.hstack((fg_labels, bg_labels))
return features, labels
def train_rf_model(X, y):
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_val)
print(classification_report(y_val, y_pred))
return clf
def reconstruct_audio(mixed_audio, labels, sr, frame_size=2048, hop_length=512):
frames = librosa.util.frame(mixed_audio, frame_length=frame_size, hop_length=hop_length).T
labels = labels[:frames.shape[0]]
fg_frames = frames[labels == 1.0] if np.any(labels == 1.0) else np.zeros_like(frames[:1])
bg_frames = frames[labels == 0.0] if np.any(labels == 0.0) else np.zeros_like(frames[:1])
fg_audio = librosa.istft(fg_frames.T, hop_length=hop_length) if fg_frames.shape[0] > 0 else np.zeros_like(mixed_audio)
bg_audio = librosa.istft(bg_frames.T, hop_length=hop_length) if bg_frames.shape[0] > 0 else np.zeros_like(mixed_audio)
return fg_audio, bg_audio
def process_audio(file_path):
signal = nussl.AudioSignal(file_path)
mix_signal, sr = librosa.load(file_path, sr=None)
ft2d_bg, ft2d_fg = Two_DFT(signal)
repet_bg, repet_fg = Repet(signal)
rsim_bg, rsim_fg = Repet_Sim(signal)
# Save the 3 outputs
fg_paths = {
"2dft": "output_foreground_2dft.wav",
"repet": "output_foreground_repet.wav",
"rsim": "output_foreground_rsim.wav"
}
ft2d_fg.write_audio_to_file(fg_paths["2dft"])
repet_fg.write_audio_to_file(fg_paths["repet"])
rsim_fg.write_audio_to_file(fg_paths["rsim"])
# Select best for training
fg_path, bg_path = fg_paths["rsim"], fg_paths["repet"] # Use RepetSim FG and Repet BG
features, labels = process_pipeline(fg_path, bg_path, sr)
clf = train_rf_model(features, labels)
test_features = extract_features(mix_signal, sr)
predicted_labels = clf.predict(test_features)
fg_rec, bg_rec = reconstruct_audio(mix_signal, predicted_labels, sr)
fg_rf_path = "output_foreground_rf.wav"
bg_rf_path = "output_background_rf.wav"
sf.write(fg_rf_path, fg_rec, sr)
sf.write(bg_rf_path, bg_rec, sr)
psnr_rf = calculate_psnr(signal.audio_data, fg_rec)
kl_rf = calculate_melspectrogram_kl_divergence(signal.audio_data, fg_rec, sr)
return (
fg_paths["2dft"], calculate_psnr(signal.audio_data, ft2d_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr),
fg_paths["repet"], calculate_psnr(signal.audio_data, repet_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr),
fg_paths["rsim"], calculate_psnr(signal.audio_data, rsim_fg.audio_data), calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr),
fg_rf_path, psnr_rf, kl_rf,
bg_rf_path
)