Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
6.84 kB
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model base classes and utilities."""
from typing import Dict, List, Tuple
import chex
from clrs._src import probing
from clrs._src import specs
import numpy as np
_Array = chex.Array
Result = Dict[str, probing.DataPoint]
def fuse_perm_and_mask(perm: probing.DataPoint,
mask: probing.DataPoint) -> probing.DataPoint:
"""Replace permutation pointers active in the mask with self-pointers.
Args:
perm: a node permutation_pointer; data shape is expected to be
[..., N, N], and ideally one-hot over the last two dimensions, although
this method does not check for one-hotness.
mask: a mask_one over nodes; data shape is expected to be
[..., N], and ideally one-hot over the last dimension, although
this method does not check for one-hotness.
Returns:
A node pointer with shape [..., N].
"""
assert perm.type_ == specs.Type.PERMUTATION_POINTER
assert perm.location == specs.Location.NODE
assert mask.name == perm.name + '_mask'
assert mask.type_ == specs.Type.MASK_ONE
assert mask.location == specs.Location.NODE
assert perm.data.shape[-1] == perm.data.shape[-2]
assert perm.data.shape[:-1] == mask.data.shape
data = np.where(mask.data > 0.5,
np.arange(perm.data.shape[-1]), # self-pointers
np.argmax(perm.data, axis=-1)) # original pointers
return probing.DataPoint(name=perm.name,
type_=specs.Type.POINTER,
location=perm.location,
data=data)
def _reduce_permutations_tuple(
targets: Tuple[probing.DataPoint, ...]) -> Tuple[probing.DataPoint, ...]:
"""Reduce node pointer + mask_one permutation to just node pointer."""
out_targets = []
n_perms = 0
i = 0
while i < len(targets):
truth = targets[i]
if truth.type_ != specs.Type.PERMUTATION_POINTER:
out_targets.append(truth)
i += 1
continue
truth_mask = targets[i + 1]
out_targets.append(fuse_perm_and_mask(truth, truth_mask))
i += 2
n_perms += 1
assert len(out_targets) == len(targets) - n_perms
return tuple(out_targets)
def _reduce_permutations_dict(predictions: Result) -> Result:
"""Reduce node pointer + mask_one permutation to just node pointer."""
out_preds = {}
n_perms = 0
for k, pred in predictions.items():
if (k.endswith('_mask') and k[:-5] in predictions and
predictions[k[:-5]].type_ == specs.Type.PERMUTATION_POINTER):
# This mask will be processed with its associated permutation datapoint
continue
if pred.type_ != specs.Type.PERMUTATION_POINTER:
out_preds[k] = pred
continue
pred_mask = predictions[k + '_mask']
out_preds[k] = fuse_perm_and_mask(pred, pred_mask)
n_perms += 1
assert len(out_preds) == len(predictions) - n_perms
return out_preds
def evaluate_hints(
hints: Tuple[probing.DataPoint, ...],
lengths: _Array,
hint_preds: List[Result],
) -> Dict[str, _Array]:
"""Evaluate hint predictions."""
evals = {}
hints = _reduce_permutations_tuple(hints)
hint_preds = [_reduce_permutations_dict(h) for h in hint_preds]
for truth in hints:
assert truth.name in hint_preds[0]
eval_along_time = [_evaluate(truth, p[truth.name],
idx=i+1, lengths=lengths)
for (i, p) in enumerate(hint_preds)]
evals[truth.name] = np.sum(
[x * np.sum(i+1 < lengths)
for i, x in enumerate(eval_along_time)]) / np.sum(lengths - 1)
evals[truth.name + '_along_time'] = np.array(eval_along_time)
# Unlike outputs, the hints sometimes include scalars, which don't have
# a meaningful eval score. So we don't compute a global 'hint score' as we
# do for outputs.
return evals
def evaluate(
outputs: Tuple[probing.DataPoint, ...],
predictions: Result,
) -> Dict[str, float]:
"""Evaluate output predictions."""
evals = {}
outputs = _reduce_permutations_tuple(outputs)
predictions = _reduce_permutations_dict(predictions)
for truth in outputs:
assert truth.name in predictions
pred = predictions[truth.name]
evals[truth.name] = _evaluate(truth, pred)
# Return a single scalar score that is the mean of all output scores.
evals['score'] = sum([v.item() for v in evals.values()]) / len(evals)
return evals
def _evaluate(truth, pred, idx=None, lengths=None):
"""Evaluate single prediction of hint or output."""
assert pred.name == truth.name
assert pred.location == truth.location
assert pred.type_ == truth.type_
if truth.type_ not in _EVAL_FN:
raise ValueError('Invalid type')
truth_data = truth.data
pred_data = pred.data
if idx is not None:
if np.all(idx >= lengths):
return 0.
truth_data = truth_data[idx][idx < lengths]
pred_data = pred_data[idx < lengths]
return _EVAL_FN[truth.type_](pred_data, truth_data)
def _eval_one(pred, truth):
mask = np.all(truth != specs.OutputClass.MASKED, axis=-1)
return np.sum(
(np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask)
def _mask_fn(pred, truth):
"""Evaluate outputs of type MASK, and account for any class imbalance."""
mask = (truth != specs.OutputClass.MASKED).astype(np.float32)
# Use F1 score for the masked outputs to address any imbalance
tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask)
fp = np.sum((((pred > 0.5) * (truth < 0.5)) * 1.0) * mask)
fn = np.sum((((pred < 0.5) * (truth > 0.5)) * 1.0) * mask)
# Protect against division by zero
if tp + fp > 0:
precision = tp / (tp + fp)
else:
precision = np.float32(1.0)
if tp + fn > 0:
recall = tp / (tp + fn)
else:
recall = np.float32(1.0)
if precision + recall > 0.0:
f_1 = 2.0 * precision * recall / (precision + recall)
else:
f_1 = np.float32(0.0)
return f_1
_EVAL_FN = {
specs.Type.SCALAR:
lambda pred, truth: np.mean((pred - truth)**2),
specs.Type.MASK: _mask_fn,
specs.Type.MASK_ONE:
_eval_one,
specs.Type.CATEGORICAL:
_eval_one,
specs.Type.POINTER:
lambda pred, truth: np.mean((pred == truth) * 1.0),
}