# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Run training of one or more algorithmic tasks from CLRS.""" import os # disable logging until training starts os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import functools import os import shutil from typing import Any, Dict, List from absl import app from absl import flags from absl import logging # disable logging until training starts logging.set_verbosity(logging.ERROR) import clrs import jax import numpy as np import requests import tensorflow as tf import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../env"))) from baselines import BaselineModel, BaselineModelChunked import pickle flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.') flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'], 'Which training sizes to use. A size of -1 means ' 'use the benchmark dataset.') flags.DEFINE_integer('length_needle', -8, 'Length of needle for training and validation ' '(not testing) in string matching algorithms. ' 'A negative value randomizes the length for each sample ' 'between 1 and the opposite of the value. ' 'A value of 0 means use always 1/4 of the length of ' 'the haystack (the default sampler behavior).') flags.DEFINE_integer('seed', 42, 'Random seed to set') flags.DEFINE_boolean('random_pos', True, 'Randomize the pos input common to all algos.') flags.DEFINE_boolean('enforce_permutations', True, 'Whether to enforce permutation-type node pointers.') flags.DEFINE_boolean('enforce_pred_as_input', True, 'Whether to change pred_h hints into pred inputs.') flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.') flags.DEFINE_boolean('chunked_training', False, 'Whether to use chunking for training.') flags.DEFINE_integer('chunk_length', 16, 'Time chunk length used for training (if ' '`chunked_training` is True.') flags.DEFINE_integer('train_steps', 1000, 'Number of training iterations.') flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).') flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).') flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).') flags.DEFINE_integer('hidden_size', 128, 'Number of hidden units of the model.') flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors') flags.DEFINE_integer('nb_msg_passing_steps', 1, 'Number of message passing steps to run per hint.') flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.') flags.DEFINE_float('grad_clip_max_norm', 1.0, 'Gradient clipping by norm. 0.0 disables grad clipping') flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.') flags.DEFINE_float('hint_teacher_forcing', 0.0, 'Probability that ground-truth teacher hints are encoded ' 'during training instead of predicted hints. Only ' 'pertinent in encoded_decoded modes.') flags.DEFINE_enum('hint_mode', 'encoded_decoded', ['encoded_decoded', 'decoded_only', 'none'], 'How should hints be used? Note, each mode defines a ' 'separate task, with various difficulties. `encoded_decoded` ' 'requires the model to explicitly materialise hint sequences ' 'and therefore is hardest, but also most aligned to the ' 'underlying algorithmic rule. Hence, `encoded_decoded` ' 'should be treated as the default mode for our benchmark. ' 'In `decoded_only`, hints are only used for defining ' 'reconstruction losses. Often, this will perform well, but ' 'note that we currently do not make any efforts to ' 'counterbalance the various hint losses. Hence, for certain ' 'tasks, the best performance will now be achievable with no ' 'hint usage at all (`none`).') flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'], 'How to process predicted hints when fed back as inputs.' 'In soft mode, we use softmaxes for categoricals, pointers ' 'and mask_one, and sigmoids for masks. ' 'In hard mode, we use argmax instead of softmax, and hard ' 'thresholding of masks. ' 'In hard_on_eval mode, soft mode is ' 'used for training and hard mode is used for evaluation.') flags.DEFINE_boolean('use_ln', True, 'Whether to use layer normalisation in the processor.') flags.DEFINE_boolean('use_lstm', False, 'Whether to insert an LSTM after message passing.') flags.DEFINE_integer('nb_triplet_fts', 8, 'How many triplet features to compute?') flags.DEFINE_enum('encoder_init', 'xavier_on_scalars', ['default', 'xavier_on_scalars'], 'Initialiser to use for the encoders.') flags.DEFINE_enum('processor_type', 'triplet_gmpnn', ['deepsets', 'mpnn', 'pgn', 'pgn_mask', 'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask', 'gat', 'gatv2', 'gat_full', 'gatv2_full', 'gpgn', 'gpgn_mask', 'gmpnn', 'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'], 'Processor type to use as the network P.') flags.DEFINE_string('checkpoint_path', '../env/checkpoints', 'Path in which checkpoints are saved.') flags.DEFINE_string('dataset_path', '/tmp/CLRS30', 'Path in which dataset is stored.') flags.DEFINE_boolean('freeze_processor', False, 'Whether to freeze the processor of the model.') FLAGS = flags.FLAGS PRED_AS_INPUT_ALGOS = [ 'binary_search', 'minimum', 'find_maximum_subarray', 'find_maximum_subarray_kadane', 'matrix_chain_order', 'lcs_length', 'optimal_bst', 'activity_selector', 'task_scheduling', 'naive_string_matcher', 'kmp_matcher', 'jarvis_march'] def unpack(v): try: return v.item() # DeviceArray except (AttributeError, ValueError): return v def _iterate_sampler(sampler, batch_size): while True: yield sampler.next(batch_size) def _maybe_download_dataset(dataset_path): """Download CLRS30 dataset if needed.""" dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder()) if os.path.isdir(dataset_folder): logging.info('Dataset found at %s. Skipping download.', dataset_folder) return dataset_folder logging.info('Dataset not found in %s. Downloading...', dataset_folder) clrs_url = clrs.get_dataset_gcp_url() request = requests.get(clrs_url, allow_redirects=True) clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url)) os.makedirs(dataset_folder) open(clrs_file, 'wb').write(request.content) shutil.unpack_archive(clrs_file, extract_dir=dataset_folder) os.remove(clrs_file) return dataset_folder def make_sampler(length: int, rng: Any, algorithm: str, split: str, batch_size: int, multiplier: int, randomize_pos: bool, enforce_pred_as_input: bool, enforce_permutations: bool, chunked: bool, chunk_length: int, sampler_kwargs: Dict[str, Any]): """Create a sampler with given options. Args: length: Size of samples (i.e., number of nodes in the graph). A length of -1 will mean that the benchmark dataset (for the given split) is used. Positive sizes will instantiate samplers of the corresponding size. rng: Numpy random state. algorithm: The name of the algorithm to sample from. split: 'train', 'val' or 'test'. batch_size: Samples per batch. multiplier: Integer multiplier for the number of samples in the dataset, only used for positive sizes. Negative multiplier means infinite samples. randomize_pos: Whether to randomize the `pos` input. enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs. enforce_permutations: Whether to enforce permutation pointers. chunked: Whether to chunk the dataset. chunk_length: Unroll length of chunks, if `chunked` is True. sampler_kwargs: Extra args passed to the sampler. Returns: A sampler (iterator), the number of samples in the iterator (negative if infinite samples), and the spec. """ if length < 0: # load from file dataset_folder = _maybe_download_dataset(FLAGS.dataset_path) sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder, algorithm=algorithm, batch_size=batch_size, split=split) sampler = sampler.as_numpy_iterator() else: num_samples = clrs.CLRS30[split]['num_samples'] * multiplier sampler, spec = clrs.build_sampler( algorithm, seed=rng.randint(2**32), num_samples=num_samples, length=length, **sampler_kwargs, ) sampler = _iterate_sampler(sampler, batch_size) if randomize_pos: sampler = clrs.process_random_pos(sampler, rng) if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS: spec, sampler = clrs.process_pred_as_input(spec, sampler) spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations) if chunked: sampler = clrs.chunkify(sampler, chunk_length) return sampler, num_samples, spec def make_multi_sampler(sizes, rng, **kwargs): """Create a sampler with cycling sample sizes.""" ss = [] tot_samples = 0 for length in sizes: sampler, num_samples, spec = make_sampler(length, rng, **kwargs) ss.append(sampler) tot_samples += num_samples def cycle_samplers(): while True: for s in ss: yield next(s) return cycle_samplers(), tot_samples, spec def _concat(dps, axis): return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps) def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras): """Collect batches of output and hint preds and evaluate them.""" processed_samples = 0 preds = [] outputs = [] while processed_samples < sample_count: feedback = next(sampler) batch_size = feedback.outputs[0].data.shape[0] outputs.append(feedback.outputs) new_rng_key, rng_key = jax.random.split(rng_key) cur_preds, _ = predict_fn(new_rng_key, feedback.features) preds.append(cur_preds) processed_samples += batch_size outputs = _concat(outputs, axis=0) preds = _concat(preds, axis=0) out = clrs.evaluate(outputs, preds) if extras: out.update(extras) return {k: unpack(v) for k, v in out.items()} def create_samplers(rng, train_lengths: List[int]): """Create all the samplers.""" train_samplers = [] val_samplers = [] val_sample_counts = [] test_samplers = [] test_sample_counts = [] spec_list = [] for algo_idx, algorithm in enumerate(FLAGS.algorithms): # Make full dataset pipeline run on CPU (including prefetching). with tf.device('/cpu:0'): if algorithm in ['naive_string_matcher', 'kmp_matcher']: # Fixed haystack + needle; variability will be in needle # Still, for chunked training, we maintain as many samplers # as train lengths, since, for each length there is a separate state, # and we must keep the 1:1 relationship between states and samplers. max_length = max(train_lengths) if max_length > 0: # if < 0, we are using the benchmark data max_length = (max_length * 5) // 4 train_lengths = [max_length] if FLAGS.chunked_training: train_lengths = train_lengths * len(train_lengths) logging.info('Creating samplers for algo %s', algorithm) p = tuple([0.1 + 0.1 * i for i in range(9)]) if p and algorithm in ['articulation_points', 'bridges', 'mst_kruskal', 'bipartite_matching']: # Choose a lower connection probability for the above algorithms, # otherwise trajectories are very long p = tuple(np.array(p) / 2) length_needle = FLAGS.length_needle sampler_kwargs = dict(p=p, length_needle=length_needle) if length_needle == 0: sampler_kwargs.pop('length_needle') common_sampler_args = dict( algorithm=FLAGS.algorithms[algo_idx], rng=rng, enforce_pred_as_input=FLAGS.enforce_pred_as_input, enforce_permutations=FLAGS.enforce_permutations, chunk_length=FLAGS.chunk_length, ) train_args = dict(sizes=train_lengths, split='train', batch_size=FLAGS.batch_size, multiplier=-1, randomize_pos=FLAGS.random_pos, chunked=FLAGS.chunked_training, sampler_kwargs=sampler_kwargs, **common_sampler_args) train_sampler, _, spec = make_multi_sampler(**train_args) mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier'] val_args = dict(sizes=[np.amax(train_lengths)], split='val', batch_size=32, multiplier=2 * mult, randomize_pos=FLAGS.random_pos, chunked=False, sampler_kwargs=sampler_kwargs, **common_sampler_args) val_sampler, val_samples, spec = make_multi_sampler(**val_args) test_args = dict(sizes=[-1], split='test', batch_size=32, multiplier=2 * mult, randomize_pos=False, chunked=False, sampler_kwargs={}, **common_sampler_args) test_sampler, test_samples, spec = make_multi_sampler(**test_args) spec_list.append(spec) train_samplers.append(train_sampler) val_samplers.append(val_sampler) val_sample_counts.append(val_samples) test_samplers.append(test_sampler) test_sample_counts.append(test_samples) return (train_samplers, val_samplers, val_sample_counts, test_samplers, test_sample_counts, spec_list) def get_score(submission_folder): FLAGS(["eval.py"]) if FLAGS.hint_mode == 'encoded_decoded': encode_hints = True decode_hints = True elif FLAGS.hint_mode == 'decoded_only': encode_hints = False decode_hints = True elif FLAGS.hint_mode == 'none': encode_hints = False decode_hints = False else: raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.') train_lengths = [int(x) for x in FLAGS.train_lengths] rng = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey(rng.randint(2**32)) checkpoint_path = os.path.join(submission_folder, 'checkpoints') spec_list = pickle.load(open(os.path.join(checkpoint_path, 'spec_list.pkl'), 'rb')) # Create samplers (train_samplers, val_samplers, val_sample_counts, test_samplers, test_sample_counts, spec_list) = create_samplers(rng, train_lengths) # load spec_list model_params = pickle.load(open(os.path.join(checkpoint_path, 'model_params.pkl'), 'rb')) processor_type, use_ln, nb_triplet_fts, nb_heads = model_params["processor_factory"] model_params["processor_factory"] = clrs.get_processor_factory( processor_type, use_ln=use_ln, nb_triplet_fts=nb_triplet_fts, nb_heads=nb_heads ) model_params["checkpoint_path"]=checkpoint_path eval_model = BaselineModel( spec=spec_list, dummy_trajectory=[next(t) for t in val_samplers], **model_params ) feedback_list = [next(t) for t in train_samplers] # Initialize model. all_features = [f.features for f in feedback_list] eval_model.init(all_features, FLAGS.seed + 1) logging.set_verbosity(logging.INFO) logging.info('Restoring best model from checkpoint...') eval_model.restore_model('best.pkl', only_load_processor=False) for algo_idx in range(len(train_samplers)): new_rng_key, rng_key = jax.random.split(rng_key) val_stats = collect_and_eval( val_samplers[algo_idx], functools.partial(eval_model.predict, algorithm_index=algo_idx), val_sample_counts[algo_idx], new_rng_key, extras = {}) # logging.info('(val) algo %s: %s', FLAGS.algorithms[algo_idx], val_stats) new_rng_key, rng_key = jax.random.split(rng_key) test_stats = collect_and_eval( test_samplers[algo_idx], functools.partial(eval_model.predict, algorithm_index=algo_idx), test_sample_counts[algo_idx], new_rng_key, extras = {}) # logging.info('(test) algo %s : %s', FLAGS.algorithms[algo_idx], test_stats) return test_stats['score'] if __name__ == '__main__': app.run(get_score)