Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File pram -> train | |
| @IDE PyCharm | |
| @Author [email protected] | |
| @Date 29/01/2024 14:26 | |
| ==================================================''' | |
| import argparse | |
| import os | |
| import os.path as osp | |
| import torch | |
| import torchvision.transforms.transforms as tvt | |
| import yaml | |
| import torch.utils.data as Data | |
| import torch.multiprocessing as mp | |
| import torch.distributed as dist | |
| from nets.segnet import SegNet | |
| from nets.segnetvit import SegNetViT | |
| from dataset.utils import collect_batch | |
| from dataset.get_dataset import compose_datasets | |
| from tools.common import torch_set_gpu | |
| from trainer import Trainer | |
| from nets.sfd2 import ResNet4x, DescriptorCompressor | |
| from nets.superpoint import SuperPoint | |
| torch.set_grad_enabled(True) | |
| parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| parser.add_argument('--config', type=str, required=True, help='config of specifications') | |
| parser.add_argument('--landmark_path', type=str, default=None, help='path of landmarks') | |
| def load_feat_network(config): | |
| if config['feature'] == 'spp': | |
| net = SuperPoint(config={ | |
| 'weight_path': '/scratches/flyer_2/fx221/Research/Code/third_weights/superpoint_v1.pth', | |
| }).eval() | |
| elif config['feature'] == 'resnet4x': | |
| net = ResNet4x(inputdim=3, outdim=128) | |
| net.load_state_dict( | |
| torch.load('weights/sfd2_20230511_210205_resnet4x.79.pth', map_location='cpu')['state_dict'], | |
| strict=True) | |
| net.eval() | |
| else: | |
| print('Please input correct feature {:s}'.format(config['feature'])) | |
| net = None | |
| if config['feat_dim'] != 128: | |
| desc_compressor = DescriptorCompressor(inputdim=128, outdim=config['feat_dim']).eval() | |
| if config['feat_dim'] == 64: | |
| desc_compressor.load_state_dict( | |
| torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O64.pth', | |
| map_location='cpu'), | |
| strict=True) | |
| elif config['feat_dim'] == 32: | |
| desc_compressor.load_state_dict( | |
| torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O32.pth', | |
| map_location='cpu'), | |
| strict=True) | |
| else: | |
| desc_compressor = None | |
| else: | |
| desc_compressor = None | |
| return net, desc_compressor | |
| def get_model(config): | |
| desc_dim = 256 if config['feature'] == 'spp' else 128 | |
| if config['use_mid_feature']: | |
| desc_dim = 256 | |
| model_config = { | |
| 'network': { | |
| 'descriptor_dim': desc_dim, | |
| 'n_layers': config['layers'], | |
| 'ac_fn': config['ac_fn'], | |
| 'norm_fn': config['norm_fn'], | |
| 'n_class': config['n_class'], | |
| 'output_dim': config['output_dim'], | |
| 'with_cls': config['with_cls'], | |
| 'with_sc': config['with_sc'], | |
| 'with_score': config['with_score'], | |
| } | |
| } | |
| if config['network'] == 'segnet': | |
| model = SegNet(model_config.get('network', {})) | |
| config['with_cls'] = False | |
| elif config['network'] == 'segnetvit': | |
| model = SegNetViT(model_config.get('network', {})) | |
| config['with_cls'] = False | |
| else: | |
| raise 'ERROR! {:s} model does not exist'.format(config['network']) | |
| if config['local_rank'] == 0: | |
| if config['weight_path'] is not None: | |
| state_dict = torch.load(osp.join(config['save_path'], config['weight_path']), map_location='cpu')['model'] | |
| model.load_state_dict(state_dict, strict=True) | |
| print('Load weight from {:s}'.format(osp.join(config['save_path'], config['weight_path']))) | |
| if config['resume_path'] is not None and not config['eval']: # only for training | |
| model.load_state_dict( | |
| torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'], | |
| strict=True) | |
| print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path']))) | |
| return model | |
| def setup(rank, world_size): | |
| os.environ['MASTER_ADDR'] = 'localhost' | |
| os.environ['MASTER_PORT'] = '12355' | |
| # initialize the process group | |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
| def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms): | |
| print('In train_DDP..., rank: ', rank) | |
| torch.cuda.set_device(rank) | |
| device = torch.device(f'cuda:{rank}') | |
| if feat_model is not None: | |
| feat_model.to(device) | |
| model.to(device) | |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
| setup(rank=rank, world_size=world_size) | |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) | |
| train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, | |
| shuffle=True, | |
| rank=rank, | |
| num_replicas=world_size, | |
| drop_last=True, # important? | |
| ) | |
| train_loader = torch.utils.data.DataLoader(train_set, | |
| batch_size=config['batch_size'] // world_size, | |
| num_workers=config['workers'] // world_size, | |
| # num_workers=1, | |
| pin_memory=True, | |
| # persistent_workers=True, | |
| shuffle=False, # must be False | |
| drop_last=True, | |
| collate_fn=collect_batch, | |
| prefetch_factor=4, | |
| sampler=train_sampler) | |
| config['local_rank'] = rank | |
| if rank == 0: | |
| test_set = test_set | |
| else: | |
| test_set = None | |
| trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set, | |
| config=config, img_transforms=img_transforms) | |
| trainer.train() | |
| if __name__ == '__main__': | |
| args = parser.parse_args() | |
| with open(args.config, 'rt') as f: | |
| config = yaml.load(f, Loader=yaml.Loader) | |
| torch_set_gpu(gpus=config['gpu']) | |
| if config['local_rank'] == 0: | |
| print(config) | |
| if config['feature'] == 'spp': | |
| img_transforms = None | |
| else: | |
| img_transforms = [] | |
| img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) | |
| img_transforms = tvt.Compose(img_transforms) | |
| feat_model, desc_compressor = load_feat_network(config=config) | |
| dataset = config['dataset'] | |
| if config['eval'] or config['loc']: | |
| if not config['online']: | |
| from localization.loc_by_rec_eval import loc_by_rec_eval | |
| test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1) | |
| config['n_class'] = test_set.n_class | |
| model = get_model(config=config) | |
| loc_by_rec_eval(rec_model=model.cuda().eval(), | |
| loader=test_set, | |
| local_feat=feat_model.cuda().eval(), | |
| config=config, img_transforms=img_transforms) | |
| else: | |
| from localization.loc_by_rec_online import loc_by_rec_online | |
| model = get_model(config=config) | |
| loc_by_rec_online(rec_model=model.cuda().eval(), | |
| local_feat=feat_model.cuda().eval(), | |
| config=config, img_transforms=img_transforms) | |
| exit(0) | |
| train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None) | |
| if config['do_eval']: | |
| test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None) | |
| else: | |
| test_set = None | |
| config['n_class'] = train_set.n_class | |
| model = get_model(config=config) | |
| if not config['with_dist'] or len(config['gpu']) == 1: | |
| config['with_dist'] = False | |
| model = model.cuda() | |
| train_loader = Data.DataLoader(dataset=train_set, | |
| shuffle=True, | |
| batch_size=config['batch_size'], | |
| drop_last=True, | |
| collate_fn=collect_batch, | |
| num_workers=config['workers']) | |
| if test_set is not None: | |
| test_loader = Data.DataLoader(dataset=test_set, | |
| shuffle=False, | |
| batch_size=1, | |
| drop_last=False, | |
| collate_fn=collect_batch, | |
| num_workers=4) | |
| else: | |
| test_loader = None | |
| trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader, | |
| config=config, img_transforms=img_transforms) | |
| trainer.train() | |
| else: | |
| mp.spawn(train_DDP, nprocs=len(config['gpu']), | |
| args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms), | |
| join=True) | |