Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import random | |
from typing import List, Optional, Tuple | |
import numpy as np | |
import torch | |
from mmengine.dist import get_rank, sync_random_seed | |
from mmengine.logging import print_log | |
from mmengine.utils import digit_version, is_list_of | |
from mmengine.utils.dl_utils import TORCH_VERSION | |
def calc_dynamic_intervals( | |
start_interval: int, | |
dynamic_interval_list: Optional[List[Tuple[int, int]]] = None | |
) -> Tuple[List[int], List[int]]: | |
"""Calculate dynamic intervals. | |
Args: | |
start_interval (int): The interval used in the beginning. | |
dynamic_interval_list (List[Tuple[int, int]], optional): The | |
first element in the tuple is a milestone and the second | |
element is a interval. The interval is used after the | |
corresponding milestone. Defaults to None. | |
Returns: | |
Tuple[List[int], List[int]]: a list of milestone and its corresponding | |
intervals. | |
""" | |
if dynamic_interval_list is None: | |
return [0], [start_interval] | |
assert is_list_of(dynamic_interval_list, tuple) | |
dynamic_milestones = [0] | |
dynamic_milestones.extend( | |
[dynamic_interval[0] for dynamic_interval in dynamic_interval_list]) | |
dynamic_intervals = [start_interval] | |
dynamic_intervals.extend( | |
[dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) | |
return dynamic_milestones, dynamic_intervals | |
def set_random_seed(seed: Optional[int] = None, | |
deterministic: bool = False, | |
diff_rank_seed: bool = False) -> int: | |
"""Set random seed. | |
Args: | |
seed (int, optional): Seed to be used. | |
deterministic (bool): Whether to set the deterministic option for | |
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` | |
to True and `torch.backends.cudnn.benchmark` to False. | |
Defaults to False. | |
diff_rank_seed (bool): Whether to add rank number to the random seed to | |
have different random seed in different threads. Defaults to False. | |
""" | |
if seed is None: | |
seed = sync_random_seed() | |
if diff_rank_seed: | |
rank = get_rank() | |
seed += rank | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
# torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# os.environ['PYTHONHASHSEED'] = str(seed) | |
if deterministic: | |
if torch.backends.cudnn.benchmark: | |
print_log( | |
'torch.backends.cudnn.benchmark is going to be set as ' | |
'`False` to cause cuDNN to deterministically select an ' | |
'algorithm', | |
logger='current', | |
level=logging.WARNING) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): | |
torch.use_deterministic_algorithms(True) | |
return seed | |