Spaces:
Sleeping
Sleeping
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space.""" | |
from typing import NamedTuple, Optional, Sequence, Tuple, Union | |
import numpy as np | |
from gym.logger import warn | |
from gym.spaces.box import Box | |
from gym.spaces.discrete import Discrete | |
from gym.spaces.multi_discrete import MultiDiscrete | |
from gym.spaces.space import Space | |
class GraphInstance(NamedTuple): | |
"""A Graph space instance. | |
* nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes, (...) must adhere to the shape of the node space. | |
* edges (Optional[np.ndarray]): an (m x ...) sized array representing the features for m edges, (...) must adhere to the shape of the edge space. | |
* edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects. | |
""" | |
nodes: np.ndarray | |
edges: Optional[np.ndarray] | |
edge_links: Optional[np.ndarray] | |
class Graph(Space): | |
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`. | |
Example usage:: | |
self.observation_space = spaces.Graph(node_space=space.Box(low=-100, high=100, shape=(3,)), edge_space=spaces.Discrete(3)) | |
""" | |
def __init__( | |
self, | |
node_space: Union[Box, Discrete], | |
edge_space: Union[None, Box, Discrete], | |
seed: Optional[Union[int, np.random.Generator]] = None, | |
): | |
r"""Constructor of :class:`Graph`. | |
The argument ``node_space`` specifies the base space that each node feature will use. | |
This argument must be either a Box or Discrete instance. | |
The argument ``edge_space`` specifies the base space that each edge feature will use. | |
This argument must be either a None, Box or Discrete instance. | |
Args: | |
node_space (Union[Box, Discrete]): space of the node features. | |
edge_space (Union[None, Box, Discrete]): space of the node features. | |
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. | |
""" | |
assert isinstance( | |
node_space, (Box, Discrete) | |
), f"Values of the node_space should be instances of Box or Discrete, got {type(node_space)}" | |
if edge_space is not None: | |
assert isinstance( | |
edge_space, (Box, Discrete) | |
), f"Values of the edge_space should be instances of None Box or Discrete, got {type(node_space)}" | |
self.node_space = node_space | |
self.edge_space = edge_space | |
super().__init__(None, None, seed) | |
def is_np_flattenable(self): | |
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" | |
return False | |
def _generate_sample_space( | |
self, base_space: Union[None, Box, Discrete], num: int | |
) -> Optional[Union[Box, MultiDiscrete]]: | |
if num == 0 or base_space is None: | |
return None | |
if isinstance(base_space, Box): | |
return Box( | |
low=np.array(max(1, num) * [base_space.low]), | |
high=np.array(max(1, num) * [base_space.high]), | |
shape=(num,) + base_space.shape, | |
dtype=base_space.dtype, | |
seed=self.np_random, | |
) | |
elif isinstance(base_space, Discrete): | |
return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random) | |
else: | |
raise TypeError( | |
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}." | |
) | |
def sample( | |
self, | |
mask: Optional[ | |
Tuple[ | |
Optional[Union[np.ndarray, tuple]], | |
Optional[Union[np.ndarray, tuple]], | |
] | |
] = None, | |
num_nodes: int = 10, | |
num_edges: Optional[int] = None, | |
) -> GraphInstance: | |
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph. | |
Args: | |
mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces | |
(Box spaces don't support sample masks). | |
If no `num_edges` is provided then the `edge_mask` is multiplied by the number of edges | |
num_nodes: The number of nodes that will be sampled, the default is 10 nodes | |
num_edges: An optional number of edges, otherwise, a random number between 0 and `num_nodes`^2 | |
Returns: | |
A NamedTuple representing a graph with attributes .nodes, .edges, and .edge_links. | |
""" | |
assert ( | |
num_nodes > 0 | |
), f"The number of nodes is expected to be greater than 0, actual value: {num_nodes}" | |
if mask is not None: | |
node_space_mask, edge_space_mask = mask | |
else: | |
node_space_mask, edge_space_mask = None, None | |
# we only have edges when we have at least 2 nodes | |
if num_edges is None: | |
if num_nodes > 1: | |
# maximal number of edges is `n*(n-1)` allowing self connections and two-way is allowed | |
num_edges = self.np_random.integers(num_nodes * (num_nodes - 1)) | |
else: | |
num_edges = 0 | |
if edge_space_mask is not None: | |
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges)) | |
else: | |
if self.edge_space is None: | |
warn( | |
f"The number of edges is set ({num_edges}) but the edge space is None." | |
) | |
assert ( | |
num_edges >= 0 | |
), f"Expects the number of edges to be greater than 0, actual value: {num_edges}" | |
assert num_edges is not None | |
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes) | |
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges) | |
assert sampled_node_space is not None | |
sampled_nodes = sampled_node_space.sample(node_space_mask) | |
sampled_edges = ( | |
sampled_edge_space.sample(edge_space_mask) | |
if sampled_edge_space is not None | |
else None | |
) | |
sampled_edge_links = None | |
if sampled_edges is not None and num_edges > 0: | |
sampled_edge_links = self.np_random.integers( | |
low=0, high=num_nodes, size=(num_edges, 2) | |
) | |
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links) | |
def contains(self, x: GraphInstance) -> bool: | |
"""Return boolean specifying if x is a valid member of this space.""" | |
if isinstance(x, GraphInstance): | |
# Checks the nodes | |
if isinstance(x.nodes, np.ndarray): | |
if all(node in self.node_space for node in x.nodes): | |
# Check the edges and edge links which are optional | |
if isinstance(x.edges, np.ndarray) and isinstance( | |
x.edge_links, np.ndarray | |
): | |
assert x.edges is not None | |
assert x.edge_links is not None | |
if self.edge_space is not None: | |
if all(edge in self.edge_space for edge in x.edges): | |
if np.issubdtype(x.edge_links.dtype, np.integer): | |
if x.edge_links.shape == (len(x.edges), 2): | |
if np.all( | |
np.logical_and( | |
x.edge_links >= 0, | |
x.edge_links < len(x.nodes), | |
) | |
): | |
return True | |
else: | |
return x.edges is None and x.edge_links is None | |
return False | |
def __repr__(self) -> str: | |
"""A string representation of this space. | |
The representation will include node_space and edge_space | |
Returns: | |
A representation of the space | |
""" | |
return f"Graph({self.node_space}, {self.edge_space})" | |
def __eq__(self, other) -> bool: | |
"""Check whether `other` is equivalent to this instance.""" | |
return ( | |
isinstance(other, Graph) | |
and (self.node_space == other.node_space) | |
and (self.edge_space == other.edge_space) | |
) | |
def to_jsonable(self, sample_n: NamedTuple) -> list: | |
"""Convert a batch of samples from this space to a JSONable data type.""" | |
# serialize as list of dicts | |
ret_n = [] | |
for sample in sample_n: | |
ret = {} | |
ret["nodes"] = sample.nodes.tolist() | |
if sample.edges is not None: | |
ret["edges"] = sample.edges.tolist() | |
ret["edge_links"] = sample.edge_links.tolist() | |
ret_n.append(ret) | |
return ret_n | |
def from_jsonable(self, sample_n: Sequence[dict]) -> list: | |
"""Convert a JSONable data type to a batch of samples from this space.""" | |
ret = [] | |
for sample in sample_n: | |
if "edges" in sample: | |
ret_n = GraphInstance( | |
np.asarray(sample["nodes"]), | |
np.asarray(sample["edges"]), | |
np.asarray(sample["edge_links"]), | |
) | |
else: | |
ret_n = GraphInstance( | |
np.asarray(sample["nodes"]), | |
None, | |
None, | |
) | |
ret.append(ret_n) | |
return ret | |