Kaushik Bar commited on
Commit
b33e613
·
1 Parent(s): 523248f

zsic_unicl

Browse files
Files changed (1) hide show
  1. config.py +245 -0
config.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Unified Contrastive Learning (UniCL)
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang ([email protected])
6
+ # Based on Swin Transformer written by Zhe Liu
7
+ # --------------------------------------------------------
8
+
9
+ import os
10
+ import yaml
11
+ from yacs.config import CfgNode as CN
12
+
13
+ _C = CN()
14
+ _C.VERBOSE = False
15
+
16
+ # Base config files
17
+ _C.BASE = ['']
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Data settings
21
+ # -----------------------------------------------------------------------------
22
+ _C.DATA = CN()
23
+ # Batch size for a single GPU, could be overwritten by command line argument
24
+ _C.DATA.BATCH_SIZE = 128
25
+ # Path to dataset, could be overwritten by command line argument
26
+ _C.DATA.DATA_PATH = ''
27
+ # Dataset name
28
+ _C.DATA.DATASET = 'imagenet'
29
+ # Input image size
30
+ _C.DATA.IMG_SIZE = 224
31
+ # Interpolation to resize image (random, bilinear, bicubic)
32
+ _C.DATA.INTERPOLATION = 'bicubic'
33
+ # Use zipped dataset instead of folder dataset
34
+ # could be overwritten by command line argument
35
+ _C.DATA.ZIP_MODE = False
36
+ # Cache Data in Memory, could be overwritten by command line argument
37
+ _C.DATA.CACHE_MODE = 'part'
38
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
39
+ _C.DATA.PIN_MEMORY = True
40
+ # Number of data loading threads
41
+ _C.DATA.NUM_WORKERS = 8
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # Model settings
45
+ # -----------------------------------------------------------------------------
46
+ _C.MODEL = CN()
47
+ # Model name
48
+ _C.MODEL.NAME = ''
49
+ # Checkpoint to resume, could be overwritten by command line argument
50
+ _C.MODEL.RESUME = ''
51
+ # Number of classes, overwritten in data preparation
52
+ _C.MODEL.NUM_CLASSES = 0
53
+ # Label Smoothing
54
+ _C.MODEL.LABEL_SMOOTHING = 0.1
55
+ # Whether load pretrained model
56
+ _C.MODEL.PRETRAINED = ''
57
+ # Projection dimension
58
+ _C.MODEL.DIM_PROJECTION = 512
59
+ # Mode specific
60
+ _C.MODEL.SPEC = CN(new_allowed=True)
61
+ # -----------------------------------------------------------------------------
62
+ # Build Image Encoder
63
+ # -----------------------------------------------------------------------------
64
+ _C.MODEL.IMAGE_ENCODER = CN()
65
+ # Image encoder type
66
+ _C.MODEL.IMAGE_ENCODER.TYPE = 'swin'
67
+ # Input image size
68
+ _C.MODEL.IMAGE_ENCODER.IMG_SIZE = 224
69
+ # Dropout rate
70
+ _C.MODEL.IMAGE_ENCODER.DROP_RATE = 0.0
71
+ # Drop path rate
72
+ _C.MODEL.IMAGE_ENCODER.DROP_PATH_RATE = 0.1
73
+
74
+ # Swin Transformer parameters
75
+ _C.MODEL.IMAGE_ENCODER.SWIN = CN()
76
+ _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_SIZE = 4
77
+ _C.MODEL.IMAGE_ENCODER.SWIN.IN_CHANS = 3
78
+ _C.MODEL.IMAGE_ENCODER.SWIN.EMBED_DIM = 96
79
+ _C.MODEL.IMAGE_ENCODER.SWIN.DEPTHS = [2, 2, 6, 2]
80
+ _C.MODEL.IMAGE_ENCODER.SWIN.NUM_HEADS = [3, 6, 12, 24]
81
+ _C.MODEL.IMAGE_ENCODER.SWIN.WINDOW_SIZE = 7
82
+ _C.MODEL.IMAGE_ENCODER.SWIN.MLP_RATIO = 4.
83
+ _C.MODEL.IMAGE_ENCODER.SWIN.QKV_BIAS = True
84
+ _C.MODEL.IMAGE_ENCODER.SWIN.QK_SCALE = None
85
+ _C.MODEL.IMAGE_ENCODER.SWIN.APE = False
86
+ _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_NORM = True
87
+
88
+ # FocalNet parameters
89
+ _C.MODEL.IMAGE_ENCODER.FOCAL = CN()
90
+ _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_SIZE = 4
91
+ _C.MODEL.IMAGE_ENCODER.FOCAL.IN_CHANS = 3
92
+ _C.MODEL.IMAGE_ENCODER.FOCAL.EMBED_DIM = 96
93
+ _C.MODEL.IMAGE_ENCODER.FOCAL.DEPTHS = [2, 2, 6, 2]
94
+ _C.MODEL.IMAGE_ENCODER.FOCAL.MLP_RATIO = 4.
95
+ _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_NORM = True
96
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_LEVELS = [2, 2, 2, 2]
97
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_WINDOWS = [3, 3, 3, 3]
98
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_FACTORS = [2, 2, 2, 2]
99
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_CONV_EMBED = False
100
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_LAYERSCALE = False
101
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_POSTLN = False
102
+
103
+ # -----------------------------------------------------------------------------
104
+ # Build Text Encoder
105
+ # -----------------------------------------------------------------------------
106
+ _C.MODEL.TEXT_ENCODER = CN()
107
+
108
+ _C.MODEL.TEXT_ENCODER.NAME = 'transformer'
109
+ _C.MODEL.TEXT_ENCODER.LOAD_PRETRAINED = False
110
+ _C.MODEL.TEXT_ENCODER.PRETRAINED = ''
111
+ _C.MODEL.TEXT_ENCODER.TOKENIZER = 'clip'
112
+ _C.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
113
+ _C.MODEL.TEXT_ENCODER.WIDTH = 1024
114
+ _C.MODEL.TEXT_ENCODER.HEADS = 16
115
+ _C.MODEL.TEXT_ENCODER.LAYERS = 12
116
+ _C.MODEL.TEXT_ENCODER.AUTOGRESSIVE = True
117
+
118
+ # -----------------------------------------------------------------------------
119
+ # Training settings
120
+ # -----------------------------------------------------------------------------
121
+ _C.TRAIN = CN()
122
+ _C.TRAIN.START_EPOCH = 0
123
+ _C.TRAIN.EPOCHS = 32
124
+ _C.TRAIN.WARMUP_EPOCHS = 5
125
+ _C.TRAIN.WEIGHT_DECAY = 0.1
126
+ _C.TRAIN.BASE_LR = 5e-4
127
+ _C.TRAIN.WARMUP_LR = 5e-7
128
+ _C.TRAIN.MIN_LR = 5e-6
129
+ # Clip gradient norm
130
+ _C.TRAIN.CLIP_GRAD = 5.0
131
+ # Auto resume from latest checkpoint
132
+ _C.TRAIN.AUTO_RESUME = True
133
+ # Gradient accumulation steps
134
+ # could be overwritten by command line argument
135
+ _C.TRAIN.ACCUMULATION_STEPS = 0
136
+ # Whether to use gradient checkpointing to save memory
137
+ # could be overwritten by command line argument
138
+ _C.TRAIN.USE_CHECKPOINT = False
139
+
140
+ # LR scheduler
141
+ _C.TRAIN.LR_SCHEDULER = CN()
142
+ _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
143
+ # Epoch interval to decay LR, used in StepLRScheduler
144
+ _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
145
+ # LR decay rate, used in StepLRScheduler
146
+ _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
147
+
148
+ # Optimizer
149
+ _C.TRAIN.OPTIMIZER = CN()
150
+ _C.TRAIN.OPTIMIZER.NAME = 'adamw'
151
+ # Optimizer Epsilon
152
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
153
+ # Optimizer Betas
154
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
155
+ # SGD momentum
156
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
157
+
158
+ # -----------------------------------------------------------------------------
159
+ # Augmentation settings
160
+ # -----------------------------------------------------------------------------
161
+ _C.AUG = CN()
162
+ # Color jitter factor
163
+ _C.AUG.COLOR_JITTER = 0.4
164
+ # Use AutoAugment policy. "v0" or "original"
165
+ _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
166
+ # Random erase prob
167
+ _C.AUG.REPROB = 0.25
168
+ # Random erase mode
169
+ _C.AUG.REMODE = 'pixel'
170
+ # Random erase count
171
+ _C.AUG.RECOUNT = 1
172
+ # Mixup alpha, mixup enabled if > 0
173
+ _C.AUG.MIXUP = 0.8
174
+ # Cutmix alpha, cutmix enabled if > 0
175
+ _C.AUG.CUTMIX = 1.0
176
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
177
+ _C.AUG.CUTMIX_MINMAX = None
178
+ # Probability of performing mixup or cutmix when either/both is enabled
179
+ _C.AUG.MIXUP_PROB = 1.0
180
+ # Probability of switching to cutmix when both mixup and cutmix enabled
181
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
182
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
183
+ _C.AUG.MIXUP_MODE = 'batch'
184
+
185
+ # -----------------------------------------------------------------------------
186
+ # Testing settings
187
+ # -----------------------------------------------------------------------------
188
+ _C.TEST = CN()
189
+ # Whether to use center crop when testing
190
+ _C.TEST.CROP = True
191
+
192
+ # -----------------------------------------------------------------------------
193
+ # Misc
194
+ # -----------------------------------------------------------------------------
195
+ # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
196
+ # overwritten by command line argument
197
+ _C.AMP_OPT_LEVEL = ''
198
+ # Path to output folder, overwritten by command line argument
199
+ _C.OUTPUT = ''
200
+ # Tag of experiment, overwritten by command line argument
201
+ _C.TAG = 'default'
202
+ # Frequency to save checkpoint
203
+ _C.SAVE_FREQ = 1
204
+ # Frequency to logging info
205
+ _C.PRINT_FREQ = 100
206
+ # Fixed random seed
207
+ _C.SEED = 0
208
+ # Perform evaluation only, overwritten by command line argument
209
+ _C.EVAL_MODE = False
210
+ # Test throughput only, overwritten by command line argument
211
+ _C.THROUGHPUT_MODE = False
212
+ # Debug only so that skip dataloader initialization, overwritten by command line argument
213
+ _C.DEBUG_MODE = False
214
+ # local rank for DistributedDataParallel, given by command line argument
215
+ _C.LOCAL_RANK = 0
216
+
217
+
218
+ def _update_config_from_file(config, cfg_file):
219
+ config.defrost()
220
+ with open(cfg_file, 'r') as f:
221
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
222
+
223
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
224
+ if cfg:
225
+ _update_config_from_file(
226
+ config, os.path.join(os.path.dirname(cfg_file), cfg)
227
+ )
228
+ print('=> merge config from {}'.format(cfg_file))
229
+ config.merge_from_file(cfg_file)
230
+ config.freeze()
231
+
232
+
233
+ def update_config(config, args):
234
+ _update_config_from_file(config, args.cfg)
235
+ config.freeze()
236
+
237
+
238
+ def get_config(args):
239
+ """Get a yacs CfgNode object with default values."""
240
+ # Return a clone so that the defaults will not be altered
241
+ # This is for the "local variable" use pattern
242
+ config = _C.clone()
243
+ update_config(config, args)
244
+
245
+ return config