# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """JAX implementation of CLRS basic network.""" import functools from typing import Dict, List, Optional, Tuple import chex from clrs._src import decoders from clrs._src import encoders from clrs._src import probing from clrs._src import processors from clrs._src import samplers from clrs._src import specs import haiku as hk import jax import jax.numpy as jnp _Array = chex.Array _DataPoint = probing.DataPoint _Features = samplers.Features _FeaturesChunked = samplers.FeaturesChunked _Location = specs.Location _Spec = specs.Spec _Stage = specs.Stage _Trajectory = samplers.Trajectory _Type = specs.Type @chex.dataclass class _MessagePassingScanState: hint_preds: chex.Array output_preds: chex.Array hiddens: chex.Array lstm_state: Optional[hk.LSTMState] @chex.dataclass class _MessagePassingOutputChunked: hint_preds: chex.Array output_preds: chex.Array @chex.dataclass class MessagePassingStateChunked: inputs: chex.Array hints: chex.Array is_first: chex.Array hint_preds: chex.Array hiddens: chex.Array lstm_state: Optional[hk.LSTMState] class Net(hk.Module): """Building blocks (networks) used to encode and decode messages.""" def __init__( self, spec: List[_Spec], hidden_dim: int, encode_hints: bool, decode_hints: bool, processor_factory: processors.ProcessorFactory, use_lstm: bool, encoder_init: str, dropout_prob: float, hint_teacher_forcing: float, hint_repred_mode='soft', nb_dims=None, nb_msg_passing_steps=1, name: str = 'net', ): """Constructs a `Net`.""" super().__init__(name=name) self._dropout_prob = dropout_prob self._hint_teacher_forcing = hint_teacher_forcing self._hint_repred_mode = hint_repred_mode self.spec = spec self.hidden_dim = hidden_dim self.encode_hints = encode_hints self.decode_hints = decode_hints self.processor_factory = processor_factory self.nb_dims = nb_dims self.use_lstm = use_lstm self.encoder_init = encoder_init self.nb_msg_passing_steps = nb_msg_passing_steps def _msg_passing_step(self, mp_state: _MessagePassingScanState, i: int, hints: List[_DataPoint], repred: bool, lengths: chex.Array, batch_size: int, nb_nodes: int, inputs: _Trajectory, first_step: bool, spec: _Spec, encs: Dict[str, List[hk.Module]], decs: Dict[str, Tuple[hk.Module]], return_hints: bool, return_all_outputs: bool ): if self.decode_hints and not first_step: assert self._hint_repred_mode in ['soft', 'hard', 'hard_on_eval'] hard_postprocess = (self._hint_repred_mode == 'hard' or (self._hint_repred_mode == 'hard_on_eval' and repred)) decoded_hint = decoders.postprocess(spec, mp_state.hint_preds, sinkhorn_temperature=0.1, sinkhorn_steps=25, hard=hard_postprocess) if repred and self.decode_hints and not first_step: cur_hint = [] for hint in decoded_hint: cur_hint.append(decoded_hint[hint]) else: cur_hint = [] needs_noise = (self.decode_hints and not first_step and self._hint_teacher_forcing < 1.0) if needs_noise: # For noisy teacher forcing, choose which examples in the batch to force force_mask = jax.random.bernoulli( hk.next_rng_key(), self._hint_teacher_forcing, (batch_size,)) else: force_mask = None for hint in hints: hint_data = jnp.asarray(hint.data)[i] _, loc, typ = spec[hint.name] if needs_noise: if (typ == _Type.POINTER and decoded_hint[hint.name].type_ == _Type.SOFT_POINTER): # When using soft pointers, the decoded hints cannot be summarised # as indices (as would happen in hard postprocessing), so we need # to raise the ground-truth hint (potentially used for teacher # forcing) to its one-hot version. hint_data = hk.one_hot(hint_data, nb_nodes) typ = _Type.SOFT_POINTER hint_data = jnp.where(_expand_to(force_mask, hint_data), hint_data, decoded_hint[hint.name].data) cur_hint.append( probing.DataPoint( name=hint.name, location=loc, type_=typ, data=hint_data)) hiddens, output_preds_cand, hint_preds, lstm_state = self._one_step_pred( inputs, cur_hint, mp_state.hiddens, batch_size, nb_nodes, mp_state.lstm_state, spec, encs, decs, repred) if first_step: output_preds = output_preds_cand else: output_preds = {} for outp in mp_state.output_preds: is_not_done = _is_not_done_broadcast(lengths, i, output_preds_cand[outp]) output_preds[outp] = is_not_done * output_preds_cand[outp] + ( 1.0 - is_not_done) * mp_state.output_preds[outp] new_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars hint_preds=hint_preds, output_preds=output_preds, hiddens=hiddens, lstm_state=lstm_state) # Save memory by not stacking unnecessary fields accum_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars hint_preds=hint_preds if return_hints else None, output_preds=output_preds if return_all_outputs else None, hiddens=None, lstm_state=None) # Complying to jax.scan, the first returned value is the state we carry over # the second value is the output that will be stacked over steps. return new_mp_state, accum_mp_state def __call__(self, features_list: List[_Features], repred: bool, algorithm_index: int, return_hints: bool, return_all_outputs: bool): """Process one batch of data. Args: features_list: A list of _Features objects, each with the inputs, hints and lengths for a batch o data corresponding to one algorithm. The list should have either length 1, at train/evaluation time, or length equal to the number of algorithms this Net is meant to process, at initialization. repred: False during training, when we have access to ground-truth hints. True in validation/test mode, when we have to use our own hint predictions. algorithm_index: Which algorithm is being processed. It can be -1 at initialisation (either because we are initialising the parameters of the module or because we are intialising the message-passing state), meaning that all algorithms should be processed, in which case `features_list` should have length equal to the number of specs of the Net. Otherwise, `algorithm_index` should be between 0 and `length(self.spec) - 1`, meaning only one of the algorithms will be processed, and `features_list` should have length 1. return_hints: Whether to accumulate and return the predicted hints, when they are decoded. return_all_outputs: Whether to return the full sequence of outputs, or just the last step's output. Returns: A 2-tuple with (output predictions, hint predictions) for the selected algorithm. """ if algorithm_index == -1: algorithm_indices = range(len(features_list)) else: algorithm_indices = [algorithm_index] assert len(algorithm_indices) == len(features_list) self.encoders, self.decoders = self._construct_encoders_decoders() self.processor = self.processor_factory(self.hidden_dim) # Optionally construct LSTM. if self.use_lstm: self.lstm = hk.LSTM( hidden_size=self.hidden_dim, name='processor_lstm') lstm_init = self.lstm.initial_state else: self.lstm = None lstm_init = lambda x: 0 for algorithm_index, features in zip(algorithm_indices, features_list): inputs = features.inputs hints = features.hints lengths = features.lengths batch_size, nb_nodes = _data_dimensions(features) nb_mp_steps = max(1, hints[0].data.shape[0] - 1) hiddens = jnp.zeros((batch_size, nb_nodes, self.hidden_dim)) if self.use_lstm: lstm_state = lstm_init(batch_size * nb_nodes) lstm_state = jax.tree_util.tree_map( lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]), lstm_state) else: lstm_state = None mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars hint_preds=None, output_preds=None, hiddens=hiddens, lstm_state=lstm_state) # Do the first step outside of the scan because it has a different # computation graph. common_args = dict( hints=hints, repred=repred, inputs=inputs, batch_size=batch_size, nb_nodes=nb_nodes, lengths=lengths, spec=self.spec[algorithm_index], encs=self.encoders[algorithm_index], decs=self.decoders[algorithm_index], return_hints=return_hints, return_all_outputs=return_all_outputs, ) mp_state, lean_mp_state = self._msg_passing_step( mp_state, i=0, first_step=True, **common_args) # Then scan through the rest. scan_fn = functools.partial( self._msg_passing_step, first_step=False, **common_args) output_mp_state, accum_mp_state = hk.scan( scan_fn, mp_state, jnp.arange(nb_mp_steps - 1) + 1, length=nb_mp_steps - 1) # We only return the last algorithm's output. That's because # the output only matters when a single algorithm is processed; the case # `algorithm_index==-1` (meaning all algorithms should be processed) # is used only to init parameters. accum_mp_state = jax.tree_util.tree_map( lambda init, tail: jnp.concatenate([init[None], tail], axis=0), lean_mp_state, accum_mp_state) def invert(d): """Dict of lists -> list of dicts.""" if d: return [dict(zip(d, i)) for i in zip(*d.values())] if return_all_outputs: output_preds = {k: jnp.stack(v) for k, v in accum_mp_state.output_preds.items()} else: output_preds = output_mp_state.output_preds hint_preds = invert(accum_mp_state.hint_preds) return output_preds, hint_preds def _construct_encoders_decoders(self): """Constructs encoders and decoders, separate for each algorithm.""" encoders_ = [] decoders_ = [] enc_algo_idx = None for (algo_idx, spec) in enumerate(self.spec): enc = {} dec = {} for name, (stage, loc, t) in spec.items(): if stage == _Stage.INPUT or ( stage == _Stage.HINT and self.encode_hints): # Build input encoders. if name == specs.ALGO_IDX_INPUT_NAME: if enc_algo_idx is None: enc_algo_idx = [hk.Linear(self.hidden_dim, name=f'{name}_enc_linear')] enc[name] = enc_algo_idx else: enc[name] = encoders.construct_encoders( stage, loc, t, hidden_dim=self.hidden_dim, init=self.encoder_init, name=f'algo_{algo_idx}_{name}') if stage == _Stage.OUTPUT or ( stage == _Stage.HINT and self.decode_hints): # Build output decoders. dec[name] = decoders.construct_decoders( loc, t, hidden_dim=self.hidden_dim, nb_dims=self.nb_dims[algo_idx][name], name=f'algo_{algo_idx}_{name}') encoders_.append(enc) decoders_.append(dec) return encoders_, decoders_ def _one_step_pred( self, inputs: _Trajectory, hints: _Trajectory, hidden: _Array, batch_size: int, nb_nodes: int, lstm_state: Optional[hk.LSTMState], spec: _Spec, encs: Dict[str, List[hk.Module]], decs: Dict[str, Tuple[hk.Module]], repred: bool, ): """Generates one-step predictions.""" # Initialise empty node/edge/graph features and adjacency matrix. node_fts = jnp.zeros((batch_size, nb_nodes, self.hidden_dim)) edge_fts = jnp.zeros((batch_size, nb_nodes, nb_nodes, self.hidden_dim)) graph_fts = jnp.zeros((batch_size, self.hidden_dim)) adj_mat = jnp.repeat( jnp.expand_dims(jnp.eye(nb_nodes), 0), batch_size, axis=0) # ENCODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Encode node/edge/graph features from inputs and (optionally) hints. trajectories = [inputs] if self.encode_hints: trajectories.append(hints) for trajectory in trajectories: for dp in trajectory: try: dp = encoders.preprocess(dp, nb_nodes) assert dp.type_ != _Type.SOFT_POINTER adj_mat = encoders.accum_adj_mat(dp, adj_mat) encoder = encs[dp.name] edge_fts = encoders.accum_edge_fts(encoder, dp, edge_fts) node_fts = encoders.accum_node_fts(encoder, dp, node_fts) graph_fts = encoders.accum_graph_fts(encoder, dp, graph_fts) except Exception as e: raise Exception(f'Failed to process {dp}') from e # PROCESS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nxt_hidden = hidden for _ in range(self.nb_msg_passing_steps): nxt_hidden, nxt_edge = self.processor( node_fts, edge_fts, graph_fts, adj_mat, nxt_hidden, batch_size=batch_size, nb_nodes=nb_nodes, ) if not repred: # dropout only on training nxt_hidden = hk.dropout(hk.next_rng_key(), self._dropout_prob, nxt_hidden) if self.use_lstm: # lstm doesn't accept multiple batch dimensions (in our case, batch and # nodes), so we vmap over the (first) batch dimension. nxt_hidden, nxt_lstm_state = jax.vmap(self.lstm)(nxt_hidden, lstm_state) else: nxt_lstm_state = None h_t = jnp.concatenate([node_fts, hidden, nxt_hidden], axis=-1) if nxt_edge is not None: e_t = jnp.concatenate([edge_fts, nxt_edge], axis=-1) else: e_t = edge_fts # DECODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Decode features and (optionally) hints. hint_preds, output_preds = decoders.decode_fts( decoders=decs, spec=spec, h_t=h_t, adj_mat=adj_mat, edge_fts=e_t, graph_fts=graph_fts, inf_bias=self.processor.inf_bias, inf_bias_edge=self.processor.inf_bias_edge, repred=repred, ) return nxt_hidden, output_preds, hint_preds, nxt_lstm_state class NetChunked(Net): """A Net that will process time-chunked data instead of full samples.""" def _msg_passing_step(self, mp_state: MessagePassingStateChunked, xs, repred: bool, init_mp_state: bool, batch_size: int, nb_nodes: int, spec: _Spec, encs: Dict[str, List[hk.Module]], decs: Dict[str, Tuple[hk.Module]], ): """Perform one message passing step. This function is unrolled along the time axis to process a data chunk. Args: mp_state: message-passing state. Includes the inputs, hints, beginning-of-sample markers, hint predictions, hidden and lstm state to be used for prediction in the current step. xs: A 3-tuple of with the next timestep's inputs, hints, and beginning-of-sample markers. These will replace the contents of the `mp_state` at the output, in readiness for the next unroll step of the chunk (or the first step of the next chunk). Besides, the next timestep's hints are necessary to compute diffs when `decode_diffs` is True. repred: False during training, when we have access to ground-truth hints. True in validation/test mode, when we have to use our own hint predictions. init_mp_state: Indicates if we are calling the method just to initialise the message-passing state, before the beginning of training or validation. batch_size: Size of batch dimension. nb_nodes: Number of nodes in graph. spec: The spec of the algorithm being processed. encs: encoders for the algorithm being processed. decs: decoders for the algorithm being processed. Returns: A 2-tuple with the next mp_state and an output consisting of hint predictions and output predictions. """ def _as_prediction_data(hint): if hint.type_ == _Type.POINTER: return hk.one_hot(hint.data, nb_nodes) return hint.data nxt_inputs, nxt_hints, nxt_is_first = xs inputs = mp_state.inputs is_first = mp_state.is_first hints = mp_state.hints if init_mp_state: prev_hint_preds = {h.name: _as_prediction_data(h) for h in hints} hints_for_pred = hints else: prev_hint_preds = mp_state.hint_preds if self.decode_hints: if repred: force_mask = jnp.zeros(batch_size, dtype=bool) elif self._hint_teacher_forcing == 1.0: force_mask = jnp.ones(batch_size, dtype=bool) else: force_mask = jax.random.bernoulli( hk.next_rng_key(), self._hint_teacher_forcing, (batch_size,)) assert self._hint_repred_mode in ['soft', 'hard', 'hard_on_eval'] hard_postprocess = ( self._hint_repred_mode == 'hard' or (self._hint_repred_mode == 'hard_on_eval' and repred)) decoded_hints = decoders.postprocess(spec, prev_hint_preds, sinkhorn_temperature=0.1, sinkhorn_steps=25, hard=hard_postprocess) hints_for_pred = [] for h in hints: typ = h.type_ hint_data = h.data if (typ == _Type.POINTER and decoded_hints[h.name].type_ == _Type.SOFT_POINTER): hint_data = hk.one_hot(hint_data, nb_nodes) typ = _Type.SOFT_POINTER hints_for_pred.append(probing.DataPoint( name=h.name, location=h.location, type_=typ, data=jnp.where(_expand_to(is_first | force_mask, hint_data), hint_data, decoded_hints[h.name].data))) else: hints_for_pred = hints hiddens = jnp.where(is_first[..., None, None], 0.0, mp_state.hiddens) if self.use_lstm: lstm_state = jax.tree_util.tree_map( lambda x: jnp.where(is_first[..., None, None], 0.0, x), mp_state.lstm_state) else: lstm_state = None hiddens, output_preds, hint_preds, lstm_state = self._one_step_pred( inputs, hints_for_pred, hiddens, batch_size, nb_nodes, lstm_state, spec, encs, decs, repred) new_mp_state = MessagePassingStateChunked( # pytype: disable=wrong-arg-types # numpy-scalars hiddens=hiddens, lstm_state=lstm_state, hint_preds=hint_preds, inputs=nxt_inputs, hints=nxt_hints, is_first=nxt_is_first) mp_output = _MessagePassingOutputChunked( # pytype: disable=wrong-arg-types # numpy-scalars hint_preds=hint_preds, output_preds=output_preds) return new_mp_state, mp_output def __call__(self, features_list: List[_FeaturesChunked], mp_state_list: List[MessagePassingStateChunked], repred: bool, init_mp_state: bool, algorithm_index: int): """Process one chunk of data. Args: features_list: A list of _FeaturesChunked objects, each with the inputs, hints and beginning- and end-of-sample markers for a chunk (i.e., fixed time length) of data corresponding to one algorithm. All features are expected to have dimensions chunk_length x batch_size x ... The list should have either length 1, at train/evaluation time, or length equal to the number of algorithms this Net is meant to process, at initialization. mp_state_list: list of message-passing states. Each message-passing state includes the inputs, hints, beginning-of-sample markers, hint prediction, hidden and lstm state from the end of the previous chunk, for one algorithm. The length of the list should be the same as the length of `features_list`. repred: False during training, when we have access to ground-truth hints. True in validation/test mode, when we have to use our own hint predictions. init_mp_state: Indicates if we are calling the network just to initialise the message-passing state, before the beginning of training or validation. If True, `algorithm_index` (see below) must be -1 in order to initialize the message-passing state of all algorithms. algorithm_index: Which algorithm is being processed. It can be -1 at initialisation (either because we are initialising the parameters of the module or because we are intialising the message-passing state), meaning that all algorithms should be processed, in which case `features_list` and `mp_state_list` should have length equal to the number of specs of the Net. Otherwise, `algorithm_index` should be between 0 and `length(self.spec) - 1`, meaning only one of the algorithms will be processed, and `features_list` and `mp_state_list` should have length 1. Returns: A 2-tuple consisting of: - A 2-tuple with (output predictions, hint predictions) for the selected algorithm. Each of these has chunk_length x batch_size x ... data, where the first time slice contains outputs for the mp_state that was passed as input, and the last time slice contains outputs for the next-to-last slice of the input features. The outputs that correspond to the final time slice of the input features will be calculated when the next chunk is processed, using the data in the mp_state returned here (see below). If `init_mp_state` is True, we return None instead of the 2-tuple. - The mp_state (message-passing state) for the next chunk of data of the selected algorithm. If `init_mp_state` is True, we return initial mp states for all the algorithms. """ if algorithm_index == -1: algorithm_indices = range(len(features_list)) else: algorithm_indices = [algorithm_index] assert not init_mp_state # init state only allowed with all algorithms assert len(algorithm_indices) == len(features_list) assert len(algorithm_indices) == len(mp_state_list) self.encoders, self.decoders = self._construct_encoders_decoders() self.processor = self.processor_factory(self.hidden_dim) # Optionally construct LSTM. if self.use_lstm: self.lstm = hk.LSTM( hidden_size=self.hidden_dim, name='processor_lstm') lstm_init = self.lstm.initial_state else: self.lstm = None lstm_init = lambda x: 0 if init_mp_state: output_mp_states = [] for algorithm_index, features, mp_state in zip( algorithm_indices, features_list, mp_state_list): inputs = features.inputs hints = features.hints batch_size, nb_nodes = _data_dimensions_chunked(features) if self.use_lstm: lstm_state = lstm_init(batch_size * nb_nodes) lstm_state = jax.tree_util.tree_map( lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]), lstm_state) mp_state.lstm_state = lstm_state mp_state.inputs = jax.tree_util.tree_map(lambda x: x[0], inputs) mp_state.hints = jax.tree_util.tree_map(lambda x: x[0], hints) mp_state.is_first = jnp.zeros(batch_size, dtype=int) mp_state.hiddens = jnp.zeros((batch_size, nb_nodes, self.hidden_dim)) next_is_first = jnp.ones(batch_size, dtype=int) mp_state, _ = self._msg_passing_step( mp_state, (mp_state.inputs, mp_state.hints, next_is_first), repred=repred, init_mp_state=True, batch_size=batch_size, nb_nodes=nb_nodes, spec=self.spec[algorithm_index], encs=self.encoders[algorithm_index], decs=self.decoders[algorithm_index], ) output_mp_states.append(mp_state) return None, output_mp_states for algorithm_index, features, mp_state in zip( algorithm_indices, features_list, mp_state_list): inputs = features.inputs hints = features.hints is_first = features.is_first batch_size, nb_nodes = _data_dimensions_chunked(features) scan_fn = functools.partial( self._msg_passing_step, repred=repred, init_mp_state=False, batch_size=batch_size, nb_nodes=nb_nodes, spec=self.spec[algorithm_index], encs=self.encoders[algorithm_index], decs=self.decoders[algorithm_index], ) mp_state, scan_output = hk.scan( scan_fn, mp_state, (inputs, hints, is_first), ) # We only return the last algorithm's output and state. That's because # the output only matters when a single algorithm is processed; the case # `algorithm_index==-1` (meaning all algorithms should be processed) # is used only to init parameters. return (scan_output.output_preds, scan_output.hint_preds), mp_state def _data_dimensions(features: _Features) -> Tuple[int, int]: """Returns (batch_size, nb_nodes).""" for inp in features.inputs: if inp.location in [_Location.NODE, _Location.EDGE]: return inp.data.shape[:2] assert False def _data_dimensions_chunked(features: _FeaturesChunked) -> Tuple[int, int]: """Returns (batch_size, nb_nodes).""" for inp in features.inputs: if inp.location in [_Location.NODE, _Location.EDGE]: return inp.data.shape[1:3] assert False def _expand_to(x: _Array, y: _Array) -> _Array: while len(y.shape) > len(x.shape): x = jnp.expand_dims(x, -1) return x def _is_not_done_broadcast(lengths, i, tensor): is_not_done = (lengths > i + 1) * 1.0 while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars is_not_done = jnp.expand_dims(is_not_done, -1) return is_not_done