# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Encoder utilities.""" import functools import chex from clrs._src import probing from clrs._src import specs import haiku as hk import jax.numpy as jnp _Array = chex.Array _DataPoint = probing.DataPoint _Location = specs.Location _Spec = specs.Spec _Stage = specs.Stage _Type = specs.Type def construct_encoders(stage: str, loc: str, t: str, hidden_dim: int, init: str, name: str): """Constructs encoders.""" if init == 'xavier_on_scalars' and stage == _Stage.HINT and t == _Type.SCALAR: initialiser = hk.initializers.TruncatedNormal( stddev=1.0 / jnp.sqrt(hidden_dim)) elif init in ['default', 'xavier_on_scalars']: initialiser = None else: raise ValueError(f'Encoder initialiser {init} not supported.') linear = functools.partial( hk.Linear, w_init=initialiser, name=f'{name}_enc_linear') encoders = [linear(hidden_dim)] if loc == _Location.EDGE and t == _Type.POINTER: # Edge pointers need two-way encoders. encoders.append(linear(hidden_dim)) return encoders def preprocess(dp: _DataPoint, nb_nodes: int) -> _DataPoint: """Pre-process data point. Make sure that the data is ready to be encoded into features. If the data is of POINTER type, we expand the compressed index representation to a full one-hot. But if the data is a SOFT_POINTER, the representation is already expanded and we just overwrite the type as POINTER so that it is treated as such for encoding. Args: dp: A DataPoint to prepare for encoding. nb_nodes: Number of nodes in the graph, necessary to expand pointers to the right dimension. Returns: The datapoint, with data and possibly type modified. """ new_type = dp.type_ if dp.type_ == _Type.POINTER: data = hk.one_hot(dp.data, nb_nodes) else: data = dp.data.astype(jnp.float32) if dp.type_ == _Type.SOFT_POINTER: new_type = _Type.POINTER dp = probing.DataPoint( name=dp.name, location=dp.location, type_=new_type, data=data) return dp def accum_adj_mat(dp: _DataPoint, adj_mat: _Array) -> _Array: """Accumulates adjacency matrix.""" if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER]: adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.5) elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK: adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.0) return (adj_mat > 0.).astype('float32') # pytype: disable=attribute-error # numpy-scalars def accum_edge_fts(encoders, dp: _DataPoint, edge_fts: _Array) -> _Array: """Encodes and accumulates edge features.""" if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER]: encoding = _encode_inputs(encoders, dp) edge_fts += encoding elif dp.location == _Location.EDGE: encoding = _encode_inputs(encoders, dp) if dp.type_ == _Type.POINTER: # Aggregate pointer contributions across sender and receiver nodes. encoding_2 = encoders[1](jnp.expand_dims(dp.data, -1)) edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2) else: edge_fts += encoding return edge_fts def accum_node_fts(encoders, dp: _DataPoint, node_fts: _Array) -> _Array: """Encodes and accumulates node features.""" is_pointer = (dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER]) if ((dp.location == _Location.NODE and not is_pointer) or (dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)): encoding = _encode_inputs(encoders, dp) node_fts += encoding return node_fts def accum_graph_fts(encoders, dp: _DataPoint, graph_fts: _Array) -> _Array: """Encodes and accumulates graph features.""" if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER: encoding = _encode_inputs(encoders, dp) graph_fts += encoding return graph_fts def _encode_inputs(encoders, dp: _DataPoint) -> _Array: if dp.type_ == _Type.CATEGORICAL: encoding = encoders[0](dp.data) else: encoding = encoders[0](jnp.expand_dims(dp.data, -1)) return encoding