Spaces:
Runtime error
Runtime error
| # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Utilities for calculating losses.""" | |
| from typing import Dict, List, Tuple | |
| import chex | |
| from clrs._src import probing | |
| from clrs._src import specs | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| _Array = chex.Array | |
| _DataPoint = probing.DataPoint | |
| _Location = specs.Location | |
| _OutputClass = specs.OutputClass | |
| _PredTrajectory = Dict[str, _Array] | |
| _PredTrajectories = List[_PredTrajectory] | |
| _Type = specs.Type | |
| EPS = 1e-12 | |
| def _expand_to(x: _Array, y: _Array) -> _Array: | |
| while len(y.shape) > len(x.shape): | |
| x = jnp.expand_dims(x, -1) | |
| return x | |
| def _expand_and_broadcast_to(x: _Array, y: _Array) -> _Array: | |
| return jnp.broadcast_to(_expand_to(x, y), y.shape) | |
| def output_loss_chunked(truth: _DataPoint, pred: _Array, | |
| is_last: _Array, nb_nodes: int) -> float: | |
| """Output loss for time-chunked training.""" | |
| mask = None | |
| if truth.type_ == _Type.SCALAR: | |
| loss = (pred - truth.data)**2 | |
| elif truth.type_ == _Type.MASK: | |
| loss = ( | |
| jnp.maximum(pred, 0) - pred * truth.data + | |
| jnp.log1p(jnp.exp(-jnp.abs(pred)))) | |
| mask = (truth.data != _OutputClass.MASKED) | |
| elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]: | |
| mask = jnp.any(truth.data == _OutputClass.POSITIVE, axis=-1) | |
| masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype( | |
| jnp.float32) | |
| loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred), axis=-1) | |
| elif truth.type_ == _Type.POINTER: | |
| loss = -jnp.sum( | |
| hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), axis=-1) | |
| elif truth.type_ == _Type.PERMUTATION_POINTER: | |
| # Predictions are NxN logits aiming to represent a doubly stochastic matrix. | |
| # Compute the cross entropy between doubly stochastic pred and truth_data | |
| loss = -jnp.sum(truth.data * pred, axis=-1) | |
| if mask is not None: | |
| mask = mask * _expand_and_broadcast_to(is_last, loss) | |
| else: | |
| mask = _expand_and_broadcast_to(is_last, loss) | |
| total_mask = jnp.maximum(jnp.sum(mask), EPS) | |
| return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask | |
| def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float: | |
| """Output loss for full-sample training.""" | |
| if truth.type_ == _Type.SCALAR: | |
| total_loss = jnp.mean((pred - truth.data)**2) | |
| elif truth.type_ == _Type.MASK: | |
| loss = ( | |
| jnp.maximum(pred, 0) - pred * truth.data + | |
| jnp.log1p(jnp.exp(-jnp.abs(pred)))) | |
| mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32) | |
| total_loss = jnp.sum(loss * mask) / jnp.sum(mask) | |
| elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]: | |
| masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype( | |
| jnp.float32) | |
| total_loss = (-jnp.sum(masked_truth * jax.nn.log_softmax(pred)) / | |
| jnp.sum(truth.data == _OutputClass.POSITIVE)) | |
| elif truth.type_ == _Type.POINTER: | |
| total_loss = ( | |
| jnp.mean(-jnp.sum( | |
| hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), | |
| axis=-1))) | |
| elif truth.type_ == _Type.PERMUTATION_POINTER: | |
| # Predictions are NxN logits aiming to represent a doubly stochastic matrix. | |
| # Compute the cross entropy between doubly stochastic pred and truth_data | |
| total_loss = jnp.mean(-jnp.sum(truth.data * pred, axis=-1)) | |
| return total_loss | |
| def hint_loss_chunked( | |
| truth: _DataPoint, | |
| pred: _Array, | |
| is_first: _Array, | |
| nb_nodes: int, | |
| ): | |
| """Hint loss for time-chunked training.""" | |
| loss, mask = _hint_loss( | |
| truth_data=truth.data, | |
| truth_type=truth.type_, | |
| pred=pred, | |
| nb_nodes=nb_nodes, | |
| ) | |
| mask *= (1 - _expand_to(is_first, loss)).astype(jnp.float32) | |
| loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS) | |
| return loss | |
| def hint_loss( | |
| truth: _DataPoint, | |
| preds: List[_Array], | |
| lengths: _Array, | |
| nb_nodes: int, | |
| verbose: bool = False, | |
| ): | |
| """Hint loss for full-sample training.""" | |
| total_loss = 0. | |
| verbose_loss = {} | |
| length = truth.data.shape[0] - 1 | |
| loss, mask = _hint_loss( | |
| truth_data=truth.data[1:], | |
| truth_type=truth.type_, | |
| pred=jnp.stack(preds), | |
| nb_nodes=nb_nodes, | |
| ) | |
| mask *= _is_not_done_broadcast(lengths, jnp.arange(length)[:, None], loss) | |
| loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS) | |
| if verbose: | |
| verbose_loss['loss_' + truth.name] = loss | |
| else: | |
| total_loss += loss | |
| return verbose_loss if verbose else total_loss | |
| def _hint_loss( | |
| truth_data: _Array, | |
| truth_type: str, | |
| pred: _Array, | |
| nb_nodes: int, | |
| ) -> Tuple[_Array, _Array]: | |
| """Hint loss helper.""" | |
| mask = None | |
| if truth_type == _Type.SCALAR: | |
| loss = (pred - truth_data)**2 | |
| elif truth_type == _Type.MASK: | |
| loss = (jnp.maximum(pred, 0) - pred * truth_data + | |
| jnp.log1p(jnp.exp(-jnp.abs(pred)))) | |
| mask = (truth_data != _OutputClass.MASKED).astype(jnp.float32) # pytype: disable=attribute-error # numpy-scalars | |
| elif truth_type == _Type.MASK_ONE: | |
| loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1, | |
| keepdims=True) | |
| elif truth_type == _Type.CATEGORICAL: | |
| loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1) | |
| mask = jnp.any(truth_data == _OutputClass.POSITIVE, axis=-1).astype( | |
| jnp.float32) | |
| elif truth_type == _Type.POINTER: | |
| loss = -jnp.sum( | |
| hk.one_hot(truth_data, nb_nodes) * jax.nn.log_softmax(pred), | |
| axis=-1) | |
| elif truth_type == _Type.PERMUTATION_POINTER: | |
| # Predictions are NxN logits aiming to represent a doubly stochastic matrix. | |
| # Compute the cross entropy between doubly stochastic pred and truth_data | |
| loss = -jnp.sum(truth_data * pred, axis=-1) | |
| if mask is None: | |
| mask = jnp.ones_like(loss) | |
| return loss, mask | |
| def _is_not_done_broadcast(lengths, i, tensor): | |
| is_not_done = (lengths > i + 1) * 1.0 | |
| while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars | |
| is_not_done = jnp.expand_dims(is_not_done, -1) | |
| return is_not_done | |