Spaces:
Running
Running
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Encoder utilities.""" | |
import functools | |
import chex | |
from clrs._src import probing | |
from clrs._src import specs | |
import haiku as hk | |
import jax.numpy as jnp | |
_Array = chex.Array | |
_DataPoint = probing.DataPoint | |
_Location = specs.Location | |
_Spec = specs.Spec | |
_Stage = specs.Stage | |
_Type = specs.Type | |
def construct_encoders(stage: str, loc: str, t: str, | |
hidden_dim: int, init: str, name: str): | |
"""Constructs encoders.""" | |
if init == 'xavier_on_scalars' and stage == _Stage.HINT and t == _Type.SCALAR: | |
initialiser = hk.initializers.TruncatedNormal( | |
stddev=1.0 / jnp.sqrt(hidden_dim)) | |
elif init in ['default', 'xavier_on_scalars']: | |
initialiser = None | |
else: | |
raise ValueError(f'Encoder initialiser {init} not supported.') | |
linear = functools.partial( | |
hk.Linear, | |
w_init=initialiser, | |
name=f'{name}_enc_linear') | |
encoders = [linear(hidden_dim)] | |
if loc == _Location.EDGE and t == _Type.POINTER: | |
# Edge pointers need two-way encoders. | |
encoders.append(linear(hidden_dim)) | |
return encoders | |
def preprocess(dp: _DataPoint, nb_nodes: int) -> _DataPoint: | |
"""Pre-process data point. | |
Make sure that the data is ready to be encoded into features. | |
If the data is of POINTER type, we expand the compressed index representation | |
to a full one-hot. But if the data is a SOFT_POINTER, the representation | |
is already expanded and we just overwrite the type as POINTER so that | |
it is treated as such for encoding. | |
Args: | |
dp: A DataPoint to prepare for encoding. | |
nb_nodes: Number of nodes in the graph, necessary to expand pointers to | |
the right dimension. | |
Returns: | |
The datapoint, with data and possibly type modified. | |
""" | |
new_type = dp.type_ | |
if dp.type_ == _Type.POINTER: | |
data = hk.one_hot(dp.data, nb_nodes) | |
else: | |
data = dp.data.astype(jnp.float32) | |
if dp.type_ == _Type.SOFT_POINTER: | |
new_type = _Type.POINTER | |
dp = probing.DataPoint( | |
name=dp.name, location=dp.location, type_=new_type, data=data) | |
return dp | |
def accum_adj_mat(dp: _DataPoint, adj_mat: _Array) -> _Array: | |
"""Accumulates adjacency matrix.""" | |
if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER, | |
_Type.PERMUTATION_POINTER]: | |
adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.5) | |
elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK: | |
adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.0) | |
return (adj_mat > 0.).astype('float32') # pytype: disable=attribute-error # numpy-scalars | |
def accum_edge_fts(encoders, dp: _DataPoint, edge_fts: _Array) -> _Array: | |
"""Encodes and accumulates edge features.""" | |
if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER, | |
_Type.PERMUTATION_POINTER]: | |
encoding = _encode_inputs(encoders, dp) | |
edge_fts += encoding | |
elif dp.location == _Location.EDGE: | |
encoding = _encode_inputs(encoders, dp) | |
if dp.type_ == _Type.POINTER: | |
# Aggregate pointer contributions across sender and receiver nodes. | |
encoding_2 = encoders[1](jnp.expand_dims(dp.data, -1)) | |
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2) | |
else: | |
edge_fts += encoding | |
return edge_fts | |
def accum_node_fts(encoders, dp: _DataPoint, node_fts: _Array) -> _Array: | |
"""Encodes and accumulates node features.""" | |
is_pointer = (dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER]) | |
if ((dp.location == _Location.NODE and not is_pointer) or | |
(dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)): | |
encoding = _encode_inputs(encoders, dp) | |
node_fts += encoding | |
return node_fts | |
def accum_graph_fts(encoders, dp: _DataPoint, | |
graph_fts: _Array) -> _Array: | |
"""Encodes and accumulates graph features.""" | |
if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER: | |
encoding = _encode_inputs(encoders, dp) | |
graph_fts += encoding | |
return graph_fts | |
def _encode_inputs(encoders, dp: _DataPoint) -> _Array: | |
if dp.type_ == _Type.CATEGORICAL: | |
encoding = encoders[0](dp.data) | |
else: | |
encoding = encoders[0](jnp.expand_dims(dp.data, -1)) | |
return encoding | |