Spaces:
Sleeping
Sleeping
# mypy: ignore-errors | |
"""Wrapper to mimic (parts of) np.random API surface. | |
NumPy has strict guarantees on reproducibility etc; here we don't give any. | |
Q: default dtype is float64 in numpy | |
""" | |
from __future__ import annotations | |
import functools | |
from math import sqrt | |
from typing import Optional | |
import torch | |
from . import _dtypes_impl, _util | |
from ._normalizations import array_or_scalar, ArrayLike, normalizer | |
__all__ = [ | |
"seed", | |
"random_sample", | |
"sample", | |
"random", | |
"rand", | |
"randn", | |
"normal", | |
"choice", | |
"randint", | |
"shuffle", | |
"uniform", | |
] | |
def use_numpy_random(): | |
# local import to avoid ref cycles | |
import torch._dynamo.config as config | |
return config.use_numpy_random_stream | |
def deco_stream(func): | |
def inner(*args, **kwds): | |
if not use_numpy_random(): | |
return func(*args, **kwds) | |
else: | |
import numpy | |
from ._ndarray import ndarray | |
f = getattr(numpy.random, func.__name__) | |
# numpy funcs accept numpy ndarrays, unwrap | |
args = tuple( | |
arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args | |
) | |
kwds = { | |
key: val.tensor.numpy() if isinstance(val, ndarray) else val | |
for key, val in kwds.items() | |
} | |
value = f(*args, **kwds) | |
# `value` can be either numpy.ndarray or python scalar (or None) | |
if isinstance(value, numpy.ndarray): | |
value = ndarray(torch.as_tensor(value)) | |
return value | |
return inner | |
def seed(seed=None): | |
if seed is not None: | |
torch.random.manual_seed(seed) | |
def random_sample(size=None): | |
if size is None: | |
size = () | |
dtype = _dtypes_impl.default_dtypes().float_dtype | |
values = torch.empty(size, dtype=dtype).uniform_() | |
return array_or_scalar(values, return_scalar=size == ()) | |
def rand(*size): | |
if size == (): | |
size = None | |
return random_sample(size) | |
sample = random_sample | |
random = random_sample | |
def uniform(low=0.0, high=1.0, size=None): | |
if size is None: | |
size = () | |
dtype = _dtypes_impl.default_dtypes().float_dtype | |
values = torch.empty(size, dtype=dtype).uniform_(low, high) | |
return array_or_scalar(values, return_scalar=size == ()) | |
def randn(*size): | |
dtype = _dtypes_impl.default_dtypes().float_dtype | |
values = torch.randn(size, dtype=dtype) | |
return array_or_scalar(values, return_scalar=size == ()) | |
def normal(loc=0.0, scale=1.0, size=None): | |
if size is None: | |
size = () | |
dtype = _dtypes_impl.default_dtypes().float_dtype | |
values = torch.empty(size, dtype=dtype).normal_(loc, scale) | |
return array_or_scalar(values, return_scalar=size == ()) | |
def shuffle(x): | |
# no @normalizer because we do not cast e.g. lists to tensors | |
from ._ndarray import ndarray | |
if isinstance(x, torch.Tensor): | |
tensor = x | |
elif isinstance(x, ndarray): | |
tensor = x.tensor | |
else: | |
raise NotImplementedError("We do not random.shuffle lists in-place") | |
perm = torch.randperm(tensor.shape[0]) | |
xp = tensor[perm] | |
tensor.copy_(xp) | |
def randint(low, high=None, size=None): | |
if size is None: | |
size = () | |
if not isinstance(size, (tuple, list)): | |
size = (size,) | |
if high is None: | |
low, high = 0, low | |
values = torch.randint(low, high, size=size) | |
return array_or_scalar(values, int, return_scalar=size == ()) | |
def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): | |
# https://stackoverflow.com/questions/59461811/random-choice-with-pytorch | |
if a.numel() == 1: | |
a = torch.arange(a) | |
# TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises | |
# number of draws | |
if size is None: | |
num_el = 1 | |
elif _util.is_sequence(size): | |
num_el = 1 | |
for el in size: | |
num_el *= el | |
else: | |
num_el = size | |
# prepare the probabilities | |
if p is None: | |
p = torch.ones_like(a) / a.shape[0] | |
# cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973 | |
atol = sqrt(torch.finfo(p.dtype).eps) | |
if abs(p.sum() - 1.0) > atol: | |
raise ValueError("probabilities do not sum to 1.") | |
# actually sample | |
indices = torch.multinomial(p, num_el, replacement=replace) | |
if _util.is_sequence(size): | |
indices = indices.reshape(size) | |
samples = a[indices] | |
return samples | |