File size: 4,174 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import logging
from pathlib import Path

import torch
import torch.optim as optim
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
from s3prl.superb import sid as problem
from s3prl.wrapper import LightningModuleSimpleWrapper

device = "cuda" if torch.cuda.is_available() else "cpu"
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("voxceleb1", help="The root directory of VoxCeleb1")
    parser.add_argument("save_to", help="The directory to save checkpoint")
    parser.add_argument("--total_steps", type=int, default=200000)
    parser.add_argument("--log_step", type=int, default=100)
    parser.add_argument("--eval_step", type=int, default=5000)
    parser.add_argument("--save_step", type=int, default=100)
    parser.add_argument(
        "--not_resume",
        action="store_true",
        help="Don't resume from the last checkpoint",
    )

    # for debugging
    parser.add_argument("--limit_train_batches", type=int)
    parser.add_argument("--limit_val_batches", type=int)
    parser.add_argument("--fast_dev_run", action="store_true")
    args = parser.parse_args()
    return args


def main():
    logging.basicConfig(level=logging.INFO)

    args = parse_args()
    voxceleb1 = Path(args.voxceleb1)
    save_to = Path(args.save_to)
    save_to.mkdir(exist_ok=True, parents=True)

    logger.info("Preparing preprocessor")
    preprocessor = problem.Preprocessor(voxceleb1)

    logger.info("Preparing train dataloader")
    train_dataset = problem.TrainDataset(**preprocessor.train_data())
    train_dataloader = train_dataset.to_dataloader(
        batch_size=8,
        num_workers=6,
        shuffle=True,
    )

    logger.info("Preparing valid dataloader")
    valid_dataset = problem.ValidDataset(
        **preprocessor.valid_data(),
        **train_dataset.statistics(),
    )
    valid_dataloader = valid_dataset.to_dataloader(batch_size=8, num_workers=6)

    logger.info("Preparing test dataloader")
    test_dataset = problem.TestDataset(
        **preprocessor.test_data(),
        **train_dataset.statistics(),
    )
    test_dataloader = test_dataset.to_dataloader(batch_size=8, num_workers=6)

    valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
    test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")

    upstream = S3PRLUpstream("apc")
    downstream = problem.DownstreamModel(
        upstream.output_size, len(preprocessor.statistics().category)
    )
    model = UpstreamDownstreamModel(upstream, downstream)
    task = problem.Task(model, preprocessor.statistics().category)

    optimizer = optim.Adam(task.parameters(), lr=1e-3)
    lightning_task = LightningModuleSimpleWrapper(task, optimizer)

    # The above is the usage of our library

    # The below is pytorch-lightning specific usage, which can be very simple
    # or very sophisticated, depending on how much you want to customized your
    # training loop

    checkpoint_callback = ModelCheckpoint(
        dirpath=str(save_to),
        filename="superb-sid-{step:02d}-{valid_0_accuracy:.2f}",
        monitor="valid_0_accuracy",  # since might have multiple valid dataloaders
        save_last=True,
        save_top_k=3,  # top 3 best ckpt on valid
        mode="max",  # higher, better
        every_n_train_steps=args.save_step,
    )

    trainer = Trainer(
        callbacks=[checkpoint_callback],
        accelerator="gpu",
        gpus=1,
        max_steps=args.total_steps,
        log_every_n_steps=args.log_step,
        val_check_interval=args.eval_step,
        limit_val_batches=args.limit_val_batches or 1.0,
        limit_train_batches=args.limit_train_batches or 1.0,
        fast_dev_run=args.fast_dev_run,
    )

    last_ckpt = save_to / "last.ckpt"
    if args.not_resume or not last_ckpt.is_file():
        last_ckpt = None

    trainer.fit(
        lightning_task,
        train_dataloader,
        val_dataloaders=[valid_dataloader, test_dataloader],
        ckpt_path=last_ckpt,
    )


if __name__ == "__main__":
    main()