# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Sequence from mmengine.hooks import Hook from mmengine.runner import Runner from mmdet.registry import HOOKS from mmdet.structures import DetDataSample @HOOKS.register_module() class MemoryProfilerHook(Hook): """Memory profiler hook recording memory information including virtual memory, swap memory, and the memory of the current process. Args: interval (int): Checking interval (every k iterations). Default: 50. """ def __init__(self, interval: int = 50) -> None: try: from psutil import swap_memory, virtual_memory self._swap_memory = swap_memory self._virtual_memory = virtual_memory except ImportError: raise ImportError('psutil is not installed, please install it by: ' 'pip install psutil') try: from memory_profiler import memory_usage self._memory_usage = memory_usage except ImportError: raise ImportError( 'memory_profiler is not installed, please install it by: ' 'pip install memory_profiler') self.interval = interval def _record_memory_information(self, runner: Runner) -> None: """Regularly record memory information. Args: runner (:obj:`Runner`): The runner of the training or evaluation process. """ # in Byte virtual_memory = self._virtual_memory() swap_memory = self._swap_memory() # in MB process_memory = self._memory_usage()[0] factor = 1024 * 1024 runner.logger.info( 'Memory information ' 'available_memory: ' f'{round(virtual_memory.available / factor)} MB, ' 'used_memory: ' f'{round(virtual_memory.used / factor)} MB, ' f'memory_utilization: {virtual_memory.percent} %, ' 'available_swap_memory: ' f'{round((swap_memory.total - swap_memory.used) / factor)}' ' MB, ' f'used_swap_memory: {round(swap_memory.used / factor)} MB, ' f'swap_memory_utilization: {swap_memory.percent} %, ' 'current_process_memory: ' f'{round(process_memory)} MB') def after_train_iter(self, runner: Runner, batch_idx: int, data_batch: Optional[dict] = None, outputs: Optional[dict] = None) -> None: """Regularly record memory information. Args: runner (:obj:`Runner`): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (dict, optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ if self.every_n_inner_iters(batch_idx, self.interval): self._record_memory_information(runner) def after_val_iter( self, runner: Runner, batch_idx: int, data_batch: Optional[dict] = None, outputs: Optional[Sequence[DetDataSample]] = None) -> None: """Regularly record memory information. Args: runner (:obj:`Runner`): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. data_batch (dict, optional): Data from dataloader. Defaults to None. outputs (Sequence[:obj:`DetDataSample`], optional): Outputs from model. Defaults to None. """ if self.every_n_inner_iters(batch_idx, self.interval): self._record_memory_information(runner) def after_test_iter( self, runner: Runner, batch_idx: int, data_batch: Optional[dict] = None, outputs: Optional[Sequence[DetDataSample]] = None) -> None: """Regularly record memory information. Args: runner (:obj:`Runner`): The runner of the testing process. batch_idx (int): The index of the current batch in the test loop. data_batch (dict, optional): Data from dataloader. Defaults to None. outputs (Sequence[:obj:`DetDataSample`], optional): Outputs from model. Defaults to None. """ if self.every_n_inner_iters(batch_idx, self.interval): self._record_memory_information(runner)