Spaces:
Sleeping
Sleeping
"""Implementation of utility functions that can be applied to spaces. | |
These functions mostly take care of flattening and unflattening elements of spaces | |
to facilitate their usage in learning code. | |
""" | |
import operator as op | |
from collections import OrderedDict | |
from functools import reduce, singledispatch | |
from typing import Dict as TypingDict | |
from typing import TypeVar, Union, cast | |
import numpy as np | |
from gym.spaces import ( | |
Box, | |
Dict, | |
Discrete, | |
Graph, | |
GraphInstance, | |
MultiBinary, | |
MultiDiscrete, | |
Sequence, | |
Space, | |
Text, | |
Tuple, | |
) | |
def flatdim(space: Space) -> int: | |
"""Return the number of dimensions a flattened equivalent of this space would have. | |
Example usage:: | |
>>> from gym.spaces import Discrete | |
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)}) | |
>>> flatdim(space) | |
5 | |
Args: | |
space: The space to return the number of dimensions of the flattened spaces | |
Returns: | |
The number of dimensions for the flattened spaces | |
Raises: | |
NotImplementedError: if the space is not defined in ``gym.spaces``. | |
ValueError: if the space cannot be flattened into a :class:`Box` | |
""" | |
if not space.is_np_flattenable: | |
raise ValueError( | |
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace" | |
) | |
raise NotImplementedError(f"Unknown space: `{space}`") | |
def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int: | |
return reduce(op.mul, space.shape, 1) | |
def _flatdim_discrete(space: Discrete) -> int: | |
return int(space.n) | |
def _flatdim_multidiscrete(space: MultiDiscrete) -> int: | |
return int(np.sum(space.nvec)) | |
def _flatdim_tuple(space: Tuple) -> int: | |
if space.is_np_flattenable: | |
return sum(flatdim(s) for s in space.spaces) | |
raise ValueError( | |
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace" | |
) | |
def _flatdim_dict(space: Dict) -> int: | |
if space.is_np_flattenable: | |
return sum(flatdim(s) for s in space.spaces.values()) | |
raise ValueError( | |
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace" | |
) | |
def _flatdim_graph(space: Graph): | |
raise ValueError( | |
"Cannot get flattened size as the Graph Space in Gym has a dynamic size." | |
) | |
def _flatdim_text(space: Text) -> int: | |
return space.max_length | |
T = TypeVar("T") | |
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance] | |
def flatten(space: Space[T], x: T) -> FlatType: | |
"""Flatten a data point from a space. | |
This is useful when e.g. points from spaces must be passed to a neural | |
network, which only understands flat arrays of floats. | |
Args: | |
space: The space that ``x`` is flattened by | |
x: The value to flatten | |
Returns: | |
- For ``Box`` and ``MultiBinary``, this is a flattened array | |
- For ``Discrete`` and ``MultiDiscrete``, this is a flattened one-hot array of the sample | |
- For ``Tuple`` and ``Dict``, this is a concatenated array the subspaces (does not support graph subspaces) | |
- For graph spaces, returns `GraphInstance` where: | |
- `nodes` are n x k arrays | |
- `edges` are either: | |
- m x k arrays | |
- None | |
- `edge_links` are either: | |
- m x 2 arrays | |
- None | |
Raises: | |
NotImplementedError: If the space is not defined in ``gym.spaces``. | |
""" | |
raise NotImplementedError(f"Unknown space: `{space}`") | |
def _flatten_box_multibinary(space, x) -> np.ndarray: | |
return np.asarray(x, dtype=space.dtype).flatten() | |
def _flatten_discrete(space, x) -> np.ndarray: | |
onehot = np.zeros(space.n, dtype=space.dtype) | |
onehot[x - space.start] = 1 | |
return onehot | |
def _flatten_multidiscrete(space, x) -> np.ndarray: | |
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype) | |
offsets[1:] = np.cumsum(space.nvec.flatten()) | |
onehot = np.zeros((offsets[-1],), dtype=space.dtype) | |
onehot[offsets[:-1] + x.flatten()] = 1 | |
return onehot | |
def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]: | |
if space.is_np_flattenable: | |
return np.concatenate( | |
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)] | |
) | |
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces)) | |
def _flatten_dict(space, x) -> Union[dict, np.ndarray]: | |
if space.is_np_flattenable: | |
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) | |
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items()) | |
def _flatten_graph(space, x) -> GraphInstance: | |
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring.""" | |
def _graph_unflatten(unflatten_space, unflatten_x): | |
ret = None | |
if unflatten_space is not None and unflatten_x is not None: | |
if isinstance(unflatten_space, Box): | |
ret = unflatten_x.reshape(unflatten_x.shape[0], -1) | |
elif isinstance(unflatten_space, Discrete): | |
ret = np.zeros( | |
(unflatten_x.shape[0], unflatten_space.n - unflatten_space.start), | |
dtype=unflatten_space.dtype, | |
) | |
ret[ | |
np.arange(unflatten_x.shape[0]), unflatten_x - unflatten_space.start | |
] = 1 | |
return ret | |
nodes = _graph_unflatten(space.node_space, x.nodes) | |
edges = _graph_unflatten(space.edge_space, x.edges) | |
return GraphInstance(nodes, edges, x.edge_links) | |
def _flatten_text(space: Text, x: str) -> np.ndarray: | |
arr = np.full( | |
shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32 | |
) | |
for i, val in enumerate(x): | |
arr[i] = space.character_index(val) | |
return arr | |
def _flatten_sequence(space, x) -> tuple: | |
return tuple(flatten(space.feature_space, item) for item in x) | |
def unflatten(space: Space[T], x: FlatType) -> T: | |
"""Unflatten a data point from a space. | |
This reverses the transformation applied by :func:`flatten`. You must ensure | |
that the ``space`` argument is the same as for the :func:`flatten` call. | |
Args: | |
space: The space used to unflatten ``x`` | |
x: The array to unflatten | |
Returns: | |
A point with a structure that matches the space. | |
Raises: | |
NotImplementedError: if the space is not defined in ``gym.spaces``. | |
""" | |
raise NotImplementedError(f"Unknown space: `{space}`") | |
def _unflatten_box_multibinary( | |
space: Union[Box, MultiBinary], x: np.ndarray | |
) -> np.ndarray: | |
return np.asarray(x, dtype=space.dtype).reshape(space.shape) | |
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int: | |
return int(space.start + np.nonzero(x)[0][0]) | |
def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray: | |
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype) | |
offsets[1:] = np.cumsum(space.nvec.flatten()) | |
(indices,) = cast(type(offsets[:-1]), np.nonzero(x)) | |
return np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape) | |
def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple: | |
if space.is_np_flattenable: | |
assert isinstance( | |
x, np.ndarray | |
), f"{space} is numpy-flattenable. Thus, you should only unflatten numpy arrays for this space. Got a {type(x)}" | |
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_) | |
list_flattened = np.split(x, np.cumsum(dims[:-1])) | |
return tuple( | |
unflatten(s, flattened) | |
for flattened, s in zip(list_flattened, space.spaces) | |
) | |
assert isinstance( | |
x, tuple | |
), f"{space} is not numpy-flattenable. Thus, you should only unflatten tuples for this space. Got a {type(x)}" | |
return tuple(unflatten(s, flattened) for flattened, s in zip(x, space.spaces)) | |
def _unflatten_dict(space: Dict, x: Union[np.ndarray, TypingDict]) -> dict: | |
if space.is_np_flattenable: | |
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_) | |
list_flattened = np.split(x, np.cumsum(dims[:-1])) | |
return OrderedDict( | |
[ | |
(key, unflatten(s, flattened)) | |
for flattened, (key, s) in zip(list_flattened, space.spaces.items()) | |
] | |
) | |
assert isinstance( | |
x, dict | |
), f"{space} is not numpy-flattenable. Thus, you should only unflatten dictionary for this space. Got a {type(x)}" | |
return OrderedDict((key, unflatten(s, x[key])) for key, s in space.spaces.items()) | |
def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance: | |
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space. | |
The size of the outcome is actually not fixed, but determined based on the number of | |
nodes and edges in the graph. | |
""" | |
def _graph_unflatten(space, x): | |
ret = None | |
if space is not None and x is not None: | |
if isinstance(space, Box): | |
ret = x.reshape(-1, *space.shape) | |
elif isinstance(space, Discrete): | |
ret = np.asarray(np.nonzero(x))[-1, :] | |
return ret | |
nodes = _graph_unflatten(space.node_space, x.nodes) | |
edges = _graph_unflatten(space.edge_space, x.edges) | |
return GraphInstance(nodes, edges, x.edge_links) | |
def _unflatten_text(space: Text, x: np.ndarray) -> str: | |
return "".join( | |
[space.character_list[val] for val in x if val < len(space.character_set)] | |
) | |
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple: | |
return tuple(unflatten(space.feature_space, item) for item in x) | |
def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]: | |
"""Flatten a space into a space that is as flat as possible. | |
This function will attempt to flatten `space` into a single :class:`Box` space. | |
However, this might not be possible when `space` is an instance of :class:`Graph`, | |
:class:`Sequence` or a compound space that contains a :class:`Graph` or :class:`Sequence`space. | |
This is equivalent to :func:`flatten`, but operates on the space itself. The | |
result for non-graph spaces is always a `Box` with flat boundaries. While | |
the result for graph spaces is always a `Graph` with `node_space` being a `Box` | |
with flat boundaries and `edge_space` being a `Box` with flat boundaries or | |
`None`. The box has exactly :func:`flatdim` dimensions. Flattening a sample | |
of the original space has the same effect as taking a sample of the flattenend | |
space. | |
Example:: | |
>>> box = Box(0.0, 1.0, shape=(3, 4, 5)) | |
>>> box | |
Box(3, 4, 5) | |
>>> flatten_space(box) | |
Box(60,) | |
>>> flatten(box, box.sample()) in flatten_space(box) | |
True | |
Example that flattens a discrete space:: | |
>>> discrete = Discrete(5) | |
>>> flatten_space(discrete) | |
Box(5,) | |
>>> flatten(box, box.sample()) in flatten_space(box) | |
True | |
Example that recursively flattens a dict:: | |
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))}) | |
>>> flatten_space(space) | |
Box(6,) | |
>>> flatten(space, space.sample()) in flatten_space(space) | |
True | |
Example that flattens a graph:: | |
>>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)) | |
>>> flatten_space(space) | |
Graph(Box(-100.0, 100.0, (12,), float32), Box(0, 1, (5,), int64)) | |
>>> flatten(space, space.sample()) in flatten_space(space) | |
True | |
Args: | |
space: The space to flatten | |
Returns: | |
A flattened Box | |
Raises: | |
NotImplementedError: if the space is not defined in ``gym.spaces``. | |
""" | |
raise NotImplementedError(f"Unknown space: `{space}`") | |
def _flatten_space_box(space: Box) -> Box: | |
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype) | |
def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box: | |
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype) | |
def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]: | |
if space.is_np_flattenable: | |
space_list = [flatten_space(s) for s in space.spaces] | |
return Box( | |
low=np.concatenate([s.low for s in space_list]), | |
high=np.concatenate([s.high for s in space_list]), | |
dtype=np.result_type(*[s.dtype for s in space_list]), | |
) | |
return Tuple(spaces=[flatten_space(s) for s in space.spaces]) | |
def _flatten_space_dict(space: Dict) -> Union[Box, Dict]: | |
if space.is_np_flattenable: | |
space_list = [flatten_space(s) for s in space.spaces.values()] | |
return Box( | |
low=np.concatenate([s.low for s in space_list]), | |
high=np.concatenate([s.high for s in space_list]), | |
dtype=np.result_type(*[s.dtype for s in space_list]), | |
) | |
return Dict( | |
spaces=OrderedDict( | |
(key, flatten_space(space)) for key, space in space.spaces.items() | |
) | |
) | |
def _flatten_space_graph(space: Graph) -> Graph: | |
return Graph( | |
node_space=flatten_space(space.node_space), | |
edge_space=flatten_space(space.edge_space) | |
if space.edge_space is not None | |
else None, | |
) | |
def _flatten_space_text(space: Text) -> Box: | |
return Box( | |
low=0, high=len(space.character_set), shape=(space.max_length,), dtype=np.int32 | |
) | |
def _flatten_space_sequence(space: Sequence) -> Sequence: | |
return Sequence(flatten_space(space.feature_space)) | |