Spaces:
Runtime error
Runtime error
# separation_utils.py | |
import numpy as np | |
import os | |
import librosa | |
import soundfile as sf | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from scipy.special import rel_entr | |
import nussl | |
import types | |
# Required to patch mask validation for nussl | |
def _validate_mask_patched(self, mask_): | |
assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!' | |
if mask_.dtype == bool: | |
return mask_ | |
mask_ = mask_ > 0.5 | |
if not np.all(np.logical_or(mask_, np.logical_not(mask_))): | |
raise ValueError('All mask entries must be 0 or 1.') | |
return mask_ | |
nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType( | |
_validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask) | |
# Separation methods | |
def Repet(mix): | |
return nussl.separation.primitive.Repet(mix)( ) | |
def Repet_Sim(mix): | |
return nussl.separation.primitive.RepetSim(mix)( ) | |
def Two_DFT(mix): | |
return nussl.separation.primitive.FT2D(mix)( ) | |
# Audio metrics | |
def calculate_psnr(clean_signal, separated_signal): | |
min_length = min(len(clean_signal), len(separated_signal)) | |
clean_signal = clean_signal[:min_length] | |
separated_signal = separated_signal[:min_length] | |
mse = np.mean((clean_signal - separated_signal) ** 2) | |
if mse == 0: | |
return float('inf') | |
max_val = np.max(np.abs(clean_signal)) | |
return 10 * np.log10((max_val ** 2) / mse) | |
def calculate_melspectrogram_kl_divergence(clean_signal, separated_signal, sr): | |
clean_mel = compute_mel_spectrogram(clean_signal, sr) | |
separated_mel = compute_mel_spectrogram(separated_signal, sr) | |
clean_mel_norm = clean_mel / np.sum(clean_mel) | |
separated_mel_norm = separated_mel / np.sum(separated_mel) | |
return np.sum(rel_entr(np.clip(clean_mel_norm, 1e-10, None), np.clip(separated_mel_norm, 1e-10, None))) | |
def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128): | |
return librosa.feature.melspectrogram( | |
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0 | |
) | |
# Main function used in Gradio | |
def process_audio(file_path): | |
signal = nussl.AudioSignal(file_path) | |
mix_signal, sr1 = librosa.load(file_path, sr=None) | |
ft2d_bg, ft2d_fg = Two_DFT(signal) | |
repet_bg, repet_fg = Repet(signal) | |
rsim_bg, rsim_fg = Repet_Sim(signal) | |
output_file1 = "output_foreground_2dft.wav" | |
output_file2 = "output_foreground_repet.wav" | |
output_file3 = "output_foreground_rsim.wav" | |
ft2d_fg.write_audio_to_file(output_file1) | |
repet_fg.write_audio_to_file(output_file2) | |
rsim_fg.write_audio_to_file(output_file3) | |
output_snr1 = calculate_psnr(signal.audio_data, ft2d_fg.audio_data) | |
output_snr2 = calculate_psnr(signal.audio_data, repet_fg.audio_data) | |
output_snr3 = calculate_psnr(signal.audio_data, rsim_fg.audio_data) | |
output_kl1 = calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr1) | |
output_kl2 = calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr1) | |
output_kl3 = calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr1) | |
return output_file1, output_snr1, output_kl1, output_file2, output_snr2, output_kl2, output_file3, output_snr3, output_kl3 | |