# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import ABCMeta, abstractmethod from typing import Dict, List, Tuple, Union from mmengine.model import BaseModel from torch import Tensor from mmdet.registry import MODELS from mmdet.structures import OptTrackSampleList, TrackSampleList from mmdet.utils import OptConfigType, OptMultiConfig @MODELS.register_module() class BaseMOTModel(BaseModel, metaclass=ABCMeta): """Base class for multiple object tracking. Args: data_preprocessor (dict or ConfigDict, optional): The pre-process config of :class:`TrackDataPreprocessor`. it usually includes, ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. init_cfg (dict or list[dict]): Initialization config dict. """ def __init__(self, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) def freeze_module(self, module: Union[List[str], Tuple[str], str]) -> None: """Freeze module during training.""" if isinstance(module, str): modules = [module] else: if not (isinstance(module, list) or isinstance(module, tuple)): raise TypeError('module must be a str or a list.') else: modules = module for module in modules: m = getattr(self, module) m.eval() for param in m.parameters(): param.requires_grad = False @property def with_detector(self) -> bool: """bool: whether the framework has a detector.""" return hasattr(self, 'detector') and self.detector is not None @property def with_reid(self) -> bool: """bool: whether the framework has a reid model.""" return hasattr(self, 'reid') and self.reid is not None @property def with_motion(self) -> bool: """bool: whether the framework has a motion model.""" return hasattr(self, 'motion') and self.motion is not None @property def with_track_head(self) -> bool: """bool: whether the framework has a track_head.""" return hasattr(self, 'track_head') and self.track_head is not None @property def with_tracker(self) -> bool: """bool: whether the framework has a tracker.""" return hasattr(self, 'tracker') and self.tracker is not None def forward(self, inputs: Dict[str, Tensor], data_samples: OptTrackSampleList = None, mode: str = 'predict', **kwargs): """The unified entry for a forward process in both training and test. The method should accept three modes: "tensor", "predict" and "loss": - "tensor": Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - "predict": Forward and return the predictions, which are fully processed to a list of :obj:`TrackDataSample`. - "loss": Forward and return a dict of losses according to the given inputs and data samples. Note that this method doesn't handle neither back propagation nor optimizer updating, which are done in the :meth:`train_step`. Args: inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding input images. Typically these should be mean centered and std scaled. The N denotes batch size. The T denotes the number of key/reference frames. - img (Tensor) : The key images. - ref_img (Tensor): The reference images. data_samples (list[:obj:`TrackDataSample`], optional): The annotation data of every samples. Defaults to None. mode (str): Return what kind of value. Defaults to 'predict'. Returns: The return type depends on ``mode``. - If ``mode="tensor"``, return a tensor or a tuple of tensor. - If ``mode="predict"``, return a list of :obj:`TrackDataSample`. - If ``mode="loss"``, return a dict of tensor. """ if mode == 'loss': return self.loss(inputs, data_samples, **kwargs) elif mode == 'predict': return self.predict(inputs, data_samples, **kwargs) elif mode == 'tensor': return self._forward(inputs, data_samples, **kwargs) else: raise RuntimeError(f'Invalid mode "{mode}". ' 'Only supports loss, predict and tensor mode') @abstractmethod def loss(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList, **kwargs) -> Union[dict, tuple]: """Calculate losses from a batch of inputs and data samples.""" pass @abstractmethod def predict(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList, **kwargs) -> TrackSampleList: """Predict results from a batch of inputs and data samples with post- processing.""" pass def _forward(self, inputs: Dict[str, Tensor], data_samples: OptTrackSampleList = None, **kwargs): """Network forward process. Usually includes backbone, neck and head forward without any post-processing. Args: inputs (Dict[str, Tensor]): of shape (N, T, C, H, W). data_samples (List[:obj:`TrackDataSample`], optional): The Data Samples. It usually includes information such as `gt_instance`. Returns: tuple[list]: A tuple of features from ``head`` forward. """ raise NotImplementedError( "_forward function (namely 'tensor' mode) is not supported now")