import random import numpy as np import torch import os def set_seed(seed, cudnn=False): os.environ["PYTHONHASHSEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # May affect performance ref: https://pytorch.org/docs/stable/notes/randomness.html if torch.backends.cudnn.is_available and cudnn: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id)