File size: 4,925 Bytes
85e3d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Encoder utilities."""

import functools
import chex
from clrs._src import probing
from clrs._src import specs
import haiku as hk
import jax.numpy as jnp

_Array = chex.Array
_DataPoint = probing.DataPoint
_Location = specs.Location
_Spec = specs.Spec
_Stage = specs.Stage
_Type = specs.Type


def construct_encoders(stage: str, loc: str, t: str,
                       hidden_dim: int, init: str, name: str):
  """Constructs encoders."""
  if init == 'xavier_on_scalars' and stage == _Stage.HINT and t == _Type.SCALAR:
    initialiser = hk.initializers.TruncatedNormal(
        stddev=1.0 / jnp.sqrt(hidden_dim))
  elif init in ['default', 'xavier_on_scalars']:
    initialiser = None
  else:
    raise ValueError(f'Encoder initialiser {init} not supported.')
  linear = functools.partial(
      hk.Linear,
      w_init=initialiser,
      name=f'{name}_enc_linear')
  encoders = [linear(hidden_dim)]
  if loc == _Location.EDGE and t == _Type.POINTER:
    # Edge pointers need two-way encoders.
    encoders.append(linear(hidden_dim))

  return encoders


def preprocess(dp: _DataPoint, nb_nodes: int) -> _DataPoint:
  """Pre-process data point.

  Make sure that the data is ready to be encoded into features.
  If the data is of POINTER type, we expand the compressed index representation
  to a full one-hot. But if the data is a SOFT_POINTER, the representation
  is already expanded and we just overwrite the type as POINTER so that
  it is treated as such for encoding.

  Args:
    dp: A DataPoint to prepare for encoding.
    nb_nodes: Number of nodes in the graph, necessary to expand pointers to
      the right dimension.
  Returns:
    The datapoint, with data and possibly type modified.
  """
  new_type = dp.type_
  if dp.type_ == _Type.POINTER:
    data = hk.one_hot(dp.data, nb_nodes)
  else:
    data = dp.data.astype(jnp.float32)
    if dp.type_ == _Type.SOFT_POINTER:
      new_type = _Type.POINTER
  dp = probing.DataPoint(
      name=dp.name, location=dp.location, type_=new_type, data=data)

  return dp


def accum_adj_mat(dp: _DataPoint, adj_mat: _Array) -> _Array:
  """Accumulates adjacency matrix."""
  if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER,
                                                    _Type.PERMUTATION_POINTER]:
    adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.5)
  elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK:
    adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.0)

  return (adj_mat > 0.).astype('float32')  # pytype: disable=attribute-error  # numpy-scalars


def accum_edge_fts(encoders, dp: _DataPoint, edge_fts: _Array) -> _Array:
  """Encodes and accumulates edge features."""
  if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER,
                                                    _Type.PERMUTATION_POINTER]:
    encoding = _encode_inputs(encoders, dp)
    edge_fts += encoding

  elif dp.location == _Location.EDGE:
    encoding = _encode_inputs(encoders, dp)
    if dp.type_ == _Type.POINTER:
      # Aggregate pointer contributions across sender and receiver nodes.
      encoding_2 = encoders[1](jnp.expand_dims(dp.data, -1))
      edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2)
    else:
      edge_fts += encoding

  return edge_fts


def accum_node_fts(encoders, dp: _DataPoint, node_fts: _Array) -> _Array:
  """Encodes and accumulates node features."""
  is_pointer = (dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER])
  if ((dp.location == _Location.NODE and not is_pointer) or
      (dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)):
    encoding = _encode_inputs(encoders, dp)
    node_fts += encoding

  return node_fts


def accum_graph_fts(encoders, dp: _DataPoint,
                    graph_fts: _Array) -> _Array:
  """Encodes and accumulates graph features."""
  if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER:
    encoding = _encode_inputs(encoders, dp)
    graph_fts += encoding

  return graph_fts


def _encode_inputs(encoders, dp: _DataPoint) -> _Array:
  if dp.type_ == _Type.CATEGORICAL:
    encoding = encoders[0](dp.data)
  else:
    encoding = encoders[0](jnp.expand_dims(dp.data, -1))
  return encoding