Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Model construction functions.""" | |
import torch | |
from fvcore.common.registry import Registry | |
MODEL_REGISTRY = Registry("MODEL") | |
MODEL_REGISTRY.__doc__ = """ | |
Registry for video model. | |
The registered object will be called with `obj(cfg)`. | |
The call should return a `torch.nn.Module` object. | |
""" | |
def build_model(cfg, gpu_id=None): | |
""" | |
Builds the video model. | |
Args: | |
cfg (configs): configs that contains the hyper-parameters to build the | |
backbone. Details can be seen in slowfast/config/defaults.py. | |
gpu_id (Optional[int]): specify the gpu index to build model. | |
""" | |
if torch.cuda.is_available(): | |
assert ( | |
cfg.NUM_GPUS <= torch.cuda.device_count() | |
), "Cannot use more GPU devices than available" | |
else: | |
assert ( | |
cfg.NUM_GPUS == 0 | |
), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." | |
# Construct the model | |
name = cfg.MODEL.MODEL_NAME | |
model = MODEL_REGISTRY.get(name)(cfg) | |
if cfg.NUM_GPUS: | |
if gpu_id is None: | |
# Determine the GPU used by the current process | |
cur_device = torch.cuda.current_device() | |
else: | |
cur_device = gpu_id | |
# Transfer the model to the current GPU device | |
model = model.cuda(device=cur_device) | |
# Use multi-process data parallel model in the multi-gpu setting | |
if cfg.NUM_GPUS > 1: | |
# Make model replica operate on the current device | |
model = torch.nn.parallel.DistributedDataParallel( | |
module=model, device_ids=[cur_device], output_device=cur_device | |
) | |
return model | |