Spaces:
Running
Running
File size: 4,925 Bytes
85e3d20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Encoder utilities."""
import functools
import chex
from clrs._src import probing
from clrs._src import specs
import haiku as hk
import jax.numpy as jnp
_Array = chex.Array
_DataPoint = probing.DataPoint
_Location = specs.Location
_Spec = specs.Spec
_Stage = specs.Stage
_Type = specs.Type
def construct_encoders(stage: str, loc: str, t: str,
hidden_dim: int, init: str, name: str):
"""Constructs encoders."""
if init == 'xavier_on_scalars' and stage == _Stage.HINT and t == _Type.SCALAR:
initialiser = hk.initializers.TruncatedNormal(
stddev=1.0 / jnp.sqrt(hidden_dim))
elif init in ['default', 'xavier_on_scalars']:
initialiser = None
else:
raise ValueError(f'Encoder initialiser {init} not supported.')
linear = functools.partial(
hk.Linear,
w_init=initialiser,
name=f'{name}_enc_linear')
encoders = [linear(hidden_dim)]
if loc == _Location.EDGE and t == _Type.POINTER:
# Edge pointers need two-way encoders.
encoders.append(linear(hidden_dim))
return encoders
def preprocess(dp: _DataPoint, nb_nodes: int) -> _DataPoint:
"""Pre-process data point.
Make sure that the data is ready to be encoded into features.
If the data is of POINTER type, we expand the compressed index representation
to a full one-hot. But if the data is a SOFT_POINTER, the representation
is already expanded and we just overwrite the type as POINTER so that
it is treated as such for encoding.
Args:
dp: A DataPoint to prepare for encoding.
nb_nodes: Number of nodes in the graph, necessary to expand pointers to
the right dimension.
Returns:
The datapoint, with data and possibly type modified.
"""
new_type = dp.type_
if dp.type_ == _Type.POINTER:
data = hk.one_hot(dp.data, nb_nodes)
else:
data = dp.data.astype(jnp.float32)
if dp.type_ == _Type.SOFT_POINTER:
new_type = _Type.POINTER
dp = probing.DataPoint(
name=dp.name, location=dp.location, type_=new_type, data=data)
return dp
def accum_adj_mat(dp: _DataPoint, adj_mat: _Array) -> _Array:
"""Accumulates adjacency matrix."""
if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER,
_Type.PERMUTATION_POINTER]:
adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.5)
elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK:
adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.0)
return (adj_mat > 0.).astype('float32') # pytype: disable=attribute-error # numpy-scalars
def accum_edge_fts(encoders, dp: _DataPoint, edge_fts: _Array) -> _Array:
"""Encodes and accumulates edge features."""
if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER,
_Type.PERMUTATION_POINTER]:
encoding = _encode_inputs(encoders, dp)
edge_fts += encoding
elif dp.location == _Location.EDGE:
encoding = _encode_inputs(encoders, dp)
if dp.type_ == _Type.POINTER:
# Aggregate pointer contributions across sender and receiver nodes.
encoding_2 = encoders[1](jnp.expand_dims(dp.data, -1))
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2)
else:
edge_fts += encoding
return edge_fts
def accum_node_fts(encoders, dp: _DataPoint, node_fts: _Array) -> _Array:
"""Encodes and accumulates node features."""
is_pointer = (dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER])
if ((dp.location == _Location.NODE and not is_pointer) or
(dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)):
encoding = _encode_inputs(encoders, dp)
node_fts += encoding
return node_fts
def accum_graph_fts(encoders, dp: _DataPoint,
graph_fts: _Array) -> _Array:
"""Encodes and accumulates graph features."""
if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER:
encoding = _encode_inputs(encoders, dp)
graph_fts += encoding
return graph_fts
def _encode_inputs(encoders, dp: _DataPoint) -> _Array:
if dp.type_ == _Type.CATEGORICAL:
encoding = encoders[0](dp.data)
else:
encoding = encoders[0](jnp.expand_dims(dp.data, -1))
return encoding
|