Spaces:
Running
Running
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Unit tests for `losses.py`.""" | |
from typing import Generator | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
from clrs._src import dataset | |
from clrs._src import losses | |
from clrs._src import probing | |
from clrs._src import samplers | |
from clrs._src import specs | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
_Array = np.ndarray | |
_Location = specs.Location | |
def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler: | |
sampler, _ = samplers.build_sampler( | |
algo, | |
seed=samplers.CLRS30['val']['seed'], | |
num_samples=samplers.CLRS30['val']['num_samples'], | |
length=nb_nodes, | |
) | |
return sampler | |
def _make_iterable_sampler( | |
algo: str, batch_size: int, | |
nb_nodes: int) -> Generator[samplers.Feedback, None, None]: | |
sampler = _make_sampler(algo, nb_nodes) | |
while True: | |
yield sampler.next(batch_size) | |
def _as_pred_data(x, nb_nodes, seed, batch_axis): | |
"""Fake a prediction from a data point.""" | |
# Permute along batch axis to make the prediction different. | |
key = jax.random.PRNGKey(seed) | |
data = jax.random.permutation(key, x.data, axis=batch_axis) | |
# Extend to one-hot for pointer types. | |
if x.type_ == specs.Type.POINTER: | |
return jax.nn.one_hot(data, nb_nodes) | |
return data | |
def _mask_datapoint(x, seed, t_axis=None): | |
"""Add some masking to data.""" | |
key = jax.random.PRNGKey(seed) | |
data = x.data | |
if x.type_ == specs.Type.MASK: | |
# mask some data at random | |
mask_shape = list(data.shape) | |
if t_axis is not None: | |
mask_shape[t_axis] = 1 | |
mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2 | |
data = jnp.where(mask, specs.OutputClass.MASKED, data) | |
elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]: | |
# mask some data at random (all categories together) | |
mask_shape = list(data.shape)[:-1] | |
if t_axis is not None: | |
mask_shape[t_axis] = 1 | |
mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2 | |
data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data) | |
return probing.DataPoint(name=x.name, location=x.location, type_=x.type_, | |
data=data) | |
def _rand_diff(seed, shape): | |
return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0 | |
def _rand_mask(seed, shape, p=0.5): | |
return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float) | |
def invert(d): | |
"""Dict of lists -> list of dicts.""" | |
if d: | |
return [dict(zip(d, i)) for i in zip(*d.values())] | |
def _create_data(algo, nb_nodes): | |
batch_size = 8 | |
ds = _make_iterable_sampler(algo, batch_size, nb_nodes) | |
full_sample = next(ds) | |
chunk_length = full_sample.features.lengths[0].astype(int) | |
chunked_ds = dataset.chunkify( | |
_make_iterable_sampler(algo, batch_size, nb_nodes), | |
chunk_length) | |
chunk_sample = next(chunked_ds) | |
return full_sample, chunk_sample | |
class FullVsChunkLossesTest(parameterized.TestCase): | |
"""Test that the full and chunked versions of the losses match.""" | |
# Test two algorithms with fixed-length, covering all data types | |
def test_output_loss(self, algo): | |
nb_nodes = 16 | |
full_sample, chunk_sample = _create_data(algo, nb_nodes) | |
# Calculate output loss. | |
for truth_full, truth_chunked in zip(full_sample.outputs, | |
chunk_sample.outputs): | |
chunk_output_loss = losses.output_loss_chunked( | |
truth=_mask_datapoint(truth_chunked, seed=0), | |
pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1), | |
is_last=chunk_sample.features.is_last, | |
nb_nodes=nb_nodes, | |
) | |
full_output_loss = losses.output_loss( | |
truth=_mask_datapoint(truth_full, seed=0), | |
pred=_as_pred_data(truth_full, nb_nodes, 0, 0), | |
nb_nodes=nb_nodes, | |
) | |
np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4) | |
def test_hint_loss(self, algo): | |
nb_nodes = 16 | |
full_sample, chunk_sample = _create_data(algo, nb_nodes) | |
for truth_full, truth_chunked in zip(full_sample.features.hints, | |
chunk_sample.features.hints): | |
np.testing.assert_array_equal(truth_full.data, truth_chunked.data) | |
pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1) | |
chunk_hint_loss = losses.hint_loss_chunked( | |
truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0), | |
pred=pred, | |
is_first=chunk_sample.features.is_first, | |
nb_nodes=nb_nodes, | |
) | |
full_preds = pred[1:] | |
full_hint_loss = losses.hint_loss( | |
truth=_mask_datapoint(truth_full, 1, t_axis=0), | |
preds=full_preds, | |
lengths=full_sample.features.lengths, | |
nb_nodes=nb_nodes, | |
) | |
np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4) | |
if __name__ == '__main__': | |
absltest.main() | |