|
r"""PyTorch Detection Training. |
|
|
|
To run in a multi-gpu environment, use the distributed launcher:: |
|
|
|
python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \ |
|
train.py ... --world-size $NGPU |
|
|
|
The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu. |
|
--lr 0.02 --batch-size 2 --world-size 8 |
|
If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU. |
|
|
|
On top of that, for training Faster/Mask R-CNN, the default hyperparameters are |
|
--epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3 |
|
|
|
Also, if you train Keypoint R-CNN, the default hyperparameters are |
|
--epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3 |
|
Because the number of images is smaller in the person keypoint subset of COCO, |
|
the number of epochs should be adapted so that we have the same number of iterations. |
|
""" |
|
import datetime |
|
import os |
|
import time |
|
|
|
import detection.presets |
|
import torch |
|
import torch.utils.data |
|
import torchvision |
|
import torchvision.models.detection |
|
import torchvision.models.detection.mask_rcnn |
|
import detection.utils as utils |
|
from detection.coco_utils import get_coco, get_coco_kp |
|
from detection.engine import train_one_epoch, evaluate |
|
from detection.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups |
|
|
|
|
|
try: |
|
from torchvision.prototype import models as PM |
|
except ImportError: |
|
PM = None |
|
|
|
|
|
def get_dataset(name, image_set, transform, data_path): |
|
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} |
|
p, ds_fn, num_classes = paths[name] |
|
|
|
ds = ds_fn(p, image_set=image_set, transforms=transform) |
|
return ds, num_classes |
|
|
|
|
|
def get_transform(train, args): |
|
if train: |
|
return presets.DetectionPresetTrain(args.data_augmentation) |
|
elif not args.weights: |
|
return presets.DetectionPresetEval() |
|
else: |
|
weights = PM.get_weight(args.weights) |
|
return weights.transforms() |
|
|
|
|
|
def get_args_parser(add_help=True): |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) |
|
|
|
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") |
|
parser.add_argument("--dataset", default="coco", type=str, help="dataset name") |
|
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") |
|
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") |
|
parser.add_argument( |
|
"-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" |
|
) |
|
parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run") |
|
parser.add_argument( |
|
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" |
|
) |
|
parser.add_argument( |
|
"--lr", |
|
default=0.02, |
|
type=float, |
|
help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu", |
|
) |
|
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") |
|
parser.add_argument( |
|
"--wd", |
|
"--weight-decay", |
|
default=1e-4, |
|
type=float, |
|
metavar="W", |
|
help="weight decay (default: 1e-4)", |
|
dest="weight_decay", |
|
) |
|
parser.add_argument( |
|
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)" |
|
) |
|
parser.add_argument( |
|
"--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)" |
|
) |
|
parser.add_argument( |
|
"--lr-steps", |
|
default=[16, 22], |
|
nargs="+", |
|
type=int, |
|
help="decrease lr every step-size epochs (multisteplr scheduler only)", |
|
) |
|
parser.add_argument( |
|
"--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)" |
|
) |
|
parser.add_argument("--print-freq", default=20, type=int, help="print frequency") |
|
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") |
|
parser.add_argument("--resume", default="", type=str, help="path of checkpoint") |
|
parser.add_argument("--start_epoch", default=0, type=int, help="start epoch") |
|
parser.add_argument("--aspect-ratio-group-factor", default=3, type=int) |
|
parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn") |
|
parser.add_argument( |
|
"--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone" |
|
) |
|
parser.add_argument( |
|
"--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)" |
|
) |
|
parser.add_argument( |
|
"--sync-bn", |
|
dest="sync_bn", |
|
help="Use sync batch norm", |
|
action="store_true", |
|
) |
|
parser.add_argument( |
|
"--test-only", |
|
dest="test_only", |
|
help="Only test the model", |
|
action="store_true", |
|
) |
|
parser.add_argument( |
|
"--pretrained", |
|
dest="pretrained", |
|
help="Use pre-trained models from the modelzoo", |
|
action="store_true", |
|
) |
|
|
|
|
|
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") |
|
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") |
|
|
|
|
|
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") |
|
|
|
|
|
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") |
|
|
|
return parser |
|
|
|
|
|
def main(args): |
|
if args.weights and PM is None: |
|
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") |
|
if args.output_dir: |
|
utils.mkdir(args.output_dir) |
|
|
|
utils.init_distributed_mode(args) |
|
print(args) |
|
|
|
device = torch.device(args.device) |
|
|
|
|
|
print("Loading data") |
|
|
|
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) |
|
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) |
|
|
|
print("Creating data loaders") |
|
if args.distributed: |
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
|
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) |
|
else: |
|
train_sampler = torch.utils.data.RandomSampler(dataset) |
|
test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
|
if args.aspect_ratio_group_factor >= 0: |
|
group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) |
|
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) |
|
else: |
|
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn |
|
) |
|
|
|
data_loader_test = torch.utils.data.DataLoader( |
|
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn |
|
) |
|
|
|
print("Creating model") |
|
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} |
|
if "rcnn" in args.model: |
|
if args.rpn_score_thresh is not None: |
|
kwargs["rpn_score_thresh"] = args.rpn_score_thresh |
|
if not args.weights: |
|
model = torchvision.models.detection.__dict__[args.model]( |
|
pretrained=args.pretrained, num_classes=num_classes, **kwargs |
|
) |
|
else: |
|
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) |
|
model.to(device) |
|
if args.distributed and args.sync_bn: |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
|
model_without_ddp = model |
|
if args.distributed: |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) |
|
model_without_ddp = model.module |
|
|
|
params = [p for p in model.parameters() if p.requires_grad] |
|
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) |
|
|
|
scaler = torch.cuda.amp.GradScaler() if args.amp else None |
|
|
|
args.lr_scheduler = args.lr_scheduler.lower() |
|
if args.lr_scheduler == "multisteplr": |
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) |
|
elif args.lr_scheduler == "cosineannealinglr": |
|
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) |
|
else: |
|
raise RuntimeError( |
|
f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported." |
|
) |
|
|
|
if args.resume: |
|
checkpoint = torch.load(args.resume, map_location="cpu") |
|
model_without_ddp.load_state_dict(checkpoint["model"]) |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
|
args.start_epoch = checkpoint["epoch"] + 1 |
|
if args.amp: |
|
scaler.load_state_dict(checkpoint["scaler"]) |
|
|
|
if args.test_only: |
|
evaluate(model, data_loader_test, device=device) |
|
return |
|
|
|
print("Start training") |
|
start_time = time.time() |
|
for epoch in range(args.start_epoch, args.epochs): |
|
if args.distributed: |
|
train_sampler.set_epoch(epoch) |
|
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler) |
|
lr_scheduler.step() |
|
if args.output_dir: |
|
checkpoint = { |
|
"model": model_without_ddp.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
"args": args, |
|
"epoch": epoch, |
|
} |
|
if args.amp: |
|
checkpoint["scaler"] = scaler.state_dict() |
|
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) |
|
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) |
|
|
|
|
|
evaluate(model, data_loader_test, device=device) |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print(f"Training time {total_time_str}") |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args_parser().parse_args() |
|
main(args) |
|
|