Spaces:
Build error
Build error
| r"""PyTorch Detection Training. | |
| To run in a multi-gpu environment, use the distributed launcher:: | |
| python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \ | |
| train.py ... --world-size $NGPU | |
| The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu. | |
| --lr 0.02 --batch-size 2 --world-size 8 | |
| If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU. | |
| On top of that, for training Faster/Mask R-CNN, the default hyperparameters are | |
| --epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3 | |
| Also, if you train Keypoint R-CNN, the default hyperparameters are | |
| --epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3 | |
| Because the number of images is smaller in the person keypoint subset of COCO, | |
| the number of epochs should be adapted so that we have the same number of iterations. | |
| """ | |
| import datetime | |
| import os | |
| import time | |
| import detection.presets | |
| import torch | |
| import torch.utils.data | |
| import torchvision | |
| import torchvision.models.detection | |
| import torchvision.models.detection.mask_rcnn | |
| import detection.utils as utils | |
| from detection.coco_utils import get_coco, get_coco_kp | |
| from detection.engine import train_one_epoch, evaluate | |
| from detection.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups | |
| try: | |
| from torchvision.prototype import models as PM | |
| except ImportError: | |
| PM = None | |
| def get_dataset(name, image_set, transform, data_path): | |
| paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} | |
| p, ds_fn, num_classes = paths[name] | |
| ds = ds_fn(p, image_set=image_set, transforms=transform) | |
| return ds, num_classes | |
| def get_transform(train, args): | |
| if train: | |
| return presets.DetectionPresetTrain(args.data_augmentation) | |
| elif not args.weights: | |
| return presets.DetectionPresetEval() | |
| else: | |
| weights = PM.get_weight(args.weights) | |
| return weights.transforms() | |
| def get_args_parser(add_help=True): | |
| import argparse | |
| parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) | |
| parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") | |
| parser.add_argument("--dataset", default="coco", type=str, help="dataset name") | |
| parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") | |
| parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") | |
| parser.add_argument( | |
| "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" | |
| ) | |
| parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run") | |
| parser.add_argument( | |
| "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" | |
| ) | |
| parser.add_argument( | |
| "--lr", | |
| default=0.02, | |
| type=float, | |
| help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu", | |
| ) | |
| parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") | |
| parser.add_argument( | |
| "--wd", | |
| "--weight-decay", | |
| default=1e-4, | |
| type=float, | |
| metavar="W", | |
| help="weight decay (default: 1e-4)", | |
| dest="weight_decay", | |
| ) | |
| parser.add_argument( | |
| "--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)" | |
| ) | |
| parser.add_argument( | |
| "--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)" | |
| ) | |
| parser.add_argument( | |
| "--lr-steps", | |
| default=[16, 22], | |
| nargs="+", | |
| type=int, | |
| help="decrease lr every step-size epochs (multisteplr scheduler only)", | |
| ) | |
| parser.add_argument( | |
| "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)" | |
| ) | |
| parser.add_argument("--print-freq", default=20, type=int, help="print frequency") | |
| parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") | |
| parser.add_argument("--resume", default="", type=str, help="path of checkpoint") | |
| parser.add_argument("--start_epoch", default=0, type=int, help="start epoch") | |
| parser.add_argument("--aspect-ratio-group-factor", default=3, type=int) | |
| parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn") | |
| parser.add_argument( | |
| "--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone" | |
| ) | |
| parser.add_argument( | |
| "--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)" | |
| ) | |
| parser.add_argument( | |
| "--sync-bn", | |
| dest="sync_bn", | |
| help="Use sync batch norm", | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--test-only", | |
| dest="test_only", | |
| help="Only test the model", | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--pretrained", | |
| dest="pretrained", | |
| help="Use pre-trained models from the modelzoo", | |
| action="store_true", | |
| ) | |
| # distributed training parameters | |
| parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") | |
| parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") | |
| # Prototype models only | |
| parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") | |
| # Mixed precision training parameters | |
| parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") | |
| return parser | |
| def main(args): | |
| if args.weights and PM is None: | |
| raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") | |
| if args.output_dir: | |
| utils.mkdir(args.output_dir) | |
| utils.init_distributed_mode(args) | |
| print(args) | |
| device = torch.device(args.device) | |
| # Data loading code | |
| print("Loading data") | |
| dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) | |
| dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) | |
| print("Creating data loaders") | |
| if args.distributed: | |
| train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) | |
| test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) | |
| else: | |
| train_sampler = torch.utils.data.RandomSampler(dataset) | |
| test_sampler = torch.utils.data.SequentialSampler(dataset_test) | |
| if args.aspect_ratio_group_factor >= 0: | |
| group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) | |
| train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) | |
| else: | |
| train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn | |
| ) | |
| data_loader_test = torch.utils.data.DataLoader( | |
| dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn | |
| ) | |
| print("Creating model") | |
| kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} | |
| if "rcnn" in args.model: | |
| if args.rpn_score_thresh is not None: | |
| kwargs["rpn_score_thresh"] = args.rpn_score_thresh | |
| if not args.weights: | |
| model = torchvision.models.detection.__dict__[args.model]( | |
| pretrained=args.pretrained, num_classes=num_classes, **kwargs | |
| ) | |
| else: | |
| model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) | |
| model.to(device) | |
| if args.distributed and args.sync_bn: | |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
| model_without_ddp = model | |
| if args.distributed: | |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) | |
| model_without_ddp = model.module | |
| params = [p for p in model.parameters() if p.requires_grad] | |
| optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) | |
| scaler = torch.cuda.amp.GradScaler() if args.amp else None | |
| args.lr_scheduler = args.lr_scheduler.lower() | |
| if args.lr_scheduler == "multisteplr": | |
| lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) | |
| elif args.lr_scheduler == "cosineannealinglr": | |
| lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) | |
| else: | |
| raise RuntimeError( | |
| f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported." | |
| ) | |
| if args.resume: | |
| checkpoint = torch.load(args.resume, map_location="cpu") | |
| model_without_ddp.load_state_dict(checkpoint["model"]) | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
| args.start_epoch = checkpoint["epoch"] + 1 | |
| if args.amp: | |
| scaler.load_state_dict(checkpoint["scaler"]) | |
| if args.test_only: | |
| evaluate(model, data_loader_test, device=device) | |
| return | |
| print("Start training") | |
| start_time = time.time() | |
| for epoch in range(args.start_epoch, args.epochs): | |
| if args.distributed: | |
| train_sampler.set_epoch(epoch) | |
| train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler) | |
| lr_scheduler.step() | |
| if args.output_dir: | |
| checkpoint = { | |
| "model": model_without_ddp.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "lr_scheduler": lr_scheduler.state_dict(), | |
| "args": args, | |
| "epoch": epoch, | |
| } | |
| if args.amp: | |
| checkpoint["scaler"] = scaler.state_dict() | |
| utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) | |
| utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) | |
| # evaluate after every epoch | |
| evaluate(model, data_loader_test, device=device) | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print(f"Training time {total_time_str}") | |
| if __name__ == "__main__": | |
| args = get_args_parser().parse_args() | |
| main(args) | |