File size: 9,384 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import argparse
import logging
import math
from pathlib import Path
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from s3prl import Logs, Object, Output
from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
from s3prl.sampler import DistributedBatchSamplerWrapper
from s3prl.superb import sid as problem
device = "cuda" if torch.cuda.is_available() else "cpu"
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("voxceleb1", help="The root directory of VoxCeleb1")
parser.add_argument("save_to", help="The directory to save checkpoint")
parser.add_argument("--total_steps", type=int, default=200000)
parser.add_argument("--log_step", type=int, default=100)
parser.add_argument("--eval_step", type=int, default=5000)
parser.add_argument("--save_step", type=int, default=100)
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
return args
def main():
logging.basicConfig()
logger.setLevel(logging.INFO)
args = parse_args()
voxceleb1 = Path(args.voxceleb1)
assert voxceleb1.is_dir()
save_to = Path(args.save_to)
save_to.mkdir(exist_ok=True, parents=True)
logger.info("Preparing preprocessor")
preprocessor = problem.Preprocessor(voxceleb1)
logger.info("Preparing train dataloader")
train_dataset = problem.TrainDataset(**preprocessor.train_data())
train_sampler = problem.TrainSampler(
train_dataset, max_timestamp=16000 * 200, shuffle=True
)
train_sampler = DistributedBatchSamplerWrapper(
train_sampler, num_replicas=1, rank=0
)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=12,
collate_fn=train_dataset.collate_fn,
)
logger.info("Preparing valid dataloader")
valid_dataset = problem.ValidDataset(
**preprocessor.valid_data(),
**train_dataset.statistics(),
)
valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
valid_sampler = problem.ValidSampler(valid_dataset, 8)
valid_sampler = DistributedBatchSamplerWrapper(
valid_sampler, num_replicas=1, rank=0
)
valid_dataloader = DataLoader(
valid_dataset,
batch_sampler=valid_sampler,
num_workers=12,
collate_fn=valid_dataset.collate_fn,
)
logger.info("Preparing test dataloader")
test_dataset = problem.TestDataset(
**preprocessor.test_data(),
**train_dataset.statistics(),
)
test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")
test_sampler = problem.TestSampler(test_dataset, 8)
test_sampler = DistributedBatchSamplerWrapper(test_sampler, num_replicas=1, rank=0)
test_dataloader = DataLoader(
test_dataset, batch_size=8, num_workers=12, collate_fn=test_dataset.collate_fn
)
latest_task = save_to / "task.ckpt"
if args.resume and latest_task.is_file():
logger.info("Last checkpoint found. Load model and optimizer from checkpoint")
# Object.load_checkpoint() from a checkpoint path and
# Object.from_checkpoint() from a loaded checkpoint dictionary
# are like AutoModel in Huggingface which you only need to
# provide the checkpoint for restoring the module.
#
# Note that source code definition should be importable, since this
# auto loading mechanism is just automating the model re-initialization
# steps instead of scriptify (torch.jit) all the source code in the
# checkpoint
task = Object.load_checkpoint(latest_task).to(device)
else:
logger.info("No last checkpoint found. Create new model")
# Model creation block which can be fully customized
upstream = S3PRLUpstream("wav2vec2")
downstream = problem.DownstreamModel(
upstream.output_size, len(preprocessor.statistics().category)
)
model = UpstreamDownstreamModel(upstream, downstream)
# After customize your own model, simply put it into task object
task = problem.Task(model, preprocessor.statistics().category)
task = task.to(device)
# We do not handle optimizer/scheduler in any special way in S3PRL, since
# there are lots of dedicated package for this. Hence, we also do not handle
# the checkpointing for optimizer/scheduler. Depends on what training pipeline
# the user prefer, either Lightning or SpeechBrain, these frameworks will
# provide different solutions on how to save these objects. By not handling
# these objects in S3PRL we are making S3PRL more flexible and agnostic to training pipeline
# The following optimizer codeblock aims to align with the standard usage
# of PyTorch which is the standard way to save it.
optimizer = optim.Adam(task.parameters(), lr=1e-3)
latest_optimizer = save_to / "optimizer.ckpt"
if args.resume and latest_optimizer.is_file():
optimizer.load_state_dict(torch.load(save_to / "optimizer.ckpt"))
else:
optimizer = optim.Adam(task.parameters(), lr=1e-3)
# The following code block demonstrate how to train with your own training loop
# This entire block can be easily replaced with Lightning/SpeechBrain Trainer as
#
# Trainer(task)
# Trainer.fit(train_dataloader, valid_dataloader, test_dataloader)
#
# As you can see, there is a huge similarity among train/valid/test loops below,
# so it is a natural step to share these logics with a generic Trainer class
# as done in Lightning/SpeechBrain
pbar = tqdm(total=args.total_steps, desc="Total")
while True:
batch_results = []
for batch in tqdm(train_dataloader, desc="Train", total=len(train_dataloader)):
pbar.update(1)
global_step = pbar.n
assert isinstance(batch, Output)
optimizer.zero_grad()
# An Output object can more all its direct
# attributes/values to the device
batch = batch.to(device)
# An Output object is an OrderedDict so we
# can use dict decomposition here
task.train()
result = task.train_step(**batch)
assert isinstance(result, Output)
# The output of train step must contain
# at least a loss key
result.loss.backward()
# gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(task.parameters(), max_norm=1.0)
if math.isnan(grad_norm):
logger.warning(f"Grad norm is NaN at step {global_step}")
else:
optimizer.step()
# Detach from GPU, remove large logging (to Tensorboard or local files)
# objects like logits and leave only small data like loss scalars / prediction
# strings, so that these objects can be safely cached in a list in the MEM,
# and become useful for calculating metrics later
# The Output class can do these with self.cacheable()
cacheable_result = result.cacheable()
# Cache these small data for later metric calculation
batch_results.append(cacheable_result)
if (global_step + 1) % args.log_step == 0:
logs: Logs = task.train_reduction(batch_results).logs
logger.info(f"[Train] step {global_step}")
for log in logs.values():
logger.info(f"{log.name}: {log.data}")
batch_results = []
if (global_step + 1) % args.eval_step == 0:
with torch.no_grad():
task.eval()
# valid
valid_results = []
for batch in tqdm(
valid_dataloader, desc="Valid", total=len(valid_dataloader)
):
batch = batch.to(device)
result = task.valid_step(**batch)
cacheable_result = result.cacheable()
valid_results.append(cacheable_result)
logs: Logs = task.valid_reduction(valid_results).logs
logger.info(f"[Valid] step {global_step}")
for log in logs.values():
logger.info(f"{log.name}: {log.data}")
# test
test_results = []
for batch in tqdm(
test_dataloader, desc="Test", total=len(test_dataloader)
):
batch = batch.to(device)
result = task.test_step(**batch)
cacheable_result = result.cacheable()
test_results.append(cacheable_result)
logs: Logs = task.test_reduction(test_results).logs
logger.info(f"[Test] step {global_step}")
for log in logs.values():
logger.info(f"{log.name}: {log.data}")
if (global_step + 1) % args.save_step == 0:
task.save_checkpoint(save_to / "task.ckpt")
torch.save(optimizer.state_dict(), save_to / "optimizer.ckpt")
if __name__ == "__main__":
main()
|