# separation_utils.py import numpy as np import os import librosa import soundfile as sf import matplotlib.pyplot as plt import seaborn as sns from scipy.special import rel_entr import nussl import types # Required to patch mask validation for nussl def _validate_mask_patched(self, mask_): assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!' if mask_.dtype == bool: return mask_ mask_ = mask_ > 0.5 if not np.all(np.logical_or(mask_, np.logical_not(mask_))): raise ValueError('All mask entries must be 0 or 1.') return mask_ nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType( _validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask) # Separation methods def Repet(mix): return nussl.separation.primitive.Repet(mix)( ) def Repet_Sim(mix): return nussl.separation.primitive.RepetSim(mix)( ) def Two_DFT(mix): return nussl.separation.primitive.FT2D(mix)( ) # Audio metrics def calculate_psnr(clean_signal, separated_signal): min_length = min(len(clean_signal), len(separated_signal)) clean_signal = clean_signal[:min_length] separated_signal = separated_signal[:min_length] mse = np.mean((clean_signal - separated_signal) ** 2) if mse == 0: return float('inf') max_val = np.max(np.abs(clean_signal)) return 10 * np.log10((max_val ** 2) / mse) def calculate_melspectrogram_kl_divergence(clean_signal, separated_signal, sr): clean_mel = compute_mel_spectrogram(clean_signal, sr) separated_mel = compute_mel_spectrogram(separated_signal, sr) clean_mel_norm = clean_mel / np.sum(clean_mel) separated_mel_norm = separated_mel / np.sum(separated_mel) return np.sum(rel_entr(np.clip(clean_mel_norm, 1e-10, None), np.clip(separated_mel_norm, 1e-10, None))) def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128): return librosa.feature.melspectrogram( y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0 ) # Main function used in Gradio def process_audio(file_path): signal = nussl.AudioSignal(file_path) mix_signal, sr1 = librosa.load(file_path, sr=None) ft2d_bg, ft2d_fg = Two_DFT(signal) repet_bg, repet_fg = Repet(signal) rsim_bg, rsim_fg = Repet_Sim(signal) output_file1 = "output_foreground_2dft.wav" output_file2 = "output_foreground_repet.wav" output_file3 = "output_foreground_rsim.wav" ft2d_fg.write_audio_to_file(output_file1) repet_fg.write_audio_to_file(output_file2) rsim_fg.write_audio_to_file(output_file3) output_snr1 = calculate_psnr(signal.audio_data, ft2d_fg.audio_data) output_snr2 = calculate_psnr(signal.audio_data, repet_fg.audio_data) output_snr3 = calculate_psnr(signal.audio_data, rsim_fg.audio_data) output_kl1 = calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr1) output_kl2 = calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr1) output_kl3 = calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr1) return output_file1, output_snr1, output_kl1, output_file2, output_snr2, output_kl2, output_file3, output_snr3, output_kl3