Aashiue commited on
Commit
cb092c7
·
verified ·
1 Parent(s): ed3942b

Create separation_utils.py

Browse files
Files changed (1) hide show
  1. separation_utils.py +85 -0
separation_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # separation_utils.py
2
+ import numpy as np
3
+ import os
4
+ import librosa
5
+ import soundfile as sf
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ from scipy.special import rel_entr
9
+ import nussl
10
+ import types
11
+
12
+ # Required to patch mask validation for nussl
13
+
14
+ def _validate_mask_patched(self, mask_):
15
+ assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
16
+ if mask_.dtype == bool:
17
+ return mask_
18
+ mask_ = mask_ > 0.5
19
+ if not np.all(np.logical_or(mask_, np.logical_not(mask_))):
20
+ raise ValueError('All mask entries must be 0 or 1.')
21
+ return mask_
22
+
23
+ nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
24
+ _validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)
25
+
26
+ # Separation methods
27
+ def Repet(mix):
28
+ return nussl.separation.primitive.Repet(mix)( )
29
+
30
+ def Repet_Sim(mix):
31
+ return nussl.separation.primitive.RepetSim(mix)( )
32
+
33
+ def Two_DFT(mix):
34
+ return nussl.separation.primitive.FT2D(mix)( )
35
+
36
+ # Audio metrics
37
+ def calculate_psnr(clean_signal, separated_signal):
38
+ min_length = min(len(clean_signal), len(separated_signal))
39
+ clean_signal = clean_signal[:min_length]
40
+ separated_signal = separated_signal[:min_length]
41
+ mse = np.mean((clean_signal - separated_signal) ** 2)
42
+ if mse == 0:
43
+ return float('inf')
44
+ max_val = np.max(np.abs(clean_signal))
45
+ return 10 * np.log10((max_val ** 2) / mse)
46
+
47
+ def calculate_melspectrogram_kl_divergence(clean_signal, separated_signal, sr):
48
+ clean_mel = compute_mel_spectrogram(clean_signal, sr)
49
+ separated_mel = compute_mel_spectrogram(separated_signal, sr)
50
+ clean_mel_norm = clean_mel / np.sum(clean_mel)
51
+ separated_mel_norm = separated_mel / np.sum(separated_mel)
52
+ return np.sum(rel_entr(np.clip(clean_mel_norm, 1e-10, None), np.clip(separated_mel_norm, 1e-10, None)))
53
+
54
+ def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128):
55
+ return librosa.feature.melspectrogram(
56
+ y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
57
+ )
58
+
59
+ # Main function used in Gradio
60
+
61
+ def process_audio(file_path):
62
+ signal = nussl.AudioSignal(file_path)
63
+ mix_signal, sr1 = librosa.load(file_path, sr=None)
64
+
65
+ ft2d_bg, ft2d_fg = Two_DFT(signal)
66
+ repet_bg, repet_fg = Repet(signal)
67
+ rsim_bg, rsim_fg = Repet_Sim(signal)
68
+
69
+ output_file1 = "output_foreground_2dft.wav"
70
+ output_file2 = "output_foreground_repet.wav"
71
+ output_file3 = "output_foreground_rsim.wav"
72
+
73
+ ft2d_fg.write_audio_to_file(output_file1)
74
+ repet_fg.write_audio_to_file(output_file2)
75
+ rsim_fg.write_audio_to_file(output_file3)
76
+
77
+ output_snr1 = calculate_psnr(signal.audio_data, ft2d_fg.audio_data)
78
+ output_snr2 = calculate_psnr(signal.audio_data, repet_fg.audio_data)
79
+ output_snr3 = calculate_psnr(signal.audio_data, rsim_fg.audio_data)
80
+
81
+ output_kl1 = calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr1)
82
+ output_kl2 = calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr1)
83
+ output_kl3 = calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr1)
84
+
85
+ return output_file1, output_snr1, output_kl1, output_file2, output_snr2, output_kl2, output_file3, output_snr3, output_kl3