File size: 3,274 Bytes
260b46d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275afd0
260b46d
 
 
275afd0
260b46d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from pathlib import Path
import os
from functools import partial

from frechet_audio_distance import FrechetAudioDistance
import pandas
import argbind
from tqdm import tqdm

import audiotools
from audiotools import AudioSignal

@argbind.bind(without_prefix=True)
def eval(
    exp_dir: str = None,
    baseline_key: str = "reconstructed", 
    audio_ext: str = ".wav",
):
    assert exp_dir is not None
    exp_dir = Path(exp_dir)
    assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"

    # set up our metrics
    sisdr_loss = audiotools.metrics.distance.SISDRLoss()
    stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
    mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
    frechet = FrechetAudioDistance(
        use_pca=False, 
        use_activation=False,
        verbose=False
    )
    visqol = partial(audiotools.metrics.quality.visqol, mode="audio")

    # figure out what conditions we have
    conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]

    assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
    conditions.remove(baseline_key)

    print(f"Found {len(conditions)} conditions in {exp_dir}")
    print(f"conditions: {conditions}")

    baseline_dir = exp_dir / baseline_key 
    baseline_files = list(baseline_dir.glob(f"*{audio_ext}"))

    metrics = []
    for condition in conditions:
        cond_dir = exp_dir / condition
        cond_files = list(cond_dir.glob(f"*{audio_ext}"))

        print(f"computing fad")
        frechet_score = frechet.score(baseline_dir, cond_dir)

        # make sure we have the same number of files
        assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"

        pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files))
        for baseline_file, cond_file in pbar:
            assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
            pbar.set_description(baseline_file.stem)

            # load the files
            baseline_sig = AudioSignal(baseline_file)
            cond_sig = AudioSignal(cond_file)

            # compute the metrics
            try:
                vsq = visqol(baseline_sig, cond_sig)
            except:
                vsq = 0.0
            metrics.append({
                "sisdr": sisdr_loss(baseline_sig, cond_sig),
                "stft": stft_loss(baseline_sig, cond_sig),
                "mel": mel_loss(baseline_sig, cond_sig),
                "frechet": frechet_score,
                "visqol": vsq,
                "condition": condition,
                "file": baseline_file.stem,
            })

    metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]


    for mk in metric_keys:
        stat = pandas.DataFrame(metrics)
        stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
        stat.to_csv(exp_dir / f"stats-{mk}.csv")

    df = pandas.DataFrame(metrics)
    df.to_csv(exp_dir / "metrics-all.csv", index=False)


if __name__ == "__main__":
    args = argbind.parse_args()

    with argbind.scope(args):
        eval()