MLR-Copilot / benchmarks /CLRS /env /losses_test.py
Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
5.61 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for `losses.py`."""
from typing import Generator
from absl.testing import absltest
from absl.testing import parameterized
from clrs._src import dataset
from clrs._src import losses
from clrs._src import probing
from clrs._src import samplers
from clrs._src import specs
import jax
import jax.numpy as jnp
import numpy as np
_Array = np.ndarray
_Location = specs.Location
def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler:
sampler, _ = samplers.build_sampler(
algo,
seed=samplers.CLRS30['val']['seed'],
num_samples=samplers.CLRS30['val']['num_samples'],
length=nb_nodes,
)
return sampler
def _make_iterable_sampler(
algo: str, batch_size: int,
nb_nodes: int) -> Generator[samplers.Feedback, None, None]:
sampler = _make_sampler(algo, nb_nodes)
while True:
yield sampler.next(batch_size)
def _as_pred_data(x, nb_nodes, seed, batch_axis):
"""Fake a prediction from a data point."""
# Permute along batch axis to make the prediction different.
key = jax.random.PRNGKey(seed)
data = jax.random.permutation(key, x.data, axis=batch_axis)
# Extend to one-hot for pointer types.
if x.type_ == specs.Type.POINTER:
return jax.nn.one_hot(data, nb_nodes)
return data
def _mask_datapoint(x, seed, t_axis=None):
"""Add some masking to data."""
key = jax.random.PRNGKey(seed)
data = x.data
if x.type_ == specs.Type.MASK:
# mask some data at random
mask_shape = list(data.shape)
if t_axis is not None:
mask_shape[t_axis] = 1
mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
data = jnp.where(mask, specs.OutputClass.MASKED, data)
elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]:
# mask some data at random (all categories together)
mask_shape = list(data.shape)[:-1]
if t_axis is not None:
mask_shape[t_axis] = 1
mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data)
return probing.DataPoint(name=x.name, location=x.location, type_=x.type_,
data=data)
def _rand_diff(seed, shape):
return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0
def _rand_mask(seed, shape, p=0.5):
return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float)
def invert(d):
"""Dict of lists -> list of dicts."""
if d:
return [dict(zip(d, i)) for i in zip(*d.values())]
def _create_data(algo, nb_nodes):
batch_size = 8
ds = _make_iterable_sampler(algo, batch_size, nb_nodes)
full_sample = next(ds)
chunk_length = full_sample.features.lengths[0].astype(int)
chunked_ds = dataset.chunkify(
_make_iterable_sampler(algo, batch_size, nb_nodes),
chunk_length)
chunk_sample = next(chunked_ds)
return full_sample, chunk_sample
class FullVsChunkLossesTest(parameterized.TestCase):
"""Test that the full and chunked versions of the losses match."""
# Test two algorithms with fixed-length, covering all data types
@parameterized.parameters('dfs', 'floyd_warshall')
def test_output_loss(self, algo):
nb_nodes = 16
full_sample, chunk_sample = _create_data(algo, nb_nodes)
# Calculate output loss.
for truth_full, truth_chunked in zip(full_sample.outputs,
chunk_sample.outputs):
chunk_output_loss = losses.output_loss_chunked(
truth=_mask_datapoint(truth_chunked, seed=0),
pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1),
is_last=chunk_sample.features.is_last,
nb_nodes=nb_nodes,
)
full_output_loss = losses.output_loss(
truth=_mask_datapoint(truth_full, seed=0),
pred=_as_pred_data(truth_full, nb_nodes, 0, 0),
nb_nodes=nb_nodes,
)
np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4)
@parameterized.parameters('dfs', 'floyd_warshall')
def test_hint_loss(self, algo):
nb_nodes = 16
full_sample, chunk_sample = _create_data(algo, nb_nodes)
for truth_full, truth_chunked in zip(full_sample.features.hints,
chunk_sample.features.hints):
np.testing.assert_array_equal(truth_full.data, truth_chunked.data)
pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1)
chunk_hint_loss = losses.hint_loss_chunked(
truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0),
pred=pred,
is_first=chunk_sample.features.is_first,
nb_nodes=nb_nodes,
)
full_preds = pred[1:]
full_hint_loss = losses.hint_loss(
truth=_mask_datapoint(truth_full, 1, t_axis=0),
preds=full_preds,
lengths=full_sample.features.lengths,
nb_nodes=nb_nodes,
)
np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4)
if __name__ == '__main__':
absltest.main()