#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random from contextlib import contextmanager from pathlib import Path from typing import Any, Generator import numpy as np import torch from safetensors.torch import load_file, save_file from lerobot.common.constants import RNG_STATE from lerobot.common.datasets.utils import flatten_dict, unflatten_dict def serialize_python_rng_state() -> dict[str, torch.Tensor]: """ Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using `safetensors.save_file()` or `torch.save()`. """ py_state = random.getstate() return { "py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64), "py_rng_state": torch.tensor(py_state[1], dtype=torch.int64), } def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: """ Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`. """ py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None) random.setstate(py_state) def serialize_numpy_rng_state() -> dict[str, torch.Tensor]: """ Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using `safetensors.save_file()` or `torch.save()`. """ np_state = np.random.get_state() # Ensure no breaking changes from numpy assert np_state[0] == "MT19937" return { "np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64), "np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64), "np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64), "np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32), } def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: """ Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`. """ np_state = ( "MT19937", rng_state_dict["np_rng_state_values"].numpy(), rng_state_dict["np_rng_state_index"].item(), rng_state_dict["np_rng_has_gauss"].item(), rng_state_dict["np_rng_cached_gaussian"].item(), ) np.random.set_state(np_state) def serialize_torch_rng_state() -> dict[str, torch.Tensor]: """ Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using `safetensors.save_file()` or `torch.save()`. """ torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()} if torch.cuda.is_available(): torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state() return torch_rng_state_dict def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: """ Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`. """ torch.set_rng_state(rng_state_dict["torch_rng_state"]) if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict: torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"]) def serialize_rng_state() -> dict[str, torch.Tensor]: """ Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`. """ py_rng_state_dict = serialize_python_rng_state() np_rng_state_dict = serialize_numpy_rng_state() torch_rng_state_dict = serialize_torch_rng_state() return { **py_rng_state_dict, **np_rng_state_dict, **torch_rng_state_dict, } def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: """ Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by `serialize_rng_state()`. """ py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")} np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")} torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")} deserialize_python_rng_state(py_rng_state_dict) deserialize_numpy_rng_state(np_rng_state_dict) deserialize_torch_rng_state(torch_rng_state_dict) def save_rng_state(save_dir: Path) -> None: rng_state_dict = serialize_rng_state() flat_rng_state_dict = flatten_dict(rng_state_dict) save_file(flat_rng_state_dict, save_dir / RNG_STATE) def load_rng_state(save_dir: Path) -> None: flat_rng_state_dict = load_file(save_dir / RNG_STATE) rng_state_dict = unflatten_dict(flat_rng_state_dict) deserialize_rng_state(rng_state_dict) def get_rng_state() -> dict[str, Any]: """Get the random state for `random`, `numpy`, and `torch`.""" random_state_dict = { "random_state": random.getstate(), "numpy_random_state": np.random.get_state(), "torch_random_state": torch.random.get_rng_state(), } if torch.cuda.is_available(): random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state() return random_state_dict def set_rng_state(random_state_dict: dict[str, Any]): """Set the random state for `random`, `numpy`, and `torch`. Args: random_state_dict: A dictionary of the form returned by `get_rng_state`. """ random.setstate(random_state_dict["random_state"]) np.random.set_state(random_state_dict["numpy_random_state"]) torch.random.set_rng_state(random_state_dict["torch_random_state"]) if torch.cuda.is_available(): torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) def set_seed(seed) -> None: """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) @contextmanager def seeded_context(seed: int) -> Generator[None, None, None]: """Set the seed when entering a context, and restore the prior random state at exit. Example usage: ``` a = random.random() # produces some random number with seeded_context(1337): b = random.random() # produces some other random number c = random.random() # produces yet another random number, but the same it would have if we never made `b` ``` """ random_state_dict = get_rng_state() set_seed(seed) yield None set_rng_state(random_state_dict)