|
import argparse |
|
import logging |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.optim as optim |
|
from pytorch_lightning import Trainer |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel |
|
from s3prl.superb import asr as problem |
|
from s3prl.wrapper import LightningModuleSimpleWrapper |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("librispeech", help="The root directory of LibriSpeech") |
|
parser.add_argument("save_to", help="The directory to save checkpoint") |
|
parser.add_argument("--total_steps", type=int, default=200000) |
|
parser.add_argument("--log_step", type=int, default=100) |
|
parser.add_argument("--eval_step", type=int, default=5000) |
|
parser.add_argument("--save_step", type=int, default=100) |
|
parser.add_argument( |
|
"--not_resume", |
|
action="store_true", |
|
help="Don't resume from the last checkpoint", |
|
) |
|
|
|
|
|
parser.add_argument("--limit_train_batches", type=int) |
|
parser.add_argument("--limit_val_batches", type=int) |
|
parser.add_argument("--fast_dev_run", action="store_true") |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
args = parse_args() |
|
librispeech = Path(args.librispeech) |
|
save_to = Path(args.save_to) |
|
save_to.mkdir(exist_ok=True, parents=True) |
|
|
|
logger.info("Preparing preprocessor") |
|
preprocessor = problem.Preprocessor(librispeech) |
|
|
|
logger.info("Preparing train dataloader") |
|
train_dataset = problem.TrainDataset(**preprocessor.train_data()) |
|
train_dataloader = train_dataset.to_dataloader( |
|
batch_size=8, |
|
num_workers=6, |
|
shuffle=True, |
|
) |
|
|
|
logger.info("Preparing valid dataloader") |
|
valid_dataset = problem.ValidDataset( |
|
**preprocessor.valid_data(), |
|
**train_dataset.statistics(), |
|
) |
|
valid_dataloader = valid_dataset.to_dataloader(batch_size=8, num_workers=6) |
|
|
|
logger.info("Preparing test dataloader") |
|
test_dataset = problem.TestDataset( |
|
**preprocessor.test_data(), |
|
**train_dataset.statistics(), |
|
) |
|
test_dataloader = test_dataset.to_dataloader(batch_size=8, num_workers=6) |
|
|
|
valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt") |
|
test_dataset.save_checkpoint(save_to / "test_dataset.ckpt") |
|
|
|
upstream = S3PRLUpstream("apc") |
|
downstream = problem.DownstreamModel( |
|
upstream.output_size, preprocessor.statistics().output_size |
|
) |
|
model = UpstreamDownstreamModel(upstream, downstream) |
|
task = problem.Task(model, preprocessor.statistics().label_loader) |
|
|
|
optimizer = optim.Adam(task.parameters(), lr=1e-3) |
|
lightning_task = LightningModuleSimpleWrapper(task, optimizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=str(save_to), |
|
filename="superb-asr-{step:02d}-{valid_0_wer:.2f}", |
|
monitor="valid_0_wer", |
|
save_last=True, |
|
save_top_k=3, |
|
mode="min", |
|
every_n_train_steps=args.save_step, |
|
) |
|
|
|
trainer = Trainer( |
|
callbacks=[checkpoint_callback], |
|
accelerator="gpu", |
|
gpus=1, |
|
max_steps=args.total_steps, |
|
log_every_n_steps=args.log_step, |
|
val_check_interval=args.eval_step, |
|
limit_val_batches=args.limit_val_batches or 1.0, |
|
limit_train_batches=args.limit_train_batches or 1.0, |
|
fast_dev_run=args.fast_dev_run, |
|
) |
|
|
|
last_ckpt = save_to / "last.ckpt" |
|
if args.not_resume or not last_ckpt.is_file(): |
|
last_ckpt = None |
|
|
|
trainer.fit( |
|
lightning_task, |
|
train_dataloader, |
|
val_dataloaders=[valid_dataloader, test_dataloader], |
|
ckpt_path=last_ckpt, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|