Spaces:
Runtime error
Runtime error
File size: 3,242 Bytes
cb092c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# separation_utils.py
import numpy as np
import os
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.special import rel_entr
import nussl
import types
# Required to patch mask validation for nussl
def _validate_mask_patched(self, mask_):
assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
if mask_.dtype == bool:
return mask_
mask_ = mask_ > 0.5
if not np.all(np.logical_or(mask_, np.logical_not(mask_))):
raise ValueError('All mask entries must be 0 or 1.')
return mask_
nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
_validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)
# Separation methods
def Repet(mix):
return nussl.separation.primitive.Repet(mix)( )
def Repet_Sim(mix):
return nussl.separation.primitive.RepetSim(mix)( )
def Two_DFT(mix):
return nussl.separation.primitive.FT2D(mix)( )
# Audio metrics
def calculate_psnr(clean_signal, separated_signal):
min_length = min(len(clean_signal), len(separated_signal))
clean_signal = clean_signal[:min_length]
separated_signal = separated_signal[:min_length]
mse = np.mean((clean_signal - separated_signal) ** 2)
if mse == 0:
return float('inf')
max_val = np.max(np.abs(clean_signal))
return 10 * np.log10((max_val ** 2) / mse)
def calculate_melspectrogram_kl_divergence(clean_signal, separated_signal, sr):
clean_mel = compute_mel_spectrogram(clean_signal, sr)
separated_mel = compute_mel_spectrogram(separated_signal, sr)
clean_mel_norm = clean_mel / np.sum(clean_mel)
separated_mel_norm = separated_mel / np.sum(separated_mel)
return np.sum(rel_entr(np.clip(clean_mel_norm, 1e-10, None), np.clip(separated_mel_norm, 1e-10, None)))
def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128):
return librosa.feature.melspectrogram(
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
)
# Main function used in Gradio
def process_audio(file_path):
signal = nussl.AudioSignal(file_path)
mix_signal, sr1 = librosa.load(file_path, sr=None)
ft2d_bg, ft2d_fg = Two_DFT(signal)
repet_bg, repet_fg = Repet(signal)
rsim_bg, rsim_fg = Repet_Sim(signal)
output_file1 = "output_foreground_2dft.wav"
output_file2 = "output_foreground_repet.wav"
output_file3 = "output_foreground_rsim.wav"
ft2d_fg.write_audio_to_file(output_file1)
repet_fg.write_audio_to_file(output_file2)
rsim_fg.write_audio_to_file(output_file3)
output_snr1 = calculate_psnr(signal.audio_data, ft2d_fg.audio_data)
output_snr2 = calculate_psnr(signal.audio_data, repet_fg.audio_data)
output_snr3 = calculate_psnr(signal.audio_data, rsim_fg.audio_data)
output_kl1 = calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr1)
output_kl2 = calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr1)
output_kl3 = calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr1)
return output_file1, output_snr1, output_kl1, output_file2, output_snr2, output_kl2, output_file3, output_snr3, output_kl3
|