File size: 3,558 Bytes
88b5dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import random
from functools import partial
from pathlib import Path

import soundfile
import torch
from deepspeed import DeepSpeedConfig
from torch import Tensor
from tqdm import tqdm

from ..data import create_dataloaders, mix_fg_bg
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
from ..utils.distributed import is_local_leader
from .denoiser import Denoiser
from .hparams import HParams


def load_G(run_dir: Path, hp: HParams | None = None, training=True):
    if hp is None:
        hp = HParams.load(run_dir)
    assert isinstance(hp, HParams)
    model = Denoiser(hp)
    engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
    if training:
        engine.load_checkpoint()
    else:
        engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
    return engine


def save_wav(path: Path, wav: Tensor, rate: int):
    wav = wav.detach().cpu().numpy()
    soundfile.write(path, wav, samplerate=rate)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("run_dir", type=Path)
    parser.add_argument("--yaml", type=Path, default=None)
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()

    setup_logging(args.run_dir)
    hp = HParams.load(args.run_dir, yaml=args.yaml)

    if is_local_leader():
        hp.save_if_not_exists(args.run_dir)
        hp.print()

    train_dl, val_dl = create_dataloaders(hp, mode="denoiser")

    def feed_G(engine: Engine, batch: dict[str, Tensor]):
        alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
        if random.random() < hp.distort_prob:
            fg_wavs = batch["fg_dwavs"]
        else:
            fg_wavs = batch["fg_wavs"]
        mx_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"], alpha=alpha_fn)
        pred = engine(mx_dwavs, fg_wavs)
        losses = engine.gather_attribute("losses", prefix="losses")
        return pred, losses

    @torch.no_grad()
    def eval_fn(engine: Engine, eval_dir, n_saved=10):
        model = engine.module
        model.eval()

        step = engine.global_step

        for i, batch in enumerate(tqdm(val_dl), 1):
            batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)

            fg_dwavs = batch["fg_dwavs"]  # 1 t
            mx_dwavs = mix_fg_bg(fg_dwavs, batch["bg_dwavs"])
            pred_fg_dwavs = model(mx_dwavs)  # 1 t

            mx_mels = model.to_mel(mx_dwavs)  # 1 c t
            fg_mels = model.to_mel(fg_dwavs)  # 1 c t
            pred_fg_mels = model.to_mel(pred_fg_dwavs)  # 1 c t

            rate = model.hp.wav_rate
            get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"

            save_wav(get_path("_input.wav"), mx_dwavs[0], rate=rate)
            save_wav(get_path("_predict.wav"), pred_fg_dwavs[0], rate=rate)
            save_wav(get_path("_target.wav"), fg_dwavs[0], rate=rate)

            save_mels(
                get_path(".png"),
                cond_mel=mx_mels[0].cpu().numpy(),
                pred_mel=pred_fg_mels[0].cpu().numpy(),
                targ_mel=fg_mels[0].cpu().numpy(),
            )

            if i >= n_saved:
                break

    train_loop = TrainLoop(
        run_dir=args.run_dir,
        train_dl=train_dl,
        load_G=partial(load_G, hp=hp),
        device=args.device,
        feed_G=feed_G,
        eval_fn=eval_fn,
    )

    train_loop.run(max_steps=hp.max_steps)


if __name__ == "__main__":
    main()