Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Wrapper to train and test a video classification model.""" | |
from timesformer.utils.misc import launch_job | |
from timesformer.utils.parser import load_config, parse_args | |
from tools.test_net import test | |
from tools.train_net import train | |
def get_func(cfg): | |
train_func = train | |
test_func = test | |
return train_func, test_func | |
def main(): | |
""" | |
Main function to spawn the train and test process. | |
""" | |
args = parse_args() | |
if args.num_shards > 1: | |
args.output_dir = str(args.job_dir) | |
cfg = load_config(args) | |
train, test = get_func(cfg) | |
# Perform training. | |
if cfg.TRAIN.ENABLE: | |
launch_job(cfg=cfg, init_method=args.init_method, func=train) | |
# Perform multi-clip testing. | |
if cfg.TEST.ENABLE: | |
launch_job(cfg=cfg, init_method=args.init_method, func=test) | |
# Perform model visualization. | |
if cfg.TENSORBOARD.ENABLE and ( | |
cfg.TENSORBOARD.MODEL_VIS.ENABLE | |
or cfg.TENSORBOARD.WRONG_PRED_VIS.ENABLE | |
): | |
launch_job(cfg=cfg, init_method=args.init_method, func=visualize) | |
if __name__ == "__main__": | |
main() | |