|
import argparse |
|
import logging |
|
import math |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from s3prl import Logs, Object, Output |
|
from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel |
|
from s3prl.sampler import DistributedBatchSamplerWrapper |
|
from s3prl.superb import asr as problem |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("librispeech", help="The root directory of LibriSpeech") |
|
parser.add_argument("save_to", help="The directory to save checkpoint") |
|
parser.add_argument("--total_steps", type=int, default=200000) |
|
parser.add_argument("--log_step", type=int, default=100) |
|
parser.add_argument("--eval_step", type=int, default=5000) |
|
parser.add_argument("--save_step", type=int, default=100) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
logging.basicConfig() |
|
logger.setLevel(logging.INFO) |
|
|
|
args = parse_args() |
|
librispeech = Path(args.librispeech) |
|
assert librispeech.is_dir() |
|
save_to = Path(args.save_to) |
|
save_to.mkdir(exist_ok=True, parents=True) |
|
|
|
logger.info("Preparing preprocessor") |
|
preprocessor = problem.Preprocessor( |
|
librispeech, splits=["train-clean-100", "dev-clean", "test-clean"] |
|
) |
|
|
|
logger.info("Preparing train dataloader") |
|
train_dataset = problem.TrainDataset(**preprocessor.train_data()) |
|
train_sampler = problem.TrainSampler( |
|
train_dataset, max_timestamp=16000 * 1000, shuffle=True |
|
) |
|
train_sampler = DistributedBatchSamplerWrapper( |
|
train_sampler, num_replicas=1, rank=0 |
|
) |
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_sampler=train_sampler, |
|
num_workers=4, |
|
collate_fn=train_dataset.collate_fn, |
|
) |
|
|
|
logger.info("Preparing valid dataloader") |
|
valid_dataset = problem.ValidDataset( |
|
**preprocessor.valid_data(), |
|
**train_dataset.statistics(), |
|
) |
|
valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt") |
|
valid_sampler = problem.ValidSampler(valid_dataset, 8) |
|
valid_sampler = DistributedBatchSamplerWrapper( |
|
valid_sampler, num_replicas=1, rank=0 |
|
) |
|
valid_dataloader = DataLoader( |
|
valid_dataset, |
|
batch_sampler=valid_sampler, |
|
num_workers=4, |
|
collate_fn=valid_dataset.collate_fn, |
|
) |
|
|
|
logger.info("Preparing test dataloader") |
|
test_dataset = problem.TestDataset( |
|
**preprocessor.test_data(), |
|
**train_dataset.statistics(), |
|
) |
|
test_dataset.save_checkpoint(save_to / "test_dataset.ckpt") |
|
test_sampler = problem.TestSampler(test_dataset, 8) |
|
test_sampler = DistributedBatchSamplerWrapper(test_sampler, num_replicas=1, rank=0) |
|
test_dataloader = DataLoader( |
|
test_dataset, |
|
batch_sampler=test_sampler, |
|
num_workers=4, |
|
collate_fn=test_dataset.collate_fn, |
|
) |
|
|
|
latest_task = save_to / "task.ckpt" |
|
if latest_task.is_file(): |
|
logger.info("Last checkpoint found. Load model and optimizer from checkpoint") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task = Object.load_checkpoint(latest_task).to(device) |
|
|
|
else: |
|
logger.info("No last checkpoint found. Create new model") |
|
|
|
|
|
upstream = S3PRLUpstream("apc") |
|
downstream = problem.DownstreamModel( |
|
upstream.output_size, |
|
preprocessor.statistics().output_size, |
|
hidden_size=[512], |
|
dropout=[0.2], |
|
) |
|
model = UpstreamDownstreamModel(upstream, downstream) |
|
|
|
|
|
task = problem.Task(model, preprocessor.statistics().label_loader) |
|
task = task.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = optim.Adam(task.parameters(), lr=1e-3) |
|
latest_optimizer = save_to / "optimizer.ckpt" |
|
if latest_optimizer.is_file(): |
|
optimizer.load_state_dict(torch.load(save_to / "optimizer.ckpt")) |
|
else: |
|
optimizer = optim.Adam(task.parameters(), lr=1e-3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pbar = tqdm(total=args.total_steps, desc="Total") |
|
while True: |
|
batch_results = [] |
|
for batch in tqdm(train_dataloader, desc="Train", total=len(train_dataloader)): |
|
pbar.update(1) |
|
global_step = pbar.n |
|
|
|
assert isinstance(batch, Output) |
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
batch = batch.to(device) |
|
|
|
|
|
|
|
task.train() |
|
result = task.train_step(**batch) |
|
assert isinstance(result, Output) |
|
|
|
|
|
|
|
result.loss.backward() |
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(task.parameters(), max_norm=1.0) |
|
|
|
if math.isnan(grad_norm): |
|
logger.warning(f"Grad norm is NaN at step {global_step}") |
|
else: |
|
optimizer.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
cacheable_result = result.cacheable() |
|
|
|
|
|
batch_results.append(cacheable_result) |
|
|
|
if (global_step + 1) % args.log_step == 0: |
|
logs: Logs = task.train_reduction(batch_results).logs |
|
logger.info(f"[Train] step {global_step}") |
|
for log in logs.values(): |
|
logger.info(f"{log.name}: {log.data}") |
|
batch_results = [] |
|
|
|
if (global_step + 1) % args.eval_step == 0: |
|
with torch.no_grad(): |
|
task.eval() |
|
|
|
|
|
valid_results = [] |
|
for batch in tqdm( |
|
valid_dataloader, desc="Valid", total=len(valid_dataloader) |
|
): |
|
batch = batch.to(device) |
|
result = task.valid_step(**batch) |
|
cacheable_result = result.cacheable() |
|
valid_results.append(cacheable_result) |
|
|
|
logs: Logs = task.valid_reduction(valid_results).logs |
|
logger.info(f"[Valid] step {global_step}") |
|
for log in logs.values(): |
|
logger.info(f"{log.name}: {log.data}") |
|
|
|
if (global_step + 1) % args.save_step == 0: |
|
task.save_checkpoint(save_to / "task.ckpt") |
|
torch.save(optimizer.state_dict(), save_to / "optimizer.ckpt") |
|
|
|
with torch.no_grad(): |
|
|
|
test_results = [] |
|
for batch in tqdm(test_dataloader, desc="Test", total=len(test_dataloader)): |
|
batch = batch.to(device) |
|
result = task.test_step(**batch) |
|
cacheable_result = result.cacheable() |
|
test_results.append(cacheable_result) |
|
|
|
logs: Logs = task.test_reduction(test_results).logs |
|
logger.info(f"[Test] step results") |
|
for log in logs.values(): |
|
logger.info(f"{log.name}: {log.data}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|