Spaces:
Runtime error
Runtime error
File size: 3,788 Bytes
260b46d 9fbfaa6 260b46d 9fbfaa6 260b46d 84d4ed6 260b46d 84d4ed6 260b46d 9fbfaa6 260b46d 84d4ed6 260b46d ac059f4 c1b9ba0 260b46d c1b9ba0 260b46d 3815be3 260b46d ac059f4 1fc9757 c1b9ba0 260b46d ac059f4 260b46d ac059f4 260b46d 275afd0 260b46d 275afd0 260b46d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from pathlib import Path
import os
from functools import partial
from frechet_audio_distance import FrechetAudioDistance
import pandas
import argbind
from tqdm import tqdm
import audiotools
from audiotools import AudioSignal
@argbind.bind(without_prefix=True)
def eval(
exp_dir: str = None,
baseline_key: str = "baseline",
audio_ext: str = ".wav",
):
assert exp_dir is not None
exp_dir = Path(exp_dir)
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
# set up our metrics
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
frechet = FrechetAudioDistance(
use_pca=False,
use_activation=False,
verbose=True
)
visqol = partial(audiotools.metrics.quality.visqol, mode="audio")
# figure out what conditions we have
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
conditions.remove(baseline_key)
print(f"Found {len(conditions)} conditions in {exp_dir}")
print(f"conditions: {conditions}")
baseline_dir = exp_dir / baseline_key
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
metrics = []
for condition in conditions:
cond_dir = exp_dir / condition
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
print(f"computing fad for {baseline_dir} and {cond_dir}")
frechet_score = frechet.score(baseline_dir, cond_dir)
# make sure we have the same number of files
num_files = min(len(baseline_files), len(cond_files))
baseline_files = baseline_files[:num_files]
cond_files = cond_files[:num_files]
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
def process(baseline_file, cond_file):
# make sure the files match (same name)
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
# load the files
baseline_sig = AudioSignal(str(baseline_file))
cond_sig = AudioSignal(str(cond_file))
cond_sig.resample(baseline_sig.sample_rate)
cond_sig.truncate_samples(baseline_sig.length)
# compute the metrics
# try:
# vsq = visqol(baseline_sig, cond_sig)
# except:
# vsq = 0.0
return {
"sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
"stft": stft_loss(baseline_sig, cond_sig).item(),
"mel": mel_loss(baseline_sig, cond_sig).item(),
"frechet": frechet_score,
# "visqol": vsq,
"condition": condition,
"file": baseline_file.stem,
}
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
for mk in metric_keys:
stat = pandas.DataFrame(metrics)
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
stat.to_csv(exp_dir / f"stats-{mk}.csv")
df = pandas.DataFrame(metrics)
df.to_csv(exp_dir / "metrics-all.csv", index=False)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
eval() |