Spaces:
Runtime error
Runtime error
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Run training of one or more algorithmic tasks from CLRS.""" | |
| import os | |
| # disable logging until training starts | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| import functools | |
| import os | |
| import shutil | |
| from typing import Any, Dict, List | |
| from absl import app | |
| from absl import flags | |
| from absl import logging | |
| # disable logging until training starts | |
| logging.set_verbosity(logging.ERROR) | |
| import clrs | |
| import jax | |
| import numpy as np | |
| import requests | |
| import tensorflow as tf | |
| from baselines import BaselineModel, BaselineModelChunked | |
| import pickle | |
| import copy | |
| flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.') | |
| flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'], | |
| 'Which training sizes to use. A size of -1 means ' | |
| 'use the benchmark dataset.') | |
| flags.DEFINE_integer('length_needle', -8, | |
| 'Length of needle for training and validation ' | |
| '(not testing) in string matching algorithms. ' | |
| 'A negative value randomizes the length for each sample ' | |
| 'between 1 and the opposite of the value. ' | |
| 'A value of 0 means use always 1/4 of the length of ' | |
| 'the haystack (the default sampler behavior).') | |
| flags.DEFINE_integer('seed', 42, 'Random seed to set') | |
| flags.DEFINE_boolean('random_pos', True, | |
| 'Randomize the pos input common to all algos.') | |
| flags.DEFINE_boolean('enforce_permutations', True, | |
| 'Whether to enforce permutation-type node pointers.') | |
| flags.DEFINE_boolean('enforce_pred_as_input', True, | |
| 'Whether to change pred_h hints into pred inputs.') | |
| flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.') | |
| flags.DEFINE_boolean('chunked_training', False, | |
| 'Whether to use chunking for training.') | |
| flags.DEFINE_integer('chunk_length', 16, | |
| 'Time chunk length used for training (if ' | |
| '`chunked_training` is True.') | |
| flags.DEFINE_integer('train_steps', 500, 'Number of training iterations.') | |
| flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).') | |
| flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).') | |
| flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).') | |
| flags.DEFINE_integer('hidden_size', 128, | |
| 'Number of hidden units of the model.') | |
| flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors') | |
| flags.DEFINE_integer('nb_msg_passing_steps', 1, | |
| 'Number of message passing steps to run per hint.') | |
| flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.') | |
| flags.DEFINE_float('grad_clip_max_norm', 1.0, | |
| 'Gradient clipping by norm. 0.0 disables grad clipping') | |
| flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.') | |
| flags.DEFINE_float('hint_teacher_forcing', 0.0, | |
| 'Probability that ground-truth teacher hints are encoded ' | |
| 'during training instead of predicted hints. Only ' | |
| 'pertinent in encoded_decoded modes.') | |
| flags.DEFINE_enum('hint_mode', 'encoded_decoded', | |
| ['encoded_decoded', 'decoded_only', 'none'], | |
| 'How should hints be used? Note, each mode defines a ' | |
| 'separate task, with various difficulties. `encoded_decoded` ' | |
| 'requires the model to explicitly materialise hint sequences ' | |
| 'and therefore is hardest, but also most aligned to the ' | |
| 'underlying algorithmic rule. Hence, `encoded_decoded` ' | |
| 'should be treated as the default mode for our benchmark. ' | |
| 'In `decoded_only`, hints are only used for defining ' | |
| 'reconstruction losses. Often, this will perform well, but ' | |
| 'note that we currently do not make any efforts to ' | |
| 'counterbalance the various hint losses. Hence, for certain ' | |
| 'tasks, the best performance will now be achievable with no ' | |
| 'hint usage at all (`none`).') | |
| flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'], | |
| 'How to process predicted hints when fed back as inputs.' | |
| 'In soft mode, we use softmaxes for categoricals, pointers ' | |
| 'and mask_one, and sigmoids for masks. ' | |
| 'In hard mode, we use argmax instead of softmax, and hard ' | |
| 'thresholding of masks. ' | |
| 'In hard_on_eval mode, soft mode is ' | |
| 'used for training and hard mode is used for evaluation.') | |
| flags.DEFINE_boolean('use_ln', True, | |
| 'Whether to use layer normalisation in the processor.') | |
| flags.DEFINE_boolean('use_lstm', False, | |
| 'Whether to insert an LSTM after message passing.') | |
| flags.DEFINE_integer('nb_triplet_fts', 8, | |
| 'How many triplet features to compute?') | |
| flags.DEFINE_enum('encoder_init', 'xavier_on_scalars', | |
| ['default', 'xavier_on_scalars'], | |
| 'Initialiser to use for the encoders.') | |
| flags.DEFINE_enum('processor_type', 'triplet_gmpnn', | |
| ['deepsets', 'mpnn', 'pgn', 'pgn_mask', | |
| 'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask', | |
| 'gat', 'gatv2', 'gat_full', 'gatv2_full', | |
| 'gpgn', 'gpgn_mask', 'gmpnn', | |
| 'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'], | |
| 'Processor type to use as the network P.') | |
| flags.DEFINE_string('checkpoint_path', './checkpoints', | |
| 'Path in which checkpoints are saved.') | |
| flags.DEFINE_string('dataset_path', '/tmp/CLRS30', | |
| 'Path in which dataset is stored.') | |
| flags.DEFINE_boolean('freeze_processor', False, | |
| 'Whether to freeze the processor of the model.') | |
| FLAGS = flags.FLAGS | |
| PRED_AS_INPUT_ALGOS = [ | |
| 'binary_search', | |
| 'minimum', | |
| 'find_maximum_subarray', | |
| 'find_maximum_subarray_kadane', | |
| 'matrix_chain_order', | |
| 'lcs_length', | |
| 'optimal_bst', | |
| 'activity_selector', | |
| 'task_scheduling', | |
| 'naive_string_matcher', | |
| 'kmp_matcher', | |
| 'jarvis_march'] | |
| def unpack(v): | |
| try: | |
| return v.item() # DeviceArray | |
| except (AttributeError, ValueError): | |
| return v | |
| def _iterate_sampler(sampler, batch_size): | |
| while True: | |
| yield sampler.next(batch_size) | |
| def _maybe_download_dataset(dataset_path): | |
| """Download CLRS30 dataset if needed.""" | |
| dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder()) | |
| if os.path.isdir(dataset_folder): | |
| logging.info('Dataset found at %s. Skipping download.', dataset_folder) | |
| return dataset_folder | |
| logging.info('Dataset not found in %s. Downloading...', dataset_folder) | |
| clrs_url = clrs.get_dataset_gcp_url() | |
| request = requests.get(clrs_url, allow_redirects=True) | |
| clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url)) | |
| os.makedirs(dataset_folder) | |
| open(clrs_file, 'wb').write(request.content) | |
| shutil.unpack_archive(clrs_file, extract_dir=dataset_folder) | |
| os.remove(clrs_file) | |
| return dataset_folder | |
| def make_sampler(length: int, | |
| rng: Any, | |
| algorithm: str, | |
| split: str, | |
| batch_size: int, | |
| multiplier: int, | |
| randomize_pos: bool, | |
| enforce_pred_as_input: bool, | |
| enforce_permutations: bool, | |
| chunked: bool, | |
| chunk_length: int, | |
| sampler_kwargs: Dict[str, Any]): | |
| """Create a sampler with given options. | |
| Args: | |
| length: Size of samples (i.e., number of nodes in the graph). | |
| A length of -1 will mean that the benchmark | |
| dataset (for the given split) is used. Positive sizes will instantiate | |
| samplers of the corresponding size. | |
| rng: Numpy random state. | |
| algorithm: The name of the algorithm to sample from. | |
| split: 'train', 'val' or 'test'. | |
| batch_size: Samples per batch. | |
| multiplier: Integer multiplier for the number of samples in the dataset, | |
| only used for positive sizes. Negative multiplier means infinite samples. | |
| randomize_pos: Whether to randomize the `pos` input. | |
| enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs. | |
| enforce_permutations: Whether to enforce permutation pointers. | |
| chunked: Whether to chunk the dataset. | |
| chunk_length: Unroll length of chunks, if `chunked` is True. | |
| sampler_kwargs: Extra args passed to the sampler. | |
| Returns: | |
| A sampler (iterator), the number of samples in the iterator (negative | |
| if infinite samples), and the spec. | |
| """ | |
| if length < 0: # load from file | |
| dataset_folder = _maybe_download_dataset(FLAGS.dataset_path) | |
| sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder, | |
| algorithm=algorithm, | |
| batch_size=batch_size, | |
| split=split) | |
| sampler = sampler.as_numpy_iterator() | |
| else: | |
| num_samples = clrs.CLRS30[split]['num_samples'] * multiplier | |
| sampler, spec = clrs.build_sampler( | |
| algorithm, | |
| seed=rng.randint(2**32), | |
| num_samples=num_samples, | |
| length=length, | |
| **sampler_kwargs, | |
| ) | |
| sampler = _iterate_sampler(sampler, batch_size) | |
| if randomize_pos: | |
| sampler = clrs.process_random_pos(sampler, rng) | |
| if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS: | |
| spec, sampler = clrs.process_pred_as_input(spec, sampler) | |
| spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations) | |
| if chunked: | |
| sampler = clrs.chunkify(sampler, chunk_length) | |
| return sampler, num_samples, spec | |
| def make_multi_sampler(sizes, rng, **kwargs): | |
| """Create a sampler with cycling sample sizes.""" | |
| ss = [] | |
| tot_samples = 0 | |
| for length in sizes: | |
| sampler, num_samples, spec = make_sampler(length, rng, **kwargs) | |
| ss.append(sampler) | |
| tot_samples += num_samples | |
| def cycle_samplers(): | |
| while True: | |
| for s in ss: | |
| yield next(s) | |
| return cycle_samplers(), tot_samples, spec | |
| def _concat(dps, axis): | |
| return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps) | |
| def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras): | |
| """Collect batches of output and hint preds and evaluate them.""" | |
| processed_samples = 0 | |
| preds = [] | |
| outputs = [] | |
| while processed_samples < sample_count: | |
| feedback = next(sampler) | |
| batch_size = feedback.outputs[0].data.shape[0] | |
| outputs.append(feedback.outputs) | |
| new_rng_key, rng_key = jax.random.split(rng_key) | |
| cur_preds, _ = predict_fn(new_rng_key, feedback.features) | |
| preds.append(cur_preds) | |
| processed_samples += batch_size | |
| outputs = _concat(outputs, axis=0) | |
| preds = _concat(preds, axis=0) | |
| out = clrs.evaluate(outputs, preds) | |
| if extras: | |
| out.update(extras) | |
| return {k: unpack(v) for k, v in out.items()} | |
| def create_samplers(rng, train_lengths: List[int]): | |
| """Create all the samplers.""" | |
| train_samplers = [] | |
| val_samplers = [] | |
| val_sample_counts = [] | |
| test_samplers = [] | |
| test_sample_counts = [] | |
| spec_list = [] | |
| for algo_idx, algorithm in enumerate(FLAGS.algorithms): | |
| # Make full dataset pipeline run on CPU (including prefetching). | |
| with tf.device('/cpu:0'): | |
| if algorithm in ['naive_string_matcher', 'kmp_matcher']: | |
| # Fixed haystack + needle; variability will be in needle | |
| # Still, for chunked training, we maintain as many samplers | |
| # as train lengths, since, for each length there is a separate state, | |
| # and we must keep the 1:1 relationship between states and samplers. | |
| max_length = max(train_lengths) | |
| if max_length > 0: # if < 0, we are using the benchmark data | |
| max_length = (max_length * 5) // 4 | |
| train_lengths = [max_length] | |
| if FLAGS.chunked_training: | |
| train_lengths = train_lengths * len(train_lengths) | |
| logging.info('Creating samplers for algo %s', algorithm) | |
| p = tuple([0.1 + 0.1 * i for i in range(9)]) | |
| if p and algorithm in ['articulation_points', 'bridges', | |
| 'mst_kruskal', 'bipartite_matching']: | |
| # Choose a lower connection probability for the above algorithms, | |
| # otherwise trajectories are very long | |
| p = tuple(np.array(p) / 2) | |
| length_needle = FLAGS.length_needle | |
| sampler_kwargs = dict(p=p, length_needle=length_needle) | |
| if length_needle == 0: | |
| sampler_kwargs.pop('length_needle') | |
| common_sampler_args = dict( | |
| algorithm=FLAGS.algorithms[algo_idx], | |
| rng=rng, | |
| enforce_pred_as_input=FLAGS.enforce_pred_as_input, | |
| enforce_permutations=FLAGS.enforce_permutations, | |
| chunk_length=FLAGS.chunk_length, | |
| ) | |
| train_args = dict(sizes=train_lengths, | |
| split='train', | |
| batch_size=FLAGS.batch_size, | |
| multiplier=-1, | |
| randomize_pos=FLAGS.random_pos, | |
| chunked=FLAGS.chunked_training, | |
| sampler_kwargs=sampler_kwargs, | |
| **common_sampler_args) | |
| train_sampler, _, spec = make_multi_sampler(**train_args) | |
| mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier'] | |
| val_args = dict(sizes=[np.amax(train_lengths)], | |
| split='val', | |
| batch_size=32, | |
| multiplier=2 * mult, | |
| randomize_pos=FLAGS.random_pos, | |
| chunked=False, | |
| sampler_kwargs=sampler_kwargs, | |
| **common_sampler_args) | |
| val_sampler, val_samples, spec = make_multi_sampler(**val_args) | |
| test_args = dict(sizes=[-1], | |
| split='test', | |
| batch_size=32, | |
| multiplier=2 * mult, | |
| randomize_pos=False, | |
| chunked=False, | |
| sampler_kwargs={}, | |
| **common_sampler_args) | |
| test_sampler, test_samples, spec = make_multi_sampler(**test_args) | |
| spec_list.append(spec) | |
| train_samplers.append(train_sampler) | |
| val_samplers.append(val_sampler) | |
| val_sample_counts.append(val_samples) | |
| test_samplers.append(test_sampler) | |
| test_sample_counts.append(test_samples) | |
| return (train_samplers, | |
| val_samplers, val_sample_counts, | |
| test_samplers, test_sample_counts, | |
| spec_list) | |
| def main(unused_argv): | |
| if FLAGS.hint_mode == 'encoded_decoded': | |
| encode_hints = True | |
| decode_hints = True | |
| elif FLAGS.hint_mode == 'decoded_only': | |
| encode_hints = False | |
| decode_hints = True | |
| elif FLAGS.hint_mode == 'none': | |
| encode_hints = False | |
| decode_hints = False | |
| else: | |
| raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.') | |
| train_lengths = [int(x) for x in FLAGS.train_lengths] | |
| rng = np.random.RandomState(FLAGS.seed) | |
| rng_key = jax.random.PRNGKey(rng.randint(2**32)) | |
| # Create samplers | |
| (train_samplers, | |
| val_samplers, val_sample_counts, | |
| test_samplers, test_sample_counts, | |
| spec_list) = create_samplers(rng, train_lengths) | |
| processor_factory = clrs.get_processor_factory( | |
| FLAGS.processor_type, | |
| use_ln=FLAGS.use_ln, | |
| nb_triplet_fts=FLAGS.nb_triplet_fts, | |
| nb_heads=FLAGS.nb_heads | |
| ) | |
| model_params = dict( | |
| processor_factory=processor_factory, | |
| hidden_dim=FLAGS.hidden_size, | |
| encode_hints=encode_hints, | |
| decode_hints=decode_hints, | |
| encoder_init=FLAGS.encoder_init, | |
| use_lstm=FLAGS.use_lstm, | |
| learning_rate=FLAGS.learning_rate, | |
| grad_clip_max_norm=FLAGS.grad_clip_max_norm, | |
| checkpoint_path=FLAGS.checkpoint_path, | |
| freeze_processor=FLAGS.freeze_processor, | |
| dropout_prob=FLAGS.dropout_prob, | |
| hint_teacher_forcing=FLAGS.hint_teacher_forcing, | |
| hint_repred_mode=FLAGS.hint_repred_mode, | |
| nb_msg_passing_steps=FLAGS.nb_msg_passing_steps, | |
| ) | |
| # save spec_list and model_params; do not change or delete!! | |
| if not os.path.exists(FLAGS.checkpoint_path): | |
| os.makedirs(FLAGS.checkpoint_path) | |
| with open(os.path.join(FLAGS.checkpoint_path, 'spec_list.pkl'), 'wb') as f: | |
| pickle.dump(spec_list, f) | |
| model_params_save = copy.deepcopy(model_params) | |
| model_params_save["processor_factory"] = (FLAGS.processor_type, FLAGS.use_ln, FLAGS.nb_triplet_fts, FLAGS.nb_heads) | |
| with open(os.path.join(FLAGS.checkpoint_path, 'model_params.pkl'), 'wb') as f: | |
| pickle.dump(model_params_save, f) | |
| eval_model = BaselineModel( | |
| spec=spec_list, | |
| dummy_trajectory=[next(t) for t in val_samplers], | |
| **model_params | |
| ) | |
| if FLAGS.chunked_training: | |
| train_model = BaselineModelChunked( | |
| spec=spec_list, | |
| dummy_trajectory=[next(t) for t in train_samplers], | |
| **model_params | |
| ) | |
| else: | |
| train_model = eval_model | |
| # Training loop. | |
| best_score = -1.0 | |
| current_train_items = [0] * len(FLAGS.algorithms) | |
| step = 0 | |
| next_eval = 0 | |
| # Make sure scores improve on first step, but not overcome best score | |
| # until all algos have had at least one evaluation. | |
| val_scores = [-99999.9] * len(FLAGS.algorithms) | |
| length_idx = 0 | |
| while step < FLAGS.train_steps: | |
| feedback_list = [next(t) for t in train_samplers] | |
| # Initialize model. | |
| if step == 0: | |
| all_features = [f.features for f in feedback_list] | |
| if FLAGS.chunked_training: | |
| # We need to initialize the model with samples of all lengths for | |
| # all algorithms. Also, we need to make sure that the order of these | |
| # sample sizes is the same as the order of the actual training sizes. | |
| all_length_features = [all_features] + [ | |
| [next(t).features for t in train_samplers] | |
| for _ in range(len(train_lengths))] | |
| train_model.init(all_length_features[:-1], FLAGS.seed + 1) | |
| else: | |
| train_model.init(all_features, FLAGS.seed + 1) | |
| # Training step. | |
| # enable logging now that we have initialized the model | |
| logging.set_verbosity(logging.INFO) | |
| for algo_idx in range(len(train_samplers)): | |
| feedback = feedback_list[algo_idx] | |
| rng_key, new_rng_key = jax.random.split(rng_key) | |
| if FLAGS.chunked_training: | |
| # In chunked training, we must indicate which training length we are | |
| # using, so the model uses the correct state. | |
| length_and_algo_idx = (length_idx, algo_idx) | |
| else: | |
| # In non-chunked training, all training lengths can be treated equally, | |
| # since there is no state to maintain between batches. | |
| length_and_algo_idx = algo_idx | |
| cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx) | |
| rng_key = new_rng_key | |
| if FLAGS.chunked_training: | |
| examples_in_chunk = np.sum(feedback.features.is_last).item() | |
| else: | |
| examples_in_chunk = len(feedback.features.lengths) | |
| current_train_items[algo_idx] += examples_in_chunk | |
| if step % FLAGS.log_every == 0: | |
| logging.info('Algo %s step %i current loss %f, current_train_items %i.', | |
| FLAGS.algorithms[algo_idx], step, | |
| cur_loss, current_train_items[algo_idx]) | |
| # Periodically evaluate model | |
| if step >= next_eval: | |
| eval_model.params = train_model.params | |
| for algo_idx in range(len(train_samplers)): | |
| common_extras = {'examples_seen': current_train_items[algo_idx], | |
| 'step': step, | |
| 'algorithm': FLAGS.algorithms[algo_idx]} | |
| # Validation info. | |
| new_rng_key, rng_key = jax.random.split(rng_key) | |
| val_stats = collect_and_eval( | |
| val_samplers[algo_idx], | |
| functools.partial(eval_model.predict, algorithm_index=algo_idx), | |
| val_sample_counts[algo_idx], | |
| new_rng_key, | |
| extras=common_extras) | |
| logging.info('(val) algo %s step %d: %s', | |
| FLAGS.algorithms[algo_idx], step, val_stats) | |
| val_scores[algo_idx] = val_stats['score'] | |
| next_eval += FLAGS.eval_every | |
| # If best total score, update best checkpoint. | |
| # Also save a best checkpoint on the first step. | |
| msg = (f'best avg val score was ' | |
| f'{best_score/len(FLAGS.algorithms):.3f}, ' | |
| f'current avg val score is {np.mean(val_scores):.3f}, ' | |
| f'val scores are: ') | |
| msg += ', '.join( | |
| ['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)]) | |
| if (sum(val_scores) > best_score) or step == 0: | |
| best_score = sum(val_scores) | |
| logging.info('Checkpointing best model, %s', msg) | |
| train_model.save_model('best.pkl') | |
| else: | |
| logging.info('Not saving new best model, %s', msg) | |
| step += 1 | |
| length_idx = (length_idx + 1) % len(train_lengths) | |
| logging.info('Restoring best model from checkpoint...') | |
| eval_model.restore_model('best.pkl', only_load_processor=False) | |
| for algo_idx in range(len(train_samplers)): | |
| common_extras = {'examples_seen': current_train_items[algo_idx], | |
| 'step': step, | |
| 'algorithm': FLAGS.algorithms[algo_idx]} | |
| new_rng_key, rng_key = jax.random.split(rng_key) | |
| test_stats = collect_and_eval( | |
| test_samplers[algo_idx], | |
| functools.partial(eval_model.predict, algorithm_index=algo_idx), | |
| test_sample_counts[algo_idx], | |
| new_rng_key, | |
| extras=common_extras) | |
| logging.info('(test) algo %s : %s', FLAGS.algorithms[algo_idx], test_stats) | |
| logging.info('Done!') | |
| if __name__ == '__main__': | |
| app.run(main) | |