File size: 3,242 Bytes
cb092c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# separation_utils.py
import numpy as np
import os
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.special import rel_entr
import nussl
import types

# Required to patch mask validation for nussl

def _validate_mask_patched(self, mask_):
    assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
    if mask_.dtype == bool:
        return mask_
    mask_ = mask_ > 0.5
    if not np.all(np.logical_or(mask_, np.logical_not(mask_))):
        raise ValueError('All mask entries must be 0 or 1.')
    return mask_

nussl.core.masks.binary_mask.BinaryMask._validate_mask = types.MethodType(
    _validate_mask_patched, nussl.core.masks.binary_mask.BinaryMask)

# Separation methods
def Repet(mix):
    return nussl.separation.primitive.Repet(mix)( )

def Repet_Sim(mix):
    return nussl.separation.primitive.RepetSim(mix)( )

def Two_DFT(mix):
    return nussl.separation.primitive.FT2D(mix)( )

# Audio metrics
def calculate_psnr(clean_signal, separated_signal):
    min_length = min(len(clean_signal), len(separated_signal))
    clean_signal = clean_signal[:min_length]
    separated_signal = separated_signal[:min_length]
    mse = np.mean((clean_signal - separated_signal) ** 2)
    if mse == 0:
        return float('inf')
    max_val = np.max(np.abs(clean_signal))
    return 10 * np.log10((max_val ** 2) / mse)

def calculate_melspectrogram_kl_divergence(clean_signal, separated_signal, sr):
    clean_mel = compute_mel_spectrogram(clean_signal, sr)
    separated_mel = compute_mel_spectrogram(separated_signal, sr)
    clean_mel_norm = clean_mel / np.sum(clean_mel)
    separated_mel_norm = separated_mel / np.sum(separated_mel)
    return np.sum(rel_entr(np.clip(clean_mel_norm, 1e-10, None), np.clip(separated_mel_norm, 1e-10, None)))

def compute_mel_spectrogram(signal, sr, n_fft=2048, hop_length=512, n_mels=128):
    return librosa.feature.melspectrogram(
        y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, power=2.0
    )

# Main function used in Gradio

def process_audio(file_path):
    signal = nussl.AudioSignal(file_path)
    mix_signal, sr1 = librosa.load(file_path, sr=None)

    ft2d_bg, ft2d_fg = Two_DFT(signal)
    repet_bg, repet_fg = Repet(signal)
    rsim_bg, rsim_fg = Repet_Sim(signal)

    output_file1 = "output_foreground_2dft.wav"
    output_file2 = "output_foreground_repet.wav"
    output_file3 = "output_foreground_rsim.wav"

    ft2d_fg.write_audio_to_file(output_file1)
    repet_fg.write_audio_to_file(output_file2)
    rsim_fg.write_audio_to_file(output_file3)

    output_snr1 = calculate_psnr(signal.audio_data, ft2d_fg.audio_data)
    output_snr2 = calculate_psnr(signal.audio_data, repet_fg.audio_data)
    output_snr3 = calculate_psnr(signal.audio_data, rsim_fg.audio_data)

    output_kl1 = calculate_melspectrogram_kl_divergence(signal.audio_data, ft2d_fg.audio_data, sr1)
    output_kl2 = calculate_melspectrogram_kl_divergence(signal.audio_data, repet_fg.audio_data, sr1)
    output_kl3 = calculate_melspectrogram_kl_divergence(signal.audio_data, rsim_fg.audio_data, sr1)

    return output_file1, output_snr1, output_kl1, output_file2, output_snr2, output_kl2, output_file3, output_snr3, output_kl3