SyncTalk / data_utils /UNFaceFlow /options_test_flow.py
yinwentao
DockerFile
8d34f50
# ref:https://github.com/ShunyuYao/DFA-NeRF
import argparse
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.initialized = False
def initialize(self):
# self.parser.add_argument('--model_save_path', type=str, default='snapshot/small_filter_wo_ct_wi_bn/real_data/combine/', help='path')
self.parser.add_argument('--model_save_path', type=str, default='snapshot/version1/', help='path')
self.parser.add_argument('--num_threads', type=int, default=2, help='number of threads')
self.parser.add_argument('--max_dataset_size', type=int, default=150000, help='max dataset size')
self.parser.add_argument('--n_epochs', type=int, default=40000, help='number of iterations')
self.parser.add_argument('--dropout', type=float, default=0.0, help='dropout')
self.parser.add_argument('--init_type', type=str, default='uniform', help='[uniform | xavier]')
self.parser.add_argument('--frequency_print_batch', type=int, default=1000, help='print messages every set iter')
self.parser.add_argument('--frequency_save_model', type=int, default=2000, help='save model every set iter')
self.parser.add_argument('--small', type=bool, default=True, help='use small model')
self.parser.add_argument('--use_batch_norm', action='store_true', help='')
self.parser.add_argument('--smooth_2nd', type=bool, default=True, help='')
#loss weight for Gauss-Newton optimization
self.parser.add_argument('--lambda_2d', type=float, default=0.001, help='weight of 2D projection loss')
self.parser.add_argument('--lambda_depth', type=float, default=1.0, help='weight of depth loss')
self.parser.add_argument('--lambda_reg', type=float, default=1.0, help='weight of regularization loss')
self.parser.add_argument('--num_adja', type=int, default=6, help='number of nodes who affect a point')
self.parser.add_argument('--num_corres', type=int, default=20000, help='number of corres')
self.parser.add_argument('--iter_num', type=int, default=3, help='GN iter num')
self.parser.add_argument('--width', type=int, default=512, help='image width')#480
self.parser.add_argument('--height', type=int, default=512, help='image height')#640
self.parser.add_argument('--crop_width', type=int, default=240, help='image width')
self.parser.add_argument('--crop_height', type=int, default=320, help='image height')
self.parser.add_argument('--max_num_edges', type=int, default=30000, help='number of edges')
self.parser.add_argument('--max_num_nodes', type=int, default=1500, help='number of edges')
self.parser.add_argument('--fdim', type=int, default=128)
#loss weight for training
self.parser.add_argument('--lambda_weights', type=float, default=0.0, help='weight of weights loss')#75
self.parser.add_argument('--lambda_corres', type=float, default=1.0, help='weight of corres loss')#0, 1
self.parser.add_argument('--lambda_graph', type=float, default=10.0, help='weight of graph loss')#1000, 5
self.parser.add_argument('--lambda_warp', type=float, default=10.0, help='weight of warp loss')#1000, 5
def parse(self):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain
self.opt.isTest = self.isTest
args = vars(self.opt)
return self.opt
class TrainOptions(BaseOptions):
# Override
def initialize(self):
BaseOptions.initialize(self)
#syn_datasets/syn_new_train_data.txt
self.parser.add_argument('--datapath', type=str, default='./data/train_data.txt', help='path')
self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path')#
self.parser.add_argument('--lr_C', type=float, default=0.00001, help='initial learning rate')#0.01
self.parser.add_argument('--optimizer_C', type=str, default='sgd', help='[sgd | adam]')
self.parser.add_argument('--lr_W', type=float, default=0.00001, help='initial learning rate')
self.parser.add_argument('--lr_BSW', type=float, default=0.00001, help='initial learning rate')
self.parser.add_argument('--optimizer_W', type=str, default='sgd', help='[sgd | adam]')
self.parser.add_argument('--optimizer_BSW', type=str, default='sgd', help='[sgd | adam]')
self.parser.add_argument('--lr_decay_epoch', type=int, default=8000, help='multiply by a gamma every set iter')
self.parser.add_argument('--lr_decay', type=float, default=0.1, help='coefficient of lr decay')
self.parser.add_argument('--weight_decay', type=float, default=1e-4, help='0.0005coefficient of weight decay')
self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
self.parser.add_argument('--validation', type=str, nargs='+')
#self.parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
self.parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
self.parser.add_argument('--iters', type=int, default=12)
self.parser.add_argument('--clip', type=float, default=1.0)
self.parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
self.parser.add_argument('--add_noise', action='store_true')
self.parser.add_argument('--train_bsw', type=bool, default=True, help='whether to train bsw network')
self.parser.add_argument('--train_weight', type=bool, default=True, help='whether to train weight network')
self.parser.add_argument('--train_corres', type=bool, default=True, help='whether to train corresPred network')
self.isTrain = True
self.isTest = False
class ValOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
self.parser.add_argument('--datapath', type=str, default='./data/val_data.txt', help='path')
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
self.parser.add_argument('--iters', type=int, default=12)
self.isTrain = True
self.isTest = False
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--batch_size', type=int, default=1, help='batch size')
self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path')#
# self.parser.add_argument('--datapath', type=str, default='./data/real_train_data_1128_1.txt', help='path')
# self.parser.add_argument('--datapath', type=str, default='./data_test_flow/test_data.txt', help='path')
self.parser.add_argument('--savepath', type=str, default='flow_result',
help='save path')
self.parser.add_argument('--datapath', type=str, default='/data_b/yudong/paper_code/TalkingHead-NeRF/data_guancha/guancha_flow.txt',
help='path')
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
self.parser.add_argument('--iters', type=int, default=12)
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
self.isTrain = False
self.isTest = True