File size: 8,192 Bytes
8d34f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# ref:https://github.com/ShunyuYao/DFA-NeRF
import argparse
class BaseOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        self.initialized = False

    def initialize(self):
        # self.parser.add_argument('--model_save_path', type=str, default='snapshot/small_filter_wo_ct_wi_bn/real_data/combine/', help='path')
        self.parser.add_argument('--model_save_path', type=str, default='snapshot/version1/', help='path')
        self.parser.add_argument('--num_threads', type=int, default=2, help='number of threads')
        self.parser.add_argument('--max_dataset_size', type=int, default=150000, help='max dataset size')

        self.parser.add_argument('--n_epochs', type=int, default=40000, help='number of iterations')
        self.parser.add_argument('--dropout', type=float, default=0.0, help='dropout')
        self.parser.add_argument('--init_type', type=str, default='uniform', help='[uniform | xavier]')
        self.parser.add_argument('--frequency_print_batch', type=int, default=1000, help='print messages every set iter')
        self.parser.add_argument('--frequency_save_model', type=int, default=2000, help='save model every set iter')
        self.parser.add_argument('--small', type=bool, default=True, help='use small model')
        self.parser.add_argument('--use_batch_norm', action='store_true', help='')
        self.parser.add_argument('--smooth_2nd', type=bool, default=True, help='')
        

        #loss weight for Gauss-Newton optimization
        self.parser.add_argument('--lambda_2d', type=float, default=0.001, help='weight of 2D projection loss')
        self.parser.add_argument('--lambda_depth', type=float, default=1.0, help='weight of depth loss')
        self.parser.add_argument('--lambda_reg', type=float, default=1.0, help='weight of regularization loss')
        
        self.parser.add_argument('--num_adja', type=int, default=6, help='number of nodes who affect a point')
        self.parser.add_argument('--num_corres', type=int, default=20000, help='number of corres')
        self.parser.add_argument('--iter_num', type=int, default=3, help='GN iter num')
        self.parser.add_argument('--width', type=int, default=512, help='image width')#480
        self.parser.add_argument('--height', type=int, default=512, help='image height')#640
        self.parser.add_argument('--crop_width', type=int, default=240, help='image width')
        self.parser.add_argument('--crop_height', type=int, default=320, help='image height')
        self.parser.add_argument('--max_num_edges', type=int, default=30000, help='number of edges')
        self.parser.add_argument('--max_num_nodes', type=int, default=1500, help='number of edges')
        self.parser.add_argument('--fdim', type=int, default=128)

        #loss weight for training
        self.parser.add_argument('--lambda_weights', type=float, default=0.0, help='weight of weights loss')#75
        self.parser.add_argument('--lambda_corres', type=float, default=1.0, help='weight of corres loss')#0, 1
        self.parser.add_argument('--lambda_graph', type=float, default=10.0, help='weight of graph loss')#1000, 5
        self.parser.add_argument('--lambda_warp', type=float, default=10.0, help='weight of warp loss')#1000, 5

        
    def parse(self):
        if not self.initialized:
            self.initialize()

        self.opt = self.parser.parse_args()
        self.opt.isTrain = self.isTrain
        self.opt.isTest = self.isTest
        args = vars(self.opt)

        return self.opt

class TrainOptions(BaseOptions):
    # Override
    def initialize(self):
        BaseOptions.initialize(self)
        #syn_datasets/syn_new_train_data.txt
        self.parser.add_argument('--datapath', type=str, default='./data/train_data.txt', help='path')
        self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path')#
        self.parser.add_argument('--lr_C', type=float, default=0.00001, help='initial learning rate')#0.01
        self.parser.add_argument('--optimizer_C', type=str, default='sgd', help='[sgd | adam]')
        self.parser.add_argument('--lr_W', type=float, default=0.00001, help='initial learning rate')
        self.parser.add_argument('--lr_BSW', type=float, default=0.00001, help='initial learning rate')
        self.parser.add_argument('--optimizer_W', type=str, default='sgd', help='[sgd | adam]')
        self.parser.add_argument('--optimizer_BSW', type=str, default='sgd', help='[sgd | adam]')
        self.parser.add_argument('--lr_decay_epoch', type=int, default=8000, help='multiply by a gamma every set iter')
        self.parser.add_argument('--lr_decay', type=float, default=0.1, help='coefficient of lr decay')
        self.parser.add_argument('--weight_decay', type=float, default=1e-4, help='0.0005coefficient of weight decay')
        self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
        self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')

        self.parser.add_argument('--validation', type=str, nargs='+')
        #self.parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
        self.parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
        self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
        self.parser.add_argument('--iters', type=int, default=12)

        self.parser.add_argument('--clip', type=float, default=1.0)
        self.parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
        self.parser.add_argument('--add_noise', action='store_true')

        self.parser.add_argument('--train_bsw', type=bool, default=True, help='whether to train bsw network')
        self.parser.add_argument('--train_weight', type=bool, default=True, help='whether to train weight network')
        self.parser.add_argument('--train_corres', type=bool, default=True, help='whether to train corresPred network')

        self.isTrain = True
        self.isTest = False

class ValOptions(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
        self.parser.add_argument('--datapath', type=str, default='./data/val_data.txt', help='path')
        self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
        self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
        self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
        self.parser.add_argument('--iters', type=int, default=12)
        self.isTrain = True
        self.isTest = False

class TestOptions(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument('--batch_size', type=int, default=1, help='batch size')
        self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path')#

        # self.parser.add_argument('--datapath', type=str, default='./data/real_train_data_1128_1.txt', help='path')
        # self.parser.add_argument('--datapath', type=str, default='./data_test_flow/test_data.txt', help='path')
        self.parser.add_argument('--savepath', type=str, default='flow_result',
                        help='save path')
        self.parser.add_argument('--datapath', type=str, default='/data_b/yudong/paper_code/TalkingHead-NeRF/data_guancha/guancha_flow.txt', 
            help='path')
        self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
        self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
        self.parser.add_argument('--iters', type=int, default=12)
        self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
        self.isTrain = False
        self.isTest = True