File size: 6,837 Bytes
85e3d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Model base classes and utilities."""

from typing import Dict, List, Tuple
import chex
from clrs._src import probing
from clrs._src import specs
import numpy as np


_Array = chex.Array
Result = Dict[str, probing.DataPoint]


def fuse_perm_and_mask(perm: probing.DataPoint,
                       mask: probing.DataPoint) -> probing.DataPoint:
  """Replace permutation pointers active in the mask with self-pointers.

  Args:
    perm: a node permutation_pointer; data shape is expected to be
      [..., N, N], and ideally one-hot over the last two dimensions, although
      this method does not check for one-hotness.
    mask: a mask_one over nodes; data shape is expected to be
      [..., N], and ideally one-hot over the last dimension, although
      this method does not check for one-hotness.
  Returns:
    A node pointer with shape [..., N].
  """
  assert perm.type_ == specs.Type.PERMUTATION_POINTER
  assert perm.location == specs.Location.NODE
  assert mask.name == perm.name + '_mask'
  assert mask.type_ == specs.Type.MASK_ONE
  assert mask.location == specs.Location.NODE
  assert perm.data.shape[-1] == perm.data.shape[-2]
  assert perm.data.shape[:-1] == mask.data.shape
  data = np.where(mask.data > 0.5,
                  np.arange(perm.data.shape[-1]),  # self-pointers
                  np.argmax(perm.data, axis=-1))   # original pointers
  return probing.DataPoint(name=perm.name,
                           type_=specs.Type.POINTER,
                           location=perm.location,
                           data=data)


def _reduce_permutations_tuple(
    targets: Tuple[probing.DataPoint, ...]) -> Tuple[probing.DataPoint, ...]:
  """Reduce node pointer + mask_one permutation to just node pointer."""
  out_targets = []
  n_perms = 0
  i = 0
  while i < len(targets):
    truth = targets[i]
    if truth.type_ != specs.Type.PERMUTATION_POINTER:
      out_targets.append(truth)
      i += 1
      continue
    truth_mask = targets[i + 1]
    out_targets.append(fuse_perm_and_mask(truth, truth_mask))
    i += 2
    n_perms += 1

  assert len(out_targets) == len(targets) - n_perms
  return tuple(out_targets)


def _reduce_permutations_dict(predictions: Result) -> Result:
  """Reduce node pointer + mask_one permutation to just node pointer."""
  out_preds = {}
  n_perms = 0
  for k, pred in predictions.items():
    if (k.endswith('_mask') and k[:-5] in predictions and
        predictions[k[:-5]].type_ == specs.Type.PERMUTATION_POINTER):
      # This mask will be processed with its associated permutation datapoint
      continue
    if pred.type_ != specs.Type.PERMUTATION_POINTER:
      out_preds[k] = pred
      continue
    pred_mask = predictions[k + '_mask']
    out_preds[k] = fuse_perm_and_mask(pred, pred_mask)
    n_perms += 1

  assert len(out_preds) == len(predictions) - n_perms
  return out_preds


def evaluate_hints(
    hints: Tuple[probing.DataPoint, ...],
    lengths: _Array,
    hint_preds: List[Result],
) -> Dict[str, _Array]:
  """Evaluate hint predictions."""
  evals = {}
  hints = _reduce_permutations_tuple(hints)
  hint_preds = [_reduce_permutations_dict(h) for h in hint_preds]
  for truth in hints:
    assert truth.name in hint_preds[0]
    eval_along_time = [_evaluate(truth, p[truth.name],
                                 idx=i+1, lengths=lengths)
                       for (i, p) in enumerate(hint_preds)]
    evals[truth.name] = np.sum(
        [x * np.sum(i+1 < lengths)
         for i, x in enumerate(eval_along_time)]) / np.sum(lengths - 1)
    evals[truth.name + '_along_time'] = np.array(eval_along_time)

  # Unlike outputs, the hints sometimes include scalars, which don't have
  # a meaningful eval score. So we don't compute a global 'hint score' as we
  # do for outputs.
  return evals


def evaluate(
    outputs: Tuple[probing.DataPoint, ...],
    predictions: Result,
) -> Dict[str, float]:
  """Evaluate output predictions."""
  evals = {}
  outputs = _reduce_permutations_tuple(outputs)
  predictions = _reduce_permutations_dict(predictions)
  for truth in outputs:
    assert truth.name in predictions
    pred = predictions[truth.name]
    evals[truth.name] = _evaluate(truth, pred)
  # Return a single scalar score that is the mean of all output scores.
  evals['score'] = sum([v.item() for v in evals.values()]) / len(evals)
  return evals


def _evaluate(truth, pred, idx=None, lengths=None):
  """Evaluate single prediction of hint or output."""
  assert pred.name == truth.name
  assert pred.location == truth.location
  assert pred.type_ == truth.type_

  if truth.type_ not in _EVAL_FN:
    raise ValueError('Invalid type')
  truth_data = truth.data
  pred_data = pred.data
  if idx is not None:
    if np.all(idx >= lengths):
      return 0.
    truth_data = truth_data[idx][idx < lengths]
    pred_data = pred_data[idx < lengths]
  return _EVAL_FN[truth.type_](pred_data, truth_data)


def _eval_one(pred, truth):
  mask = np.all(truth != specs.OutputClass.MASKED, axis=-1)
  return np.sum(
      (np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask)


def _mask_fn(pred, truth):
  """Evaluate outputs of type MASK, and account for any class imbalance."""
  mask = (truth != specs.OutputClass.MASKED).astype(np.float32)

  # Use F1 score for the masked outputs to address any imbalance
  tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask)
  fp = np.sum((((pred > 0.5) * (truth < 0.5)) * 1.0) * mask)
  fn = np.sum((((pred < 0.5) * (truth > 0.5)) * 1.0) * mask)

  # Protect against division by zero
  if tp + fp > 0:
    precision = tp / (tp + fp)
  else:
    precision = np.float32(1.0)
  if tp + fn > 0:
    recall = tp / (tp + fn)
  else:
    recall = np.float32(1.0)

  if precision + recall > 0.0:
    f_1 = 2.0 * precision * recall / (precision + recall)
  else:
    f_1 = np.float32(0.0)

  return f_1

_EVAL_FN = {
    specs.Type.SCALAR:
        lambda pred, truth: np.mean((pred - truth)**2),
    specs.Type.MASK: _mask_fn,
    specs.Type.MASK_ONE:
        _eval_one,
    specs.Type.CATEGORICAL:
        _eval_one,
    specs.Type.POINTER:
        lambda pred, truth: np.mean((pred == truth) * 1.0),
}