Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
22.5 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run training of one or more algorithmic tasks from CLRS."""
import os
# disable logging until training starts
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import functools
import os
import shutil
from typing import Any, Dict, List
from absl import app
from absl import flags
from absl import logging
# disable logging until training starts
logging.set_verbosity(logging.ERROR)
import clrs
import jax
import numpy as np
import requests
import tensorflow as tf
from baselines import BaselineModel, BaselineModelChunked
import pickle
import copy
flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.')
flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'],
'Which training sizes to use. A size of -1 means '
'use the benchmark dataset.')
flags.DEFINE_integer('length_needle', -8,
'Length of needle for training and validation '
'(not testing) in string matching algorithms. '
'A negative value randomizes the length for each sample '
'between 1 and the opposite of the value. '
'A value of 0 means use always 1/4 of the length of '
'the haystack (the default sampler behavior).')
flags.DEFINE_integer('seed', 42, 'Random seed to set')
flags.DEFINE_boolean('random_pos', True,
'Randomize the pos input common to all algos.')
flags.DEFINE_boolean('enforce_permutations', True,
'Whether to enforce permutation-type node pointers.')
flags.DEFINE_boolean('enforce_pred_as_input', True,
'Whether to change pred_h hints into pred inputs.')
flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.')
flags.DEFINE_boolean('chunked_training', False,
'Whether to use chunking for training.')
flags.DEFINE_integer('chunk_length', 16,
'Time chunk length used for training (if '
'`chunked_training` is True.')
flags.DEFINE_integer('train_steps', 500, 'Number of training iterations.')
flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).')
flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).')
flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).')
flags.DEFINE_integer('hidden_size', 128,
'Number of hidden units of the model.')
flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors')
flags.DEFINE_integer('nb_msg_passing_steps', 1,
'Number of message passing steps to run per hint.')
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.')
flags.DEFINE_float('grad_clip_max_norm', 1.0,
'Gradient clipping by norm. 0.0 disables grad clipping')
flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.')
flags.DEFINE_float('hint_teacher_forcing', 0.0,
'Probability that ground-truth teacher hints are encoded '
'during training instead of predicted hints. Only '
'pertinent in encoded_decoded modes.')
flags.DEFINE_enum('hint_mode', 'encoded_decoded',
['encoded_decoded', 'decoded_only', 'none'],
'How should hints be used? Note, each mode defines a '
'separate task, with various difficulties. `encoded_decoded` '
'requires the model to explicitly materialise hint sequences '
'and therefore is hardest, but also most aligned to the '
'underlying algorithmic rule. Hence, `encoded_decoded` '
'should be treated as the default mode for our benchmark. '
'In `decoded_only`, hints are only used for defining '
'reconstruction losses. Often, this will perform well, but '
'note that we currently do not make any efforts to '
'counterbalance the various hint losses. Hence, for certain '
'tasks, the best performance will now be achievable with no '
'hint usage at all (`none`).')
flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'],
'How to process predicted hints when fed back as inputs.'
'In soft mode, we use softmaxes for categoricals, pointers '
'and mask_one, and sigmoids for masks. '
'In hard mode, we use argmax instead of softmax, and hard '
'thresholding of masks. '
'In hard_on_eval mode, soft mode is '
'used for training and hard mode is used for evaluation.')
flags.DEFINE_boolean('use_ln', True,
'Whether to use layer normalisation in the processor.')
flags.DEFINE_boolean('use_lstm', False,
'Whether to insert an LSTM after message passing.')
flags.DEFINE_integer('nb_triplet_fts', 8,
'How many triplet features to compute?')
flags.DEFINE_enum('encoder_init', 'xavier_on_scalars',
['default', 'xavier_on_scalars'],
'Initialiser to use for the encoders.')
flags.DEFINE_enum('processor_type', 'triplet_gmpnn',
['deepsets', 'mpnn', 'pgn', 'pgn_mask',
'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask',
'gat', 'gatv2', 'gat_full', 'gatv2_full',
'gpgn', 'gpgn_mask', 'gmpnn',
'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'],
'Processor type to use as the network P.')
flags.DEFINE_string('checkpoint_path', './checkpoints',
'Path in which checkpoints are saved.')
flags.DEFINE_string('dataset_path', '/tmp/CLRS30',
'Path in which dataset is stored.')
flags.DEFINE_boolean('freeze_processor', False,
'Whether to freeze the processor of the model.')
FLAGS = flags.FLAGS
PRED_AS_INPUT_ALGOS = [
'binary_search',
'minimum',
'find_maximum_subarray',
'find_maximum_subarray_kadane',
'matrix_chain_order',
'lcs_length',
'optimal_bst',
'activity_selector',
'task_scheduling',
'naive_string_matcher',
'kmp_matcher',
'jarvis_march']
def unpack(v):
try:
return v.item() # DeviceArray
except (AttributeError, ValueError):
return v
def _iterate_sampler(sampler, batch_size):
while True:
yield sampler.next(batch_size)
def _maybe_download_dataset(dataset_path):
"""Download CLRS30 dataset if needed."""
dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder())
if os.path.isdir(dataset_folder):
logging.info('Dataset found at %s. Skipping download.', dataset_folder)
return dataset_folder
logging.info('Dataset not found in %s. Downloading...', dataset_folder)
clrs_url = clrs.get_dataset_gcp_url()
request = requests.get(clrs_url, allow_redirects=True)
clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url))
os.makedirs(dataset_folder)
open(clrs_file, 'wb').write(request.content)
shutil.unpack_archive(clrs_file, extract_dir=dataset_folder)
os.remove(clrs_file)
return dataset_folder
def make_sampler(length: int,
rng: Any,
algorithm: str,
split: str,
batch_size: int,
multiplier: int,
randomize_pos: bool,
enforce_pred_as_input: bool,
enforce_permutations: bool,
chunked: bool,
chunk_length: int,
sampler_kwargs: Dict[str, Any]):
"""Create a sampler with given options.
Args:
length: Size of samples (i.e., number of nodes in the graph).
A length of -1 will mean that the benchmark
dataset (for the given split) is used. Positive sizes will instantiate
samplers of the corresponding size.
rng: Numpy random state.
algorithm: The name of the algorithm to sample from.
split: 'train', 'val' or 'test'.
batch_size: Samples per batch.
multiplier: Integer multiplier for the number of samples in the dataset,
only used for positive sizes. Negative multiplier means infinite samples.
randomize_pos: Whether to randomize the `pos` input.
enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs.
enforce_permutations: Whether to enforce permutation pointers.
chunked: Whether to chunk the dataset.
chunk_length: Unroll length of chunks, if `chunked` is True.
sampler_kwargs: Extra args passed to the sampler.
Returns:
A sampler (iterator), the number of samples in the iterator (negative
if infinite samples), and the spec.
"""
if length < 0: # load from file
dataset_folder = _maybe_download_dataset(FLAGS.dataset_path)
sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder,
algorithm=algorithm,
batch_size=batch_size,
split=split)
sampler = sampler.as_numpy_iterator()
else:
num_samples = clrs.CLRS30[split]['num_samples'] * multiplier
sampler, spec = clrs.build_sampler(
algorithm,
seed=rng.randint(2**32),
num_samples=num_samples,
length=length,
**sampler_kwargs,
)
sampler = _iterate_sampler(sampler, batch_size)
if randomize_pos:
sampler = clrs.process_random_pos(sampler, rng)
if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS:
spec, sampler = clrs.process_pred_as_input(spec, sampler)
spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations)
if chunked:
sampler = clrs.chunkify(sampler, chunk_length)
return sampler, num_samples, spec
def make_multi_sampler(sizes, rng, **kwargs):
"""Create a sampler with cycling sample sizes."""
ss = []
tot_samples = 0
for length in sizes:
sampler, num_samples, spec = make_sampler(length, rng, **kwargs)
ss.append(sampler)
tot_samples += num_samples
def cycle_samplers():
while True:
for s in ss:
yield next(s)
return cycle_samplers(), tot_samples, spec
def _concat(dps, axis):
return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps)
def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras):
"""Collect batches of output and hint preds and evaluate them."""
processed_samples = 0
preds = []
outputs = []
while processed_samples < sample_count:
feedback = next(sampler)
batch_size = feedback.outputs[0].data.shape[0]
outputs.append(feedback.outputs)
new_rng_key, rng_key = jax.random.split(rng_key)
cur_preds, _ = predict_fn(new_rng_key, feedback.features)
preds.append(cur_preds)
processed_samples += batch_size
outputs = _concat(outputs, axis=0)
preds = _concat(preds, axis=0)
out = clrs.evaluate(outputs, preds)
if extras:
out.update(extras)
return {k: unpack(v) for k, v in out.items()}
def create_samplers(rng, train_lengths: List[int]):
"""Create all the samplers."""
train_samplers = []
val_samplers = []
val_sample_counts = []
test_samplers = []
test_sample_counts = []
spec_list = []
for algo_idx, algorithm in enumerate(FLAGS.algorithms):
# Make full dataset pipeline run on CPU (including prefetching).
with tf.device('/cpu:0'):
if algorithm in ['naive_string_matcher', 'kmp_matcher']:
# Fixed haystack + needle; variability will be in needle
# Still, for chunked training, we maintain as many samplers
# as train lengths, since, for each length there is a separate state,
# and we must keep the 1:1 relationship between states and samplers.
max_length = max(train_lengths)
if max_length > 0: # if < 0, we are using the benchmark data
max_length = (max_length * 5) // 4
train_lengths = [max_length]
if FLAGS.chunked_training:
train_lengths = train_lengths * len(train_lengths)
logging.info('Creating samplers for algo %s', algorithm)
p = tuple([0.1 + 0.1 * i for i in range(9)])
if p and algorithm in ['articulation_points', 'bridges',
'mst_kruskal', 'bipartite_matching']:
# Choose a lower connection probability for the above algorithms,
# otherwise trajectories are very long
p = tuple(np.array(p) / 2)
length_needle = FLAGS.length_needle
sampler_kwargs = dict(p=p, length_needle=length_needle)
if length_needle == 0:
sampler_kwargs.pop('length_needle')
common_sampler_args = dict(
algorithm=FLAGS.algorithms[algo_idx],
rng=rng,
enforce_pred_as_input=FLAGS.enforce_pred_as_input,
enforce_permutations=FLAGS.enforce_permutations,
chunk_length=FLAGS.chunk_length,
)
train_args = dict(sizes=train_lengths,
split='train',
batch_size=FLAGS.batch_size,
multiplier=-1,
randomize_pos=FLAGS.random_pos,
chunked=FLAGS.chunked_training,
sampler_kwargs=sampler_kwargs,
**common_sampler_args)
train_sampler, _, spec = make_multi_sampler(**train_args)
mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier']
val_args = dict(sizes=[np.amax(train_lengths)],
split='val',
batch_size=32,
multiplier=2 * mult,
randomize_pos=FLAGS.random_pos,
chunked=False,
sampler_kwargs=sampler_kwargs,
**common_sampler_args)
val_sampler, val_samples, spec = make_multi_sampler(**val_args)
test_args = dict(sizes=[-1],
split='test',
batch_size=32,
multiplier=2 * mult,
randomize_pos=False,
chunked=False,
sampler_kwargs={},
**common_sampler_args)
test_sampler, test_samples, spec = make_multi_sampler(**test_args)
spec_list.append(spec)
train_samplers.append(train_sampler)
val_samplers.append(val_sampler)
val_sample_counts.append(val_samples)
test_samplers.append(test_sampler)
test_sample_counts.append(test_samples)
return (train_samplers,
val_samplers, val_sample_counts,
test_samplers, test_sample_counts,
spec_list)
def main(unused_argv):
if FLAGS.hint_mode == 'encoded_decoded':
encode_hints = True
decode_hints = True
elif FLAGS.hint_mode == 'decoded_only':
encode_hints = False
decode_hints = True
elif FLAGS.hint_mode == 'none':
encode_hints = False
decode_hints = False
else:
raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.')
train_lengths = [int(x) for x in FLAGS.train_lengths]
rng = np.random.RandomState(FLAGS.seed)
rng_key = jax.random.PRNGKey(rng.randint(2**32))
# Create samplers
(train_samplers,
val_samplers, val_sample_counts,
test_samplers, test_sample_counts,
spec_list) = create_samplers(rng, train_lengths)
processor_factory = clrs.get_processor_factory(
FLAGS.processor_type,
use_ln=FLAGS.use_ln,
nb_triplet_fts=FLAGS.nb_triplet_fts,
nb_heads=FLAGS.nb_heads
)
model_params = dict(
processor_factory=processor_factory,
hidden_dim=FLAGS.hidden_size,
encode_hints=encode_hints,
decode_hints=decode_hints,
encoder_init=FLAGS.encoder_init,
use_lstm=FLAGS.use_lstm,
learning_rate=FLAGS.learning_rate,
grad_clip_max_norm=FLAGS.grad_clip_max_norm,
checkpoint_path=FLAGS.checkpoint_path,
freeze_processor=FLAGS.freeze_processor,
dropout_prob=FLAGS.dropout_prob,
hint_teacher_forcing=FLAGS.hint_teacher_forcing,
hint_repred_mode=FLAGS.hint_repred_mode,
nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
)
# save spec_list and model_params; do not change or delete!!
if not os.path.exists(FLAGS.checkpoint_path):
os.makedirs(FLAGS.checkpoint_path)
with open(os.path.join(FLAGS.checkpoint_path, 'spec_list.pkl'), 'wb') as f:
pickle.dump(spec_list, f)
model_params_save = copy.deepcopy(model_params)
model_params_save["processor_factory"] = (FLAGS.processor_type, FLAGS.use_ln, FLAGS.nb_triplet_fts, FLAGS.nb_heads)
with open(os.path.join(FLAGS.checkpoint_path, 'model_params.pkl'), 'wb') as f:
pickle.dump(model_params_save, f)
eval_model = BaselineModel(
spec=spec_list,
dummy_trajectory=[next(t) for t in val_samplers],
**model_params
)
if FLAGS.chunked_training:
train_model = BaselineModelChunked(
spec=spec_list,
dummy_trajectory=[next(t) for t in train_samplers],
**model_params
)
else:
train_model = eval_model
# Training loop.
best_score = -1.0
current_train_items = [0] * len(FLAGS.algorithms)
step = 0
next_eval = 0
# Make sure scores improve on first step, but not overcome best score
# until all algos have had at least one evaluation.
val_scores = [-99999.9] * len(FLAGS.algorithms)
length_idx = 0
while step < FLAGS.train_steps:
feedback_list = [next(t) for t in train_samplers]
# Initialize model.
if step == 0:
all_features = [f.features for f in feedback_list]
if FLAGS.chunked_training:
# We need to initialize the model with samples of all lengths for
# all algorithms. Also, we need to make sure that the order of these
# sample sizes is the same as the order of the actual training sizes.
all_length_features = [all_features] + [
[next(t).features for t in train_samplers]
for _ in range(len(train_lengths))]
train_model.init(all_length_features[:-1], FLAGS.seed + 1)
else:
train_model.init(all_features, FLAGS.seed + 1)
# Training step.
# enable logging now that we have initialized the model
logging.set_verbosity(logging.INFO)
for algo_idx in range(len(train_samplers)):
feedback = feedback_list[algo_idx]
rng_key, new_rng_key = jax.random.split(rng_key)
if FLAGS.chunked_training:
# In chunked training, we must indicate which training length we are
# using, so the model uses the correct state.
length_and_algo_idx = (length_idx, algo_idx)
else:
# In non-chunked training, all training lengths can be treated equally,
# since there is no state to maintain between batches.
length_and_algo_idx = algo_idx
cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx)
rng_key = new_rng_key
if FLAGS.chunked_training:
examples_in_chunk = np.sum(feedback.features.is_last).item()
else:
examples_in_chunk = len(feedback.features.lengths)
current_train_items[algo_idx] += examples_in_chunk
if step % FLAGS.log_every == 0:
logging.info('Algo %s step %i current loss %f, current_train_items %i.',
FLAGS.algorithms[algo_idx], step,
cur_loss, current_train_items[algo_idx])
# Periodically evaluate model
if step >= next_eval:
eval_model.params = train_model.params
for algo_idx in range(len(train_samplers)):
common_extras = {'examples_seen': current_train_items[algo_idx],
'step': step,
'algorithm': FLAGS.algorithms[algo_idx]}
# Validation info.
new_rng_key, rng_key = jax.random.split(rng_key)
val_stats = collect_and_eval(
val_samplers[algo_idx],
functools.partial(eval_model.predict, algorithm_index=algo_idx),
val_sample_counts[algo_idx],
new_rng_key,
extras=common_extras)
logging.info('(val) algo %s step %d: %s',
FLAGS.algorithms[algo_idx], step, val_stats)
val_scores[algo_idx] = val_stats['score']
next_eval += FLAGS.eval_every
# If best total score, update best checkpoint.
# Also save a best checkpoint on the first step.
msg = (f'best avg val score was '
f'{best_score/len(FLAGS.algorithms):.3f}, '
f'current avg val score is {np.mean(val_scores):.3f}, '
f'val scores are: ')
msg += ', '.join(
['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
if (sum(val_scores) > best_score) or step == 0:
best_score = sum(val_scores)
logging.info('Checkpointing best model, %s', msg)
train_model.save_model('best.pkl')
else:
logging.info('Not saving new best model, %s', msg)
step += 1
length_idx = (length_idx + 1) % len(train_lengths)
logging.info('Restoring best model from checkpoint...')
eval_model.restore_model('best.pkl', only_load_processor=False)
for algo_idx in range(len(train_samplers)):
common_extras = {'examples_seen': current_train_items[algo_idx],
'step': step,
'algorithm': FLAGS.algorithms[algo_idx]}
new_rng_key, rng_key = jax.random.split(rng_key)
test_stats = collect_and_eval(
test_samplers[algo_idx],
functools.partial(eval_model.predict, algorithm_index=algo_idx),
test_sample_counts[algo_idx],
new_rng_key,
extras=common_extras)
logging.info('(test) algo %s : %s', FLAGS.algorithms[algo_idx], test_stats)
logging.info('Done!')
if __name__ == '__main__':
app.run(main)