Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import random | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from mmengine.dist import get_rank, sync_random_seed | |
| from mmengine.logging import print_log | |
| from mmengine.utils import digit_version, is_list_of | |
| from mmengine.utils.dl_utils import TORCH_VERSION | |
| def calc_dynamic_intervals( | |
| start_interval: int, | |
| dynamic_interval_list: Optional[List[Tuple[int, int]]] = None | |
| ) -> Tuple[List[int], List[int]]: | |
| """Calculate dynamic intervals. | |
| Args: | |
| start_interval (int): The interval used in the beginning. | |
| dynamic_interval_list (List[Tuple[int, int]], optional): The | |
| first element in the tuple is a milestone and the second | |
| element is a interval. The interval is used after the | |
| corresponding milestone. Defaults to None. | |
| Returns: | |
| Tuple[List[int], List[int]]: a list of milestone and its corresponding | |
| intervals. | |
| """ | |
| if dynamic_interval_list is None: | |
| return [0], [start_interval] | |
| assert is_list_of(dynamic_interval_list, tuple) | |
| dynamic_milestones = [0] | |
| dynamic_milestones.extend( | |
| [dynamic_interval[0] for dynamic_interval in dynamic_interval_list]) | |
| dynamic_intervals = [start_interval] | |
| dynamic_intervals.extend( | |
| [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) | |
| return dynamic_milestones, dynamic_intervals | |
| def set_random_seed(seed: Optional[int] = None, | |
| deterministic: bool = False, | |
| diff_rank_seed: bool = False) -> int: | |
| """Set random seed. | |
| Args: | |
| seed (int, optional): Seed to be used. | |
| deterministic (bool): Whether to set the deterministic option for | |
| CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` | |
| to True and `torch.backends.cudnn.benchmark` to False. | |
| Defaults to False. | |
| diff_rank_seed (bool): Whether to add rank number to the random seed to | |
| have different random seed in different threads. Defaults to False. | |
| """ | |
| if seed is None: | |
| seed = sync_random_seed() | |
| if diff_rank_seed: | |
| rank = get_rank() | |
| seed += rank | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # os.environ['PYTHONHASHSEED'] = str(seed) | |
| if deterministic: | |
| if torch.backends.cudnn.benchmark: | |
| print_log( | |
| 'torch.backends.cudnn.benchmark is going to be set as ' | |
| '`False` to cause cuDNN to deterministically select an ' | |
| 'algorithm', | |
| logger='current', | |
| level=logging.WARNING) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): | |
| torch.use_deterministic_algorithms(True) | |
| return seed | |