Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,069 Bytes
96fe5d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
import os
import contextlib
from functools import partial
from tqdm import tqdm
import pickle
import numpy as np
import librosa
from hear21passt.base import get_basic_model
import pyloudnorm as pyln
import torch
import torch.nn.functional as F
SAMPLING_RATE = 32000
class _patch_passt_stft:
"""
From version 1.8.0, return_complex must always be given explicitly
for real inputs and return_complex=False has been deprecated.
Decorator to patch torch.stft in PaSST that uses an old stft version.
Adapted from: https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py
"""
def __init__(self):
self.old_stft = torch.stft
def __enter__(self):
# return_complex is a mandatory parameter in latest torch versions.
# torch is throwing RuntimeErrors when not set.
# see: https://pytorch.org/docs/1.7.1/generated/torch.stft.html?highlight=stft#torch.stft
# see: https://github.com/kkoutini/passt_hear21/commit/dce83183674e559162b49924d666c0a916dc967a
torch.stft = partial(torch.stft, return_complex=False)
def __exit__(self, *exc):
torch.stft = self.old_stft
def return_probabilities(model, audio_path, window_size=10, overlap=5, collect='mean'):
"""
Given an audio and the PaSST model, return the probabilities of each AudioSet class.
Audio is converted to mono at 32kHz.
PaSST model is trained with 10 sec inputs. We refer to this parameter as the window_size.
We set it to 10 sec for consistency with PaSST training.
For longer audios, we split audio into overlapping analysis windows of window_size and overlap of 10 and 5 seconds.
PaSST supports 10, 20 or 30 sec inputs. Not longer inputs: https://github.com/kkoutini/PaSST/issues/19
Note that AudioSet taggers normally use sigmoid output layers. Yet, to compute the
KL we work with normalized probabilities by running a softmax over logits as in MusicGen:
https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py
This implementation assumes run will be on GPU.
Params:
-- model: PaSST model on a GPU.
-- audio_path: path to the audio to be loaded with librosa.
-- window_size (default=10 sec): analysis window (and receptive field) of PaSST.
-- overlap (default=5 sec): overlap of the running analysis window for inputs longar than window_size (10 sec).
-- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along logits vector.
Returns:
-- 527 probabilities (after softmax, no logarithm).
"""
# load the audio using librosa
audio, _ = librosa.load(audio_path, sr=SAMPLING_RATE, mono=True)
audio = pyln.normalize.peak(audio, -1.0)
# calculate the step size for the analysis windows with the specified overlap
step_size = int((window_size - overlap) * SAMPLING_RATE)
# iterate over the audio, creating analysis windows
probabilities = []
for i in range(0, max(step_size, len(audio) - step_size), step_size):
# extract the current analysis window
window = audio[i:i + int(window_size * SAMPLING_RATE)]
# pad the window with zeros if it's shorter than the desired window size
if len(window) < int(window_size * SAMPLING_RATE):
# discard window if it's too small (avoid mostly zeros predicted as silence), as in MusicGen:
# https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py
if len(window) > int(window_size * SAMPLING_RATE * 0.15):
tmp = np.zeros(int(window_size * SAMPLING_RATE))
tmp[:len(window)] = window
window = tmp
# convert to a PyTorch tensor and move to GPU
audio_wave = torch.from_numpy(window.astype(np.float32)).unsqueeze(0).cuda()
# get the probabilities for this analysis window
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
with torch.no_grad(), _patch_passt_stft():
logits = model(audio_wave)
probabilities.append(torch.squeeze(logits))
probabilities = torch.stack(probabilities)
if collect == 'mean':
probabilities = torch.mean(probabilities, dim=0)
elif collect == 'max':
probabilities, _ = torch.max(probabilities, dim=0)
return F.softmax(probabilities, dim=0).squeeze().cpu()
def passt_kld(ids, eval_path, eval_files_extension='.wav', ref_path=None, ref_files_extension='.wav', load_ref_probabilities=None, no_ids=[], collect='mean'):
"""
Compute KL-divergence between the label probabilities of the generated audio with respect to the original audio.
Both generated audio (in eval_path) and original audio (in ref_path) are represented by the same prompt/description.
Audios are identified by an id, that is the name of the file in both directories and links the audio with the prompt/description.
segmenting the audio
For inputs longer that the 10 sec PaSST was trained on, we aggregate/collect via 'mean' (default) or 'max' pooling along the logits vector.
We split the inpot into overlapping analysis windows. Subsequently, we aggregate/collect (accross windows) the generated logits and then apply a softmax.
This evaluation script assumes that ids are in both ref_path and eval_path.
We label probabilities via the PaSST model: https://github.com/kkoutini/PaSST
GPU-based computation.
Extracting the probabilities is timeconsuming. After being computed once, we store them.
We store pre-computed reference probabilities in load/
To load those and save computation, just set the path in load_ref_probabilities.
If load_ref_probabilities is set, ref_path is not required.
Params:
-- ids: list of ids present in both eval_path and ref_path.
-- eval_path: path where the generated audio files to evaluate are available.
-- eval_files_extenstion: files extension (default .wav) in eval_path.
-- ref_path: path where the reference audio files are available. (instead of load_ref_probabilities)
-- ref_files_extenstion: files extension (default .wav) in ref_path.
-- load_ref_probabilities: path to the reference probabilities. (inestead of ref_path)
-- no_ids: it is possible that some reference audio is corrupted or not present. Ignore some this list of ids.
-- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along the logits vector.
Returns:
-- KL divergence
"""
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): # capturing all useless outputs from passt
# load model
model = get_basic_model(mode="logits")
model.eval()
model = model.cuda()
if not os.path.isdir(eval_path):
if not os.path.isfile(eval_path):
raise ValueError('eval_path does not exist')
if load_ref_probabilities:
if not os.path.exists(load_ref_probabilities):
raise ValueError('load_ref_probabilities does not exist')
print('[LOADING REFERENCE PROBABILITIES] ', load_ref_probabilities)
with open(load_ref_probabilities, 'rb') as fp:
ref_p = pickle.load(fp)
else:
if ref_path:
if not os.path.isdir(ref_path):
if os.path.isfile(ref_path):
id2utt = {}
with open(ref_path, "r") as f:
for line in f:
sec = line.strip().split(" ")
id2utt[sec[0]] = sec[1]
f.close()
else:
raise ValueError("ref_path does not exist")
print('[EXTRACTING REFERENCE PROBABILITIES] ', ref_path)
ref_p = {}
for id in tqdm(ids):
if id not in no_ids:
try:
if os.path.isfile(ref_path):
if id in id2utt.keys():
audio_path = id2utt[id]
else:
raise ValueError(f"id: {id} not in {ref_path}!")
else:
audio_path = os.path.join(ref_path, str(id)+ref_files_extension)
if os.path.isfile(audio_path):
ref_p[id] = return_probabilities(model, audio_path, collect=collect)
except Exception as e:
print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.")
# store reference probabilities to load later on
if not os.path.exists('load/passt_kld/'):
os.makedirs('load/passt_kld/')
save_ref_probabilities_path = 'load/passt_kld/'+ref_path.replace('/', '_')+'_collect'+str(collect)+'__reference_probabilities.pkl'
with open(save_ref_probabilities_path, 'wb') as fp:
pickle.dump(ref_p, fp)
print('[REFERENCE EMBEDDINGS][SAVED] ', save_ref_probabilities_path)
else:
raise ValueError('Must specify ref_path or load_ref_probabilities')
print('[EVALUATING GENERATIONS] ', eval_path)
passt_kl = 0
count = 0
for id in tqdm(ids):
if id not in no_ids:
try:
audio_path = os.path.join(eval_path, str(id)+eval_files_extension)
if os.path.isfile(audio_path):
eval_p = return_probabilities(model, audio_path, collect=collect)
# note: F.kl_div(x, y) is KL(y||x)
# see: https://github.com/pytorch/pytorch/issues/7337
# see: https://discuss.pytorch.org/t/kl-divergence-different-results-from-tf/56903/2
passt_kl += F.kl_div((ref_p[id] + 1e-6).log(), eval_p, reduction='sum', log_target=False)
count += 1
except Exception as e:
print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.")
return passt_kl / count if count > 0 else 0
|