lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
10.5 kB
import argparse
import logging
import math
import os
from copy import deepcopy
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from s3prl import Container, Logs, Object, Output
from s3prl.dataset.base import AugmentedDynamicItemDataset, DataLoader
from s3prl.sampler import DistributedBatchSamplerWrapper
from s3prl.util.configuration import parse_override, qualname_to_cls
from s3prl.util.seed import fix_random_seeds
device = "cuda" if torch.cuda.is_available() else "cpu"
logger = logging.getLogger(__name__)
DRYRUN_CONFIG = dict(
Trainer=dict(
total_steps=200000,
log_step=5000,
valid_step=5000,
save_step=5000,
eval_batch=8,
),
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"problem", help="The problem module. E.g. `s3prl.problem.ssl.tera.Tera`"
)
parser.add_argument("dataset_root", help="The dataset root for pretrain.")
parser.add_argument("save_to", help="The directory to save checkpoint")
parser.add_argument("--n_jobs", type=int, default=8)
parser.add_argument(
"--override",
default=None,
help=(
"Override the default_config of the problem module. "
"E.g. --override ValidSampler.batch_size=4,,TestSampler.batch_size=4"
),
)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--dryrun", action="store_true")
parser.add_argument("--seed", type=int, default=1337)
args = parser.parse_args()
fix_random_seeds(args.seed)
problem = qualname_to_cls(args.problem)
config = Container(deepcopy(problem.default_config))
for key, value in vars(args).items():
if key not in ["override"]:
config[key] = value
if args.dryrun:
config.override(DRYRUN_CONFIG)
if isinstance(args.override, str) and len(args.override) > 0:
override_dict = parse_override(args.override)
config.override(override_dict)
return problem, config
def main():
logging.basicConfig(level=logging.INFO)
problem, config = parse_args()
save_to = Path(config.save_to)
save_to.mkdir(exist_ok=True, parents=True)
# configure any upstream
body = problem.Body(**config.Body)
head = problem.Head(**config.Head)
loss = problem.Loss(**config.Loss)
stats = Container()
logger.info("Preparing corpus")
corpus = problem.Corpus(config.dataset_root, **config.Corpus)
train_data, valid_data, test_data, corpus_stats = corpus().split(3)
stats.add(corpus_stats)
logger.info("Preparing train data")
train_dataset = AugmentedDynamicItemDataset(train_data, tools=stats)
train_dataset = problem.TrainData(**config.TrainData)(train_dataset)
assert train_dataset.get_tool("feat_dim") == problem.input_size
train_sampler = DistributedBatchSamplerWrapper(
problem.TrainSampler(train_dataset, **config.TrainSampler),
num_replicas=1,
rank=0,
)
train_dataloader = DataLoader(
train_dataset,
train_sampler,
num_workers=config.n_jobs,
)
stats.add(train_dataset.all_tools())
logger.info("Preparing valid data")
valid_dataset = AugmentedDynamicItemDataset(valid_data, tools=stats)
valid_dataset = problem.ValidData(**config.ValidData)(valid_dataset)
valid_sampler = DistributedBatchSamplerWrapper(
problem.ValidSampler(valid_dataset, **config.ValidSampler),
num_replicas=1,
rank=0,
)
valid_dataloader = DataLoader(
valid_dataset,
valid_sampler,
num_workers=12,
)
logger.info("Preparing test data")
test_dataset = AugmentedDynamicItemDataset(test_data, tools=stats)
test_dataset = problem.TestData(**config.TestData)(test_dataset)
test_sampler = DistributedBatchSamplerWrapper(
problem.ValidSampler(test_dataset, **config.TestSampler),
num_replicas=1,
rank=0,
)
test_dataloader = DataLoader(
test_dataset,
test_sampler,
num_workers=12,
)
sorted_ckpt_dirs = sorted(
[
file
for file in save_to.iterdir()
if file.is_dir() and str(file).endswith(".ckpts")
],
key=os.path.getmtime,
)
if config.resume and len(sorted_ckpt_dirs) > 0:
logger.info("Last checkpoint found. Load model and optimizer from checkpoint")
task = Object.load_checkpoint(sorted_ckpt_dirs[1] / "task.ckpt").to(device)
else:
logger.info("Create a new model")
task = problem.Task(body, head, loss, **stats)
task = task.to(device)
# ALL THE FOLLOWING CODES ARE FOR TRAINER
# WHICH CAN BE LARGELY SIMPLIFIED WHEN USING OTHER TRAINER PACKAGES
opt_cls_qualname, opt_cfgs = config.Optimizer.split(1)
optimizer = qualname_to_cls(opt_cls_qualname)(task.parameters(), **opt_cfgs)
if config.resume and len(sorted_ckpt_dirs) > 0:
optimizer.load_state_dict(torch.load(sorted_ckpt_dirs[-1] / "optimizer.ckpt"))
if config.Trainer.use_valid:
if config.resume and len(sorted_ckpt_dirs) > 0:
valid_best_score = torch.load(
sorted_ckpt_dirs[-1] / "valid_best_score.ckpt"
)[config.Trainer.valid_metric]
else:
valid_best_score = -100000 if config.Trainer.valid_higher_better else 100000
def save_checkpoint(name):
ckpt_dir: Path = save_to / f"{name}.ckpts"
ckpt_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Save checkpoint to: {ckpt_dir}")
if hasattr(problem, "save_checkpoint"):
logger.info(f"Save upstream checkpoint to: {ckpt_dir}")
problem.save_checkpoint(config, body, head, ckpt_dir / "upstream.ckpt")
task.save_checkpoint(ckpt_dir / "task.ckpt")
torch.save(optimizer.state_dict(), ckpt_dir / "optimizer.ckpt")
torch.save(
{config.Trainer.valid_metric: valid_best_score},
ckpt_dir / "valid_best_score.ckpt",
)
pbar = tqdm(total=config.Trainer.total_steps, desc="Total")
train_completed = False
accum_grad_steps = 0
while not train_completed:
batch_results = []
for batch in tqdm(train_dataloader, desc="Train", total=len(train_dataloader)):
pbar.update(1)
global_step = pbar.n
assert isinstance(batch, Output)
batch = batch.to(device)
task.train()
result = task.train_step(**batch)
assert isinstance(result, Output)
result.loss /= config.Trainer.gradient_accumulate_steps
result.loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
task.parameters(), max_norm=config.Trainer.gradient_clipping
)
if math.isnan(grad_norm):
logger.warning(f"Grad norm is NaN at step {global_step}")
optimizer.zero_grad()
accum_grad_steps = 0
else:
accum_grad_steps += 1
if accum_grad_steps == config.Trainer.gradient_accumulate_steps:
optimizer.step()
optimizer.zero_grad()
accum_grad_steps = 0
batch_results.append(result.cacheable())
if global_step % config.Trainer.log_step == 0:
logs: Logs = task.train_reduction(batch_results).logs
logger.info(f"[Train] step {global_step}")
for name, value in logs.Scalar.items():
if name == "loss":
value *= config.Trainer.gradient_accumulate_steps
logger.info(f"{name}: {value}")
batch_results = []
if global_step % config.Trainer.valid_step == 0:
with torch.no_grad():
if config.Trainer.use_valid:
valid_results = []
for batch_idx, batch in enumerate(
tqdm(
valid_dataloader,
desc="Valid",
total=len(valid_dataloader),
)
):
if batch_idx == config.Trainer.get("eval_batch", -1):
break
batch = batch.to(device)
task.eval()
result = task.valid_step(**batch)
valid_results.append(result.cacheable())
logs: Logs = task.valid_reduction(valid_results).slice(1)
logger.info(f"[Valid] step {global_step}")
for name, value in logs.Scalar.items():
logger.info(f"{name}: {value}")
if name == config.Trainer.valid_metric:
cond1 = config.Trainer.valid_higher_better and (
value > valid_best_score
)
cond2 = (not config.Trainer.valid_higher_better) and (
value < valid_best_score
)
if cond1 or cond2:
valid_best_score = value
save_checkpoint("valid_best")
if (
global_step % config.Trainer.save_step == 0
or global_step == config.Trainer.total_steps
):
save_checkpoint(f"global_step_{global_step}")
if global_step == config.Trainer.total_steps:
train_completed = True
break
test_results = []
for batch_idx, batch in enumerate(
tqdm(test_dataloader, desc="Test", total=len(test_dataloader))
):
if batch_idx == config.Trainer.get("eval_batch", -1):
break
batch = batch.to(device)
result = task.test_step(**batch)
test_results.append(result.cacheable())
logs: Logs = task.test_reduction(test_results).slice(1)
logger.info(f"[Test] step {global_step}")
for name, value in logs.Scalar.items():
logger.info(f"{name}: {value}")
if __name__ == "__main__":
main()