File size: 4,180 Bytes
41b9d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from pathlib import Path
import os
from functools import partial

from frechet_audio_distance import FrechetAudioDistance
import pandas
import argbind
import torch
from tqdm import tqdm

import audiotools
from audiotools import AudioSignal

@argbind.bind(without_prefix=True)
def eval(
    exp_dir: str = None,
    baseline_key: str = "baseline", 
    audio_ext: str = ".wav",
):
    assert exp_dir is not None
    exp_dir = Path(exp_dir)
    assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"

    # set up our metrics
    # sisdr_loss = audiotools.metrics.distance.SISDRLoss()
    # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
    mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
    frechet = FrechetAudioDistance(
        use_pca=False, 
        use_activation=False,
        verbose=True, 
        audio_load_worker=4,
    )
    frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")

    # figure out what conditions we have
    conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]

    assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
    conditions.remove(baseline_key)

    print(f"Found {len(conditions)} conditions in {exp_dir}")
    print(f"conditions: {conditions}")

    baseline_dir = exp_dir / baseline_key 
    baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))

    metrics = []
    for condition in tqdm(conditions):
        cond_dir = exp_dir / condition
        cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))

        print(f"computing fad for {baseline_dir} and {cond_dir}")
        frechet_score = frechet.score(baseline_dir, cond_dir)

        # make sure we have the same number of files
        num_files = min(len(baseline_files), len(cond_files))
        baseline_files = baseline_files[:num_files]
        cond_files = cond_files[:num_files]
        assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"

        def process(baseline_file, cond_file):
            # make sure the files match (same name)
            assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"

            # load the files
            baseline_sig = AudioSignal(str(baseline_file))
            cond_sig = AudioSignal(str(cond_file))

            cond_sig.resample(baseline_sig.sample_rate)
            cond_sig.truncate_samples(baseline_sig.length)

            # if our condition is inpainting, we need to trim the conditioning off
            if "inpaint" in condition:
                ctx_amt = float(condition.split("_")[-1])
                ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
                print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
                cond_sig.trim(ctx_samples, ctx_samples)
                baseline_sig.trim(ctx_samples, ctx_samples)

            return {
                # "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
                # "stft": stft_loss(baseline_sig, cond_sig).item(),
                "mel": mel_loss(baseline_sig, cond_sig).item(),
                "frechet": frechet_score,
                # "visqol": vsq,
                "condition": condition,
                "file": baseline_file.stem,
            }

        print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
        metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))

    metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]


    for mk in metric_keys:
        stat = pandas.DataFrame(metrics)
        stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
        stat.to_csv(exp_dir / f"stats-{mk}.csv")

    df = pandas.DataFrame(metrics)
    df.to_csv(exp_dir / "metrics-all.csv", index=False)


if __name__ == "__main__":
    args = argbind.parse_args()

    with argbind.scope(args):
        eval()