Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Helper functions for multigrid training.""" | |
import numpy as np | |
import timesformer.utils.logging as logging | |
logger = logging.get_logger(__name__) | |
class MultigridSchedule(object): | |
""" | |
This class defines multigrid training schedule and update cfg accordingly. | |
""" | |
def init_multigrid(self, cfg): | |
""" | |
Update cfg based on multigrid settings. | |
Args: | |
cfg (configs): configs that contains training and multigrid specific | |
hyperparameters. Details can be seen in | |
slowfast/config/defaults.py. | |
Returns: | |
cfg (configs): the updated cfg. | |
""" | |
self.schedule = None | |
# We may modify cfg.TRAIN.BATCH_SIZE, cfg.DATA.NUM_FRAMES, and | |
# cfg.DATA.TRAIN_CROP_SIZE during training, so we store their original | |
# value in cfg and use them as global variables. | |
cfg.MULTIGRID.DEFAULT_B = cfg.TRAIN.BATCH_SIZE | |
cfg.MULTIGRID.DEFAULT_T = cfg.DATA.NUM_FRAMES | |
cfg.MULTIGRID.DEFAULT_S = cfg.DATA.TRAIN_CROP_SIZE | |
if cfg.MULTIGRID.LONG_CYCLE: | |
self.schedule = self.get_long_cycle_schedule(cfg) | |
cfg.SOLVER.STEPS = [0] + [s[-1] for s in self.schedule] | |
# Fine-tuning phase. | |
cfg.SOLVER.STEPS[-1] = ( | |
cfg.SOLVER.STEPS[-2] + cfg.SOLVER.STEPS[-1] | |
) // 2 | |
cfg.SOLVER.LRS = [ | |
cfg.SOLVER.GAMMA ** s[0] * s[1][0] for s in self.schedule | |
] | |
# Fine-tuning phase. | |
cfg.SOLVER.LRS = cfg.SOLVER.LRS[:-1] + [ | |
cfg.SOLVER.LRS[-2], | |
cfg.SOLVER.LRS[-1], | |
] | |
cfg.SOLVER.MAX_EPOCH = self.schedule[-1][-1] | |
elif cfg.MULTIGRID.SHORT_CYCLE: | |
cfg.SOLVER.STEPS = [ | |
int(s * cfg.MULTIGRID.EPOCH_FACTOR) for s in cfg.SOLVER.STEPS | |
] | |
cfg.SOLVER.MAX_EPOCH = int( | |
cfg.SOLVER.MAX_EPOCH * cfg.MULTIGRID.EPOCH_FACTOR | |
) | |
return cfg | |
def update_long_cycle(self, cfg, cur_epoch): | |
""" | |
Before every epoch, check if long cycle shape should change. If it | |
should, update cfg accordingly. | |
Args: | |
cfg (configs): configs that contains training and multigrid specific | |
hyperparameters. Details can be seen in | |
slowfast/config/defaults.py. | |
cur_epoch (int): current epoch index. | |
Returns: | |
cfg (configs): the updated cfg. | |
changed (bool): do we change long cycle shape at this epoch? | |
""" | |
base_b, base_t, base_s = get_current_long_cycle_shape( | |
self.schedule, cur_epoch | |
) | |
if base_s != cfg.DATA.TRAIN_CROP_SIZE or base_t != cfg.DATA.NUM_FRAMES: | |
cfg.DATA.NUM_FRAMES = base_t | |
cfg.DATA.TRAIN_CROP_SIZE = base_s | |
cfg.TRAIN.BATCH_SIZE = base_b * cfg.MULTIGRID.DEFAULT_B | |
bs_factor = ( | |
float(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) | |
/ cfg.MULTIGRID.BN_BASE_SIZE | |
) | |
if bs_factor < 1: | |
cfg.BN.NORM_TYPE = "sync_batchnorm" | |
cfg.BN.NUM_SYNC_DEVICES = int(1.0 / bs_factor) | |
elif bs_factor > 1: | |
cfg.BN.NORM_TYPE = "sub_batchnorm" | |
cfg.BN.NUM_SPLITS = int(bs_factor) | |
else: | |
cfg.BN.NORM_TYPE = "batchnorm" | |
cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = cfg.DATA.SAMPLING_RATE * ( | |
cfg.MULTIGRID.DEFAULT_T // cfg.DATA.NUM_FRAMES | |
) | |
logger.info("Long cycle updates:") | |
logger.info("\tBN.NORM_TYPE: {}".format(cfg.BN.NORM_TYPE)) | |
if cfg.BN.NORM_TYPE == "sync_batchnorm": | |
logger.info( | |
"\tBN.NUM_SYNC_DEVICES: {}".format(cfg.BN.NUM_SYNC_DEVICES) | |
) | |
elif cfg.BN.NORM_TYPE == "sub_batchnorm": | |
logger.info("\tBN.NUM_SPLITS: {}".format(cfg.BN.NUM_SPLITS)) | |
logger.info("\tTRAIN.BATCH_SIZE: {}".format(cfg.TRAIN.BATCH_SIZE)) | |
logger.info( | |
"\tDATA.NUM_FRAMES x LONG_CYCLE_SAMPLING_RATE: {}x{}".format( | |
cfg.DATA.NUM_FRAMES, cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE | |
) | |
) | |
logger.info( | |
"\tDATA.TRAIN_CROP_SIZE: {}".format(cfg.DATA.TRAIN_CROP_SIZE) | |
) | |
return cfg, True | |
else: | |
return cfg, False | |
def get_long_cycle_schedule(self, cfg): | |
""" | |
Based on multigrid hyperparameters, define the schedule of a long cycle. | |
Args: | |
cfg (configs): configs that contains training and multigrid specific | |
hyperparameters. Details can be seen in | |
slowfast/config/defaults.py. | |
Returns: | |
schedule (list): Specifies a list long cycle base shapes and their | |
corresponding training epochs. | |
""" | |
steps = cfg.SOLVER.STEPS | |
default_size = float( | |
cfg.DATA.NUM_FRAMES * cfg.DATA.TRAIN_CROP_SIZE ** 2 | |
) | |
default_iters = steps[-1] | |
# Get shapes and average batch size for each long cycle shape. | |
avg_bs = [] | |
all_shapes = [] | |
for t_factor, s_factor in cfg.MULTIGRID.LONG_CYCLE_FACTORS: | |
base_t = int(round(cfg.DATA.NUM_FRAMES * t_factor)) | |
base_s = int(round(cfg.DATA.TRAIN_CROP_SIZE * s_factor)) | |
if cfg.MULTIGRID.SHORT_CYCLE: | |
shapes = [ | |
[ | |
base_t, | |
cfg.MULTIGRID.DEFAULT_S | |
* cfg.MULTIGRID.SHORT_CYCLE_FACTORS[0], | |
], | |
[ | |
base_t, | |
cfg.MULTIGRID.DEFAULT_S | |
* cfg.MULTIGRID.SHORT_CYCLE_FACTORS[1], | |
], | |
[base_t, base_s], | |
] | |
else: | |
shapes = [[base_t, base_s]] | |
# (T, S) -> (B, T, S) | |
shapes = [ | |
[int(round(default_size / (s[0] * s[1] * s[1]))), s[0], s[1]] | |
for s in shapes | |
] | |
avg_bs.append(np.mean([s[0] for s in shapes])) | |
all_shapes.append(shapes) | |
# Get schedule regardless of cfg.MULTIGRID.EPOCH_FACTOR. | |
total_iters = 0 | |
schedule = [] | |
for step_index in range(len(steps) - 1): | |
step_epochs = steps[step_index + 1] - steps[step_index] | |
for long_cycle_index, shapes in enumerate(all_shapes): | |
cur_epochs = ( | |
step_epochs * avg_bs[long_cycle_index] / sum(avg_bs) | |
) | |
cur_iters = cur_epochs / avg_bs[long_cycle_index] | |
total_iters += cur_iters | |
schedule.append((step_index, shapes[-1], cur_epochs)) | |
iter_saving = default_iters / total_iters | |
final_step_epochs = cfg.SOLVER.MAX_EPOCH - steps[-1] | |
# We define the fine-tuning phase to have the same amount of iteration | |
# saving as the rest of the training. | |
ft_epochs = final_step_epochs / iter_saving * avg_bs[-1] | |
schedule.append((step_index + 1, all_shapes[-1][2], ft_epochs)) | |
# Obtrain final schedule given desired cfg.MULTIGRID.EPOCH_FACTOR. | |
x = ( | |
cfg.SOLVER.MAX_EPOCH | |
* cfg.MULTIGRID.EPOCH_FACTOR | |
/ sum(s[-1] for s in schedule) | |
) | |
final_schedule = [] | |
total_epochs = 0 | |
for s in schedule: | |
epochs = s[2] * x | |
total_epochs += epochs | |
final_schedule.append((s[0], s[1], int(round(total_epochs)))) | |
print_schedule(final_schedule) | |
return final_schedule | |
def print_schedule(schedule): | |
""" | |
Log schedule. | |
""" | |
logger.info("Long cycle index\tBase shape\tEpochs") | |
for s in schedule: | |
logger.info("{}\t{}\t{}".format(s[0], s[1], s[2])) | |
def get_current_long_cycle_shape(schedule, epoch): | |
""" | |
Given a schedule and epoch index, return the long cycle base shape. | |
Args: | |
schedule (configs): configs that contains training and multigrid specific | |
hyperparameters. Details can be seen in | |
slowfast/config/defaults.py. | |
cur_epoch (int): current epoch index. | |
Returns: | |
shapes (list): A list describing the base shape in a long cycle: | |
[batch size relative to default, | |
number of frames, spatial dimension]. | |
""" | |
for s in schedule: | |
if epoch < s[-1]: | |
return s[1] | |
return schedule[-1][1] | |