|
|
|
import argparse |
|
class BaseOptions(): |
|
def __init__(self): |
|
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
self.initialized = False |
|
|
|
def initialize(self): |
|
|
|
self.parser.add_argument('--model_save_path', type=str, default='snapshot/version1/', help='path') |
|
self.parser.add_argument('--num_threads', type=int, default=2, help='number of threads') |
|
self.parser.add_argument('--max_dataset_size', type=int, default=150000, help='max dataset size') |
|
|
|
self.parser.add_argument('--n_epochs', type=int, default=40000, help='number of iterations') |
|
self.parser.add_argument('--dropout', type=float, default=0.0, help='dropout') |
|
self.parser.add_argument('--init_type', type=str, default='uniform', help='[uniform | xavier]') |
|
self.parser.add_argument('--frequency_print_batch', type=int, default=1000, help='print messages every set iter') |
|
self.parser.add_argument('--frequency_save_model', type=int, default=2000, help='save model every set iter') |
|
self.parser.add_argument('--small', type=bool, default=True, help='use small model') |
|
self.parser.add_argument('--use_batch_norm', action='store_true', help='') |
|
self.parser.add_argument('--smooth_2nd', type=bool, default=True, help='') |
|
|
|
|
|
|
|
self.parser.add_argument('--lambda_2d', type=float, default=0.001, help='weight of 2D projection loss') |
|
self.parser.add_argument('--lambda_depth', type=float, default=1.0, help='weight of depth loss') |
|
self.parser.add_argument('--lambda_reg', type=float, default=1.0, help='weight of regularization loss') |
|
|
|
self.parser.add_argument('--num_adja', type=int, default=6, help='number of nodes who affect a point') |
|
self.parser.add_argument('--num_corres', type=int, default=20000, help='number of corres') |
|
self.parser.add_argument('--iter_num', type=int, default=3, help='GN iter num') |
|
self.parser.add_argument('--width', type=int, default=512, help='image width') |
|
self.parser.add_argument('--height', type=int, default=512, help='image height') |
|
self.parser.add_argument('--crop_width', type=int, default=240, help='image width') |
|
self.parser.add_argument('--crop_height', type=int, default=320, help='image height') |
|
self.parser.add_argument('--max_num_edges', type=int, default=30000, help='number of edges') |
|
self.parser.add_argument('--max_num_nodes', type=int, default=1500, help='number of edges') |
|
self.parser.add_argument('--fdim', type=int, default=128) |
|
|
|
|
|
self.parser.add_argument('--lambda_weights', type=float, default=0.0, help='weight of weights loss') |
|
self.parser.add_argument('--lambda_corres', type=float, default=1.0, help='weight of corres loss') |
|
self.parser.add_argument('--lambda_graph', type=float, default=10.0, help='weight of graph loss') |
|
self.parser.add_argument('--lambda_warp', type=float, default=10.0, help='weight of warp loss') |
|
|
|
|
|
def parse(self): |
|
if not self.initialized: |
|
self.initialize() |
|
|
|
self.opt = self.parser.parse_args() |
|
self.opt.isTrain = self.isTrain |
|
self.opt.isTest = self.isTest |
|
args = vars(self.opt) |
|
|
|
return self.opt |
|
|
|
class TrainOptions(BaseOptions): |
|
|
|
def initialize(self): |
|
BaseOptions.initialize(self) |
|
|
|
self.parser.add_argument('--datapath', type=str, default='./data/train_data.txt', help='path') |
|
self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path') |
|
self.parser.add_argument('--lr_C', type=float, default=0.00001, help='initial learning rate') |
|
self.parser.add_argument('--optimizer_C', type=str, default='sgd', help='[sgd | adam]') |
|
self.parser.add_argument('--lr_W', type=float, default=0.00001, help='initial learning rate') |
|
self.parser.add_argument('--lr_BSW', type=float, default=0.00001, help='initial learning rate') |
|
self.parser.add_argument('--optimizer_W', type=str, default='sgd', help='[sgd | adam]') |
|
self.parser.add_argument('--optimizer_BSW', type=str, default='sgd', help='[sgd | adam]') |
|
self.parser.add_argument('--lr_decay_epoch', type=int, default=8000, help='multiply by a gamma every set iter') |
|
self.parser.add_argument('--lr_decay', type=float, default=0.1, help='coefficient of lr decay') |
|
self.parser.add_argument('--weight_decay', type=float, default=1e-4, help='0.0005coefficient of weight decay') |
|
self.parser.add_argument('--batch_size', type=int, default=4, help='batch size') |
|
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data') |
|
|
|
self.parser.add_argument('--validation', type=str, nargs='+') |
|
|
|
self.parser.add_argument('--gpus', type=int, nargs='+', default=[0,1]) |
|
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') |
|
self.parser.add_argument('--iters', type=int, default=12) |
|
|
|
self.parser.add_argument('--clip', type=float, default=1.0) |
|
self.parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') |
|
self.parser.add_argument('--add_noise', action='store_true') |
|
|
|
self.parser.add_argument('--train_bsw', type=bool, default=True, help='whether to train bsw network') |
|
self.parser.add_argument('--train_weight', type=bool, default=True, help='whether to train weight network') |
|
self.parser.add_argument('--train_corres', type=bool, default=True, help='whether to train corresPred network') |
|
|
|
self.isTrain = True |
|
self.isTest = False |
|
|
|
class ValOptions(BaseOptions): |
|
def initialize(self): |
|
BaseOptions.initialize(self) |
|
self.parser.add_argument('--batch_size', type=int, default=4, help='batch size') |
|
self.parser.add_argument('--datapath', type=str, default='./data/val_data.txt', help='path') |
|
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data') |
|
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') |
|
self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') |
|
self.parser.add_argument('--iters', type=int, default=12) |
|
self.isTrain = True |
|
self.isTest = False |
|
|
|
class TestOptions(BaseOptions): |
|
def initialize(self): |
|
BaseOptions.initialize(self) |
|
self.parser.add_argument('--batch_size', type=int, default=1, help='batch size') |
|
self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path') |
|
|
|
|
|
|
|
self.parser.add_argument('--savepath', type=str, default='flow_result', |
|
help='save path') |
|
self.parser.add_argument('--datapath', type=str, default='/data_b/yudong/paper_code/TalkingHead-NeRF/data_guancha/guancha_flow.txt', |
|
help='path') |
|
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') |
|
self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') |
|
self.parser.add_argument('--iters', type=int, default=12) |
|
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data') |
|
self.isTrain = False |
|
self.isTest = True |
|
|