File size: 4,175 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import argparse
import logging
from pathlib import Path
import torch
import torch.optim as optim
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
from s3prl.superb import asr as problem
from s3prl.wrapper import LightningModuleSimpleWrapper
device = "cuda" if torch.cuda.is_available() else "cpu"
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("librispeech", help="The root directory of LibriSpeech")
parser.add_argument("save_to", help="The directory to save checkpoint")
parser.add_argument("--total_steps", type=int, default=200000)
parser.add_argument("--log_step", type=int, default=100)
parser.add_argument("--eval_step", type=int, default=5000)
parser.add_argument("--save_step", type=int, default=100)
parser.add_argument(
"--not_resume",
action="store_true",
help="Don't resume from the last checkpoint",
)
# for debugging
parser.add_argument("--limit_train_batches", type=int)
parser.add_argument("--limit_val_batches", type=int)
parser.add_argument("--fast_dev_run", action="store_true")
args = parser.parse_args()
return args
def main():
logging.basicConfig(level=logging.INFO)
args = parse_args()
librispeech = Path(args.librispeech)
save_to = Path(args.save_to)
save_to.mkdir(exist_ok=True, parents=True)
logger.info("Preparing preprocessor")
preprocessor = problem.Preprocessor(librispeech)
logger.info("Preparing train dataloader")
train_dataset = problem.TrainDataset(**preprocessor.train_data())
train_dataloader = train_dataset.to_dataloader(
batch_size=8,
num_workers=6,
shuffle=True,
)
logger.info("Preparing valid dataloader")
valid_dataset = problem.ValidDataset(
**preprocessor.valid_data(),
**train_dataset.statistics(),
)
valid_dataloader = valid_dataset.to_dataloader(batch_size=8, num_workers=6)
logger.info("Preparing test dataloader")
test_dataset = problem.TestDataset(
**preprocessor.test_data(),
**train_dataset.statistics(),
)
test_dataloader = test_dataset.to_dataloader(batch_size=8, num_workers=6)
valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")
upstream = S3PRLUpstream("apc")
downstream = problem.DownstreamModel(
upstream.output_size, preprocessor.statistics().output_size
)
model = UpstreamDownstreamModel(upstream, downstream)
task = problem.Task(model, preprocessor.statistics().label_loader)
optimizer = optim.Adam(task.parameters(), lr=1e-3)
lightning_task = LightningModuleSimpleWrapper(task, optimizer)
# The above is the usage of our library
# The below is pytorch-lightning specific usage, which can be very simple
# or very sophisticated, depending on how much you want to customized your
# training loop
checkpoint_callback = ModelCheckpoint(
dirpath=str(save_to),
filename="superb-asr-{step:02d}-{valid_0_wer:.2f}",
monitor="valid_0_wer", # since might have multiple valid dataloaders
save_last=True,
save_top_k=3, # top 3 best ckpt on valid
mode="min", # lower, better
every_n_train_steps=args.save_step,
)
trainer = Trainer(
callbacks=[checkpoint_callback],
accelerator="gpu",
gpus=1,
max_steps=args.total_steps,
log_every_n_steps=args.log_step,
val_check_interval=args.eval_step,
limit_val_batches=args.limit_val_batches or 1.0,
limit_train_batches=args.limit_train_batches or 1.0,
fast_dev_run=args.fast_dev_run,
)
last_ckpt = save_to / "last.ckpt"
if args.not_resume or not last_ckpt.is_file():
last_ckpt = None
trainer.fit(
lightning_task,
train_dataloader,
val_dataloaders=[valid_dataloader, test_dataloader],
ckpt_path=last_ckpt,
)
if __name__ == "__main__":
main()
|