Spaces:
Running
Running
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Utilities for calculating losses.""" | |
from typing import Dict, List, Tuple | |
import chex | |
from clrs._src import probing | |
from clrs._src import specs | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
_Array = chex.Array | |
_DataPoint = probing.DataPoint | |
_Location = specs.Location | |
_OutputClass = specs.OutputClass | |
_PredTrajectory = Dict[str, _Array] | |
_PredTrajectories = List[_PredTrajectory] | |
_Type = specs.Type | |
EPS = 1e-12 | |
def _expand_to(x: _Array, y: _Array) -> _Array: | |
while len(y.shape) > len(x.shape): | |
x = jnp.expand_dims(x, -1) | |
return x | |
def _expand_and_broadcast_to(x: _Array, y: _Array) -> _Array: | |
return jnp.broadcast_to(_expand_to(x, y), y.shape) | |
def output_loss_chunked(truth: _DataPoint, pred: _Array, | |
is_last: _Array, nb_nodes: int) -> float: | |
"""Output loss for time-chunked training.""" | |
mask = None | |
if truth.type_ == _Type.SCALAR: | |
loss = (pred - truth.data)**2 | |
elif truth.type_ == _Type.MASK: | |
loss = ( | |
jnp.maximum(pred, 0) - pred * truth.data + | |
jnp.log1p(jnp.exp(-jnp.abs(pred)))) | |
mask = (truth.data != _OutputClass.MASKED) | |
elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]: | |
mask = jnp.any(truth.data == _OutputClass.POSITIVE, axis=-1) | |
masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype( | |
jnp.float32) | |
loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred), axis=-1) | |
elif truth.type_ == _Type.POINTER: | |
loss = -jnp.sum( | |
hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), axis=-1) | |
elif truth.type_ == _Type.PERMUTATION_POINTER: | |
# Predictions are NxN logits aiming to represent a doubly stochastic matrix. | |
# Compute the cross entropy between doubly stochastic pred and truth_data | |
loss = -jnp.sum(truth.data * pred, axis=-1) | |
if mask is not None: | |
mask = mask * _expand_and_broadcast_to(is_last, loss) | |
else: | |
mask = _expand_and_broadcast_to(is_last, loss) | |
total_mask = jnp.maximum(jnp.sum(mask), EPS) | |
return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask | |
def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float: | |
"""Output loss for full-sample training.""" | |
if truth.type_ == _Type.SCALAR: | |
total_loss = jnp.mean((pred - truth.data)**2) | |
elif truth.type_ == _Type.MASK: | |
loss = ( | |
jnp.maximum(pred, 0) - pred * truth.data + | |
jnp.log1p(jnp.exp(-jnp.abs(pred)))) | |
mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32) | |
total_loss = jnp.sum(loss * mask) / jnp.sum(mask) | |
elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]: | |
masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype( | |
jnp.float32) | |
total_loss = (-jnp.sum(masked_truth * jax.nn.log_softmax(pred)) / | |
jnp.sum(truth.data == _OutputClass.POSITIVE)) | |
elif truth.type_ == _Type.POINTER: | |
total_loss = ( | |
jnp.mean(-jnp.sum( | |
hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), | |
axis=-1))) | |
elif truth.type_ == _Type.PERMUTATION_POINTER: | |
# Predictions are NxN logits aiming to represent a doubly stochastic matrix. | |
# Compute the cross entropy between doubly stochastic pred and truth_data | |
total_loss = jnp.mean(-jnp.sum(truth.data * pred, axis=-1)) | |
return total_loss | |
def hint_loss_chunked( | |
truth: _DataPoint, | |
pred: _Array, | |
is_first: _Array, | |
nb_nodes: int, | |
): | |
"""Hint loss for time-chunked training.""" | |
loss, mask = _hint_loss( | |
truth_data=truth.data, | |
truth_type=truth.type_, | |
pred=pred, | |
nb_nodes=nb_nodes, | |
) | |
mask *= (1 - _expand_to(is_first, loss)).astype(jnp.float32) | |
loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS) | |
return loss | |
def hint_loss( | |
truth: _DataPoint, | |
preds: List[_Array], | |
lengths: _Array, | |
nb_nodes: int, | |
verbose: bool = False, | |
): | |
"""Hint loss for full-sample training.""" | |
total_loss = 0. | |
verbose_loss = {} | |
length = truth.data.shape[0] - 1 | |
loss, mask = _hint_loss( | |
truth_data=truth.data[1:], | |
truth_type=truth.type_, | |
pred=jnp.stack(preds), | |
nb_nodes=nb_nodes, | |
) | |
mask *= _is_not_done_broadcast(lengths, jnp.arange(length)[:, None], loss) | |
loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS) | |
if verbose: | |
verbose_loss['loss_' + truth.name] = loss | |
else: | |
total_loss += loss | |
return verbose_loss if verbose else total_loss | |
def _hint_loss( | |
truth_data: _Array, | |
truth_type: str, | |
pred: _Array, | |
nb_nodes: int, | |
) -> Tuple[_Array, _Array]: | |
"""Hint loss helper.""" | |
mask = None | |
if truth_type == _Type.SCALAR: | |
loss = (pred - truth_data)**2 | |
elif truth_type == _Type.MASK: | |
loss = (jnp.maximum(pred, 0) - pred * truth_data + | |
jnp.log1p(jnp.exp(-jnp.abs(pred)))) | |
mask = (truth_data != _OutputClass.MASKED).astype(jnp.float32) # pytype: disable=attribute-error # numpy-scalars | |
elif truth_type == _Type.MASK_ONE: | |
loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1, | |
keepdims=True) | |
elif truth_type == _Type.CATEGORICAL: | |
loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1) | |
mask = jnp.any(truth_data == _OutputClass.POSITIVE, axis=-1).astype( | |
jnp.float32) | |
elif truth_type == _Type.POINTER: | |
loss = -jnp.sum( | |
hk.one_hot(truth_data, nb_nodes) * jax.nn.log_softmax(pred), | |
axis=-1) | |
elif truth_type == _Type.PERMUTATION_POINTER: | |
# Predictions are NxN logits aiming to represent a doubly stochastic matrix. | |
# Compute the cross entropy between doubly stochastic pred and truth_data | |
loss = -jnp.sum(truth_data * pred, axis=-1) | |
if mask is None: | |
mask = jnp.ones_like(loss) | |
return loss, mask | |
def _is_not_done_broadcast(lengths, i, tensor): | |
is_not_done = (lengths > i + 1) * 1.0 | |
while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars | |
is_not_done = jnp.expand_dims(is_not_done, -1) | |
return is_not_done | |