File size: 4,697 Bytes
88b5dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import random
from functools import partial
from pathlib import Path

import soundfile
import torch
from deepspeed import DeepSpeedConfig
from torch import Tensor
from tqdm import tqdm

from ..data import create_dataloaders, mix_fg_bg
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
from ..utils.distributed import is_local_leader
from .enhancer import Enhancer
from .hparams import HParams
from .univnet.discriminator import Discriminator


def load_G(run_dir: Path, hp: HParams | None = None, training=True):
    if hp is None:
        hp = HParams.load(run_dir)
        assert isinstance(hp, HParams)
    model = Enhancer(hp)
    engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
    if training:
        engine.load_checkpoint()
    else:
        engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
    return engine


def load_D(run_dir: Path, hp: HParams):
    if hp is None:
        hp = HParams.load(run_dir)
        assert isinstance(hp, HParams)
    model = Discriminator(hp)
    engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "D")
    engine.load_checkpoint()
    return engine


def save_wav(path: Path, wav: Tensor, rate: int):
    wav = wav.detach().cpu().numpy()
    soundfile.write(path, wav, samplerate=rate)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("run_dir", type=Path)
    parser.add_argument("--yaml", type=Path, default=None)
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()

    setup_logging(args.run_dir)
    hp = HParams.load(args.run_dir, yaml=args.yaml)

    if is_local_leader():
        hp.save_if_not_exists(args.run_dir)
        hp.print()

    train_dl, val_dl = create_dataloaders(hp, mode="enhancer")

    def feed_G(engine: Engine, batch: dict[str, Tensor]):
        if hp.lcfm_training_mode == "ae":
            pred = engine(batch["fg_wavs"], batch["fg_wavs"])
        elif hp.lcfm_training_mode == "cfm":
            alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
            mx_dwavs = mix_fg_bg(batch["fg_dwavs"], batch["bg_dwavs"], alpha=alpha_fn)
            pred = engine(mx_dwavs, batch["fg_wavs"], batch["fg_dwavs"])
        else:
            raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
        losses = engine.gather_attribute("losses")
        return pred, losses

    def feed_D(engine: Engine, batch: dict | None, fake: Tensor):
        if batch is None:
            losses = engine(fake=fake)
        else:
            losses = engine(fake=fake, real=batch["fg_wavs"])
        return losses

    @torch.no_grad()
    def eval_fn(engine: Engine, eval_dir, n_saved=10):
        assert isinstance(hp, HParams)

        model = engine.module
        model.eval()

        step = engine.global_step

        for i, batch in enumerate(tqdm(val_dl), 1):
            batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)

            fg_wavs = batch["fg_wavs"]  # 1 t

            if hp.lcfm_training_mode == "ae":
                in_dwavs = fg_wavs
            elif hp.lcfm_training_mode == "cfm":
                in_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"])
            else:
                raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")

            pred_fg_wavs = model(in_dwavs)  # 1 t

            in_mels = model.to_mel(in_dwavs)  # 1 c t
            fg_mels = model.to_mel(fg_wavs)  # 1 c t
            pred_fg_mels = model.to_mel(pred_fg_wavs)  # 1 c t

            rate = model.hp.wav_rate
            get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"

            save_wav(get_path("_input.wav"), in_dwavs[0], rate=rate)
            save_wav(get_path("_predict.wav"), pred_fg_wavs[0], rate=rate)
            save_wav(get_path("_target.wav"), fg_wavs[0], rate=rate)

            save_mels(
                get_path(".png"),
                cond_mel=in_mels[0].cpu().numpy(),
                pred_mel=pred_fg_mels[0].cpu().numpy(),
                targ_mel=fg_mels[0].cpu().numpy(),
            )

            if i >= n_saved:
                break

    train_loop = TrainLoop(
        run_dir=args.run_dir,
        train_dl=train_dl,
        load_G=partial(load_G, hp=hp),
        load_D=partial(load_D, hp=hp),
        device=args.device,
        feed_G=feed_G,
        feed_D=feed_D,
        eval_fn=eval_fn,
        gan_training_start_step=hp.gan_training_start_step,
    )

    train_loop.run(max_steps=hp.max_steps)


if __name__ == "__main__":
    main()