File size: 3,125 Bytes
57ae837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
from yacs.config import CfgNode as CN

# Configuration variable
cfg = CN()

cfg.TITLE = 'default'
cfg.OUTPUT_DIR = 'results'
cfg.EXP_NAME = 'default'
cfg.DEVICE = 'cuda'
cfg.DEBUG = False
cfg.EVAL = False
cfg.RESUME = False
cfg.LOGDIR = ''
cfg.NUM_WORKERS = 5
cfg.SEED_VALUE = -1
cfg.SUMMARY_ITER = 50
cfg.MODEL_CONFIG = ''
cfg.FLIP_EVAL = False

cfg.TRAIN = CN()
cfg.TRAIN.STAGE = 'stage1'
cfg.TRAIN.DATASET_EVAL = '3dpw'
cfg.TRAIN.CHECKPOINT = ''
cfg.TRAIN.BATCH_SIZE = 64
cfg.TRAIN.START_EPOCH = 0
cfg.TRAIN.END_EPOCH = 999
cfg.TRAIN.OPTIM = 'Adam'
cfg.TRAIN.LR = 3e-4
cfg.TRAIN.LR_FINETUNE = 5e-5
cfg.TRAIN.LR_PATIENCE = 5
cfg.TRAIN.LR_DECAY_RATIO = 0.1
cfg.TRAIN.WD = 0.0
cfg.TRAIN.MOMENTUM = 0.9
cfg.TRAIN.MILESTONES = [50, 70]

cfg.DATASET = CN()
cfg.DATASET.SEQLEN = 81
cfg.DATASET.RATIO = [1.0, 0, 0, 0, 0]

cfg.MODEL = CN()
cfg.MODEL.BACKBONE = 'vit'

cfg.LOSS = CN()
cfg.LOSS.SHAPE_LOSS_WEIGHT = 0.001
cfg.LOSS.JOINT2D_LOSS_WEIGHT = 5.
cfg.LOSS.JOINT3D_LOSS_WEIGHT = 5.
cfg.LOSS.VERTS3D_LOSS_WEIGHT = 1.
cfg.LOSS.POSE_LOSS_WEIGHT = 1.
cfg.LOSS.CASCADED_LOSS_WEIGHT = 0.0
cfg.LOSS.CONTACT_LOSS_WEIGHT = 0.04
cfg.LOSS.ROOT_VEL_LOSS_WEIGHT = 0.001
cfg.LOSS.ROOT_POSE_LOSS_WEIGHT = 0.4
cfg.LOSS.SLIDING_LOSS_WEIGHT = 0.5
cfg.LOSS.CAMERA_LOSS_WEIGHT = 0.04
cfg.LOSS.LOSS_WEIGHT = 60.
cfg.LOSS.CAMERA_LOSS_SKIP_EPOCH = 5


def get_cfg_defaults():
    """Get a yacs CfgNode object with default values for my_project."""
    # Return a clone so that the defaults will not be altered
    # This is for the "local variable" use pattern
    return cfg.clone()


def get_cfg(args, test):
    """

    Define configuration.

    """
    import os
    
    cfg = get_cfg_defaults()
    if os.path.exists(args.cfg):
        cfg.merge_from_file(args.cfg)
    
    cfg.merge_from_list(args.opts)
    if test:
        cfg.merge_from_list(['EVAL', True])

    return cfg.clone()


def bool_arg(value):
    if value.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif value.lower() in ('no', 'false', 'f', 'n', '0'):
        return False


def parse_args(test=False):
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--cfg', type=str, default='./configs/debug.yaml', help='cfg file path')
    parser.add_argument(
        "--eval-set", type=str, default='3dpw', help="Evaluation dataset")
    parser.add_argument(
        "--eval-split", type=str, default='test', help="Evaluation data split")
    parser.add_argument('--render', default=False, type=bool_arg,
                        help='Render SMPL meshes after the evaluation')
    parser.add_argument('--save-results', default=False, type=bool_arg,
                        help='Save SMPL parameters after the evaluation')
    parser.add_argument(
        "opts", default=None, nargs=argparse.REMAINDER,
        help="Modify config options using the command-line")
    
    args = parser.parse_args()
    print(args, end='\n\n')
    cfg_file = args.cfg
    cfg = get_cfg(args, test)

    return cfg, cfg_file, args