Spaces:
Build error
Build error
| # Copyright 2022 The T5X Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # pylint:disable=line-too-long | |
| # pyformat: disable | |
| r"""This script runs inference on a T5X-compatible model. | |
| """ | |
| # pyformat: enable | |
| # pylint:enable=line-too-long | |
| import concurrent.futures | |
| import functools | |
| import hashlib | |
| import json | |
| import os | |
| import re | |
| import shutil | |
| import time | |
| from typing import Any, Callable, Iterator, List, Mapping, Optional, Sequence, Tuple, Type | |
| # TODO(adarob): Re-enable once users are notified and tests are updated. | |
| # Must be set before flax imports. | |
| # pylint:disable=g-import-not-at-top | |
| os.environ['FLAX_LAZY_RNG'] = 'no' | |
| from absl import logging | |
| from clu import metric_writers | |
| import jax | |
| from jax.experimental import multihost_utils | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import seqio | |
| from t5x import gin_utils | |
| from t5x import models | |
| from t5x import partitioning | |
| from t5x import utils | |
| import tensorflow as tf | |
| from tensorflow.io import gfile | |
| from typing_extensions import Protocol | |
| # Automatically search for gin files relative to the T5X package. | |
| _DEFAULT_GIN_SEARCH_PATHS = [ | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| ] | |
| AUTOTUNE = tf.data.experimental.AUTOTUNE | |
| class SummarizeConfigFn(Protocol): | |
| def __call__(self, model_dir: str, | |
| summary_writer: Optional[metric_writers.SummaryWriter], | |
| step: int) -> None: | |
| ... | |
| class FailFastThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): | |
| """Wrapper for ThreadPoolExecutor that crashes main thread on exceptions. | |
| NOTE: this class should be used only from the main thread. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._incomplete_futures: List[concurrent.futures.Future] = [] | |
| def check_for_exceptions(self, wait: bool = False): | |
| """Raises any exceptions from complete futures on the main thread.""" | |
| still_incomplete_futures = [] | |
| for future in self._incomplete_futures: | |
| try: | |
| exception = future.exception(timeout=0 if wait else None) | |
| except concurrent.futures.TimeoutError: | |
| still_incomplete_futures.append(future) | |
| if exception is not None: | |
| raise exception | |
| self._incomplete_futures = still_incomplete_futures | |
| def submit(self, *args, **kwargs) -> concurrent.futures.Future: | |
| """Submit function to threadpool, capturing the returned future.""" | |
| future = super().submit(*args, **kwargs) | |
| self._incomplete_futures.append(future) | |
| self.check_for_exceptions(wait=False) | |
| return future | |
| def shutdown(self, *args, wait: bool = False, **kwargs): | |
| self.check_for_exceptions(wait=wait) | |
| super().shutdown(*args, **kwargs) | |
| def create_task_from_tfexample_file(paths: Sequence[str], | |
| file_type: str, | |
| inputs_key: str, | |
| targets_key: Optional[str], | |
| features: Mapping[str, seqio.Feature], | |
| task_id: Optional[str] = None) -> str: | |
| """Registers ad-hoc Task for file-based dataset of TFExamples. | |
| Args: | |
| paths: Input file paths; all files should have type `file_type` and contain | |
| binary-serialized TFExample protos. | |
| file_type: Input file type; e.g., 'tfrecord', 'recordio', 'sstable'. For | |
| keyed formats like 'sstable', we ignore the keys and use only the values. | |
| inputs_key: Name of TFExample feature containing the input text for T5X. The | |
| value of this feature should be a UTF8-encoded string. | |
| targets_key: Optional name of a TFExample feature containing the target text | |
| (relevant only in scoring mode). The value of this feature should be a | |
| UTF8-encoded string. | |
| features: Should have entries for keys 'inputs' and (if targets_key is not | |
| None) 'targets', mapping to `seqio.Feature` objects that specify | |
| attributes like vocabulary, add_eos, etc. These attributes are used for | |
| preprocessing and featurizing the input text. | |
| task_id: Task name identifier. By default, it is set to a unique and | |
| deterministic hash id. Overrideable via this argument. | |
| Returns: | |
| Name of the newly-registered Task. This Task has a split named 'infer' that | |
| contains the preprocessed and featurized input dataset. | |
| """ | |
| # tf.io.gfile.glob supports lists, in contrast to gfile.glob. | |
| files = tf.io.gfile.glob(paths) | |
| if files: | |
| logging.info('Using tfexample files %s', files) | |
| else: | |
| # Fail early if there's something wrong with the input file pattern. | |
| raise ValueError('Missing or invalid paths: %s' % paths) | |
| reader = { | |
| 'tfrecord': | |
| tf.data.TFRecordDataset, | |
| }[file_type] | |
| feature_description = {inputs_key: tf.io.FixedLenFeature([], tf.string)} | |
| if targets_key: | |
| feature_description[targets_key] = tf.io.FixedLenFeature([], tf.string) | |
| # Create a unique, deterministic task name. | |
| if task_id is None: | |
| task_id = hashlib.md5( | |
| ':'.join(list(paths) + | |
| [inputs_key, targets_key or '']).encode()).hexdigest()[:10] | |
| task = seqio.TaskRegistry.add( | |
| name=f'infer_{task_id}', | |
| source=seqio.TFExampleDataSource({'infer': paths}, | |
| feature_description=feature_description, | |
| reader_cls=reader), | |
| preprocessors=[ | |
| functools.partial( | |
| seqio.preprocessors.rekey, | |
| key_map={ | |
| 'inputs': inputs_key, | |
| 'targets': targets_key | |
| }), seqio.preprocessors.tokenize_and_append_eos | |
| ], | |
| output_features=features) | |
| return task.name | |
| def merge_chunks_to_file( | |
| output_dir: str, | |
| output_fname: str, | |
| tmp_dir: str, | |
| step: Optional[int], | |
| ) -> None: | |
| """Merge the predictions from different chunks into a unified file.""" | |
| logging.info('Merging chunk results.') | |
| # Merge chunks into single file. | |
| chunk_paths = sorted( | |
| gfile.glob(os.path.join(tmp_dir, f'{output_fname}-chunk?????'))) | |
| if not chunk_paths: | |
| raise FileNotFoundError( | |
| 'No chunk results found! One possible explanation is that your ' | |
| 'input did not contain any examples') | |
| assert int(chunk_paths[-1][-5:]) + 1 == len(chunk_paths), ( | |
| f'Expecting {int(chunk_paths[-1][-5:])} chunk paths, found ' | |
| f'{len(chunk_paths)}') | |
| output_path = os.path.join(output_dir, output_fname) | |
| del step | |
| with gfile.GFile(output_path, 'wb') as merged: | |
| for chunk_path in chunk_paths: | |
| with gfile.GFile(chunk_path, 'rb') as ef: | |
| shutil.copyfileobj(ef, merged) | |
| logging.info('Results written to %s.', output_path) | |
| _Inferences = Tuple[Sequence[Any], Mapping[str, Any]] | |
| def write_inferences_to_file( | |
| path: str, | |
| inferences: _Inferences, | |
| task_ds: tf.data.Dataset, | |
| mode: str, | |
| vocabulary: Optional[seqio.Vocabulary] = None, | |
| json_encoder_cls: Type[json.JSONEncoder] = seqio.TensorAndNumpyEncoder, | |
| include_all_inputs: bool = False, | |
| input_fields_to_include: Optional[Sequence[str]] = None, | |
| output_ids: bool = False) -> None: | |
| """Write model predictions, along with pretokenized inputs, to JSONL file. | |
| Args: | |
| path: File path to write to. | |
| inferences: A tuple containing (predictions, aux_values). If mode is | |
| 'predict' then the `predictions` will be token IDs. If it's | |
| 'scores' then it'll be a collection of scores. `aux_values` will be an | |
| empty dictionary unless mode is 'predict_with_aux', in which case it'll | |
| contain the model's auxiliary outputs. | |
| task_ds: Original task dataset. Features from task with suffix | |
| `_pretokenized` are added to the outputs. | |
| mode: Prediction mode, either 'predict', 'score' or 'predict_with_aux'. | |
| vocabulary: Task output vocabulary. Only used in `predict` mode in order to | |
| decode predicted outputs into string. | |
| json_encoder_cls: a JSON encoder class used to customize JSON serialization | |
| via json.dumps. | |
| include_all_inputs: if True, will include all model inputs in the output | |
| JSONL file (including raw tokens) in addition to the pretokenized inputs. | |
| input_fields_to_include: List of input fields to include in the output JSONL | |
| file. This list should be None if `include_all_inputs` is set to True. | |
| output_ids: if True, will output the token ID sequence for the output, in | |
| addition to the decoded text. | |
| """ | |
| all_predictions, all_aux_values = inferences | |
| if mode in ('predict', 'predict_with_aux') and vocabulary is None: | |
| raise ValueError('The `vocabulary` parameter is required in `predict` and ' | |
| '`predict_with_aux` modes') | |
| def _json_compat(value): | |
| if isinstance(value, bytes): | |
| return value.decode('utf-8') | |
| elif isinstance(value, (jnp.bfloat16, jnp.floating)): | |
| return float(value) | |
| elif isinstance(value, jnp.integer): | |
| return float(value) | |
| elif isinstance(value, (jnp.ndarray, np.ndarray)): | |
| # Flatten array features. | |
| return value.tolist() | |
| else: | |
| return value | |
| if include_all_inputs and input_fields_to_include is not None: | |
| raise ValueError( | |
| 'include_all_inputs and input_fields_to_include should not be set' | |
| ' simultaneously.') | |
| with gfile.GFile(path, 'w') as f: | |
| for i, inp in task_ds.enumerate().as_numpy_iterator(): | |
| predictions = all_predictions[i] | |
| aux_values = {aux_field: v[i] for aux_field, v in all_aux_values.items()} | |
| if include_all_inputs: | |
| inputs = inp | |
| elif input_fields_to_include is not None: | |
| inputs = { | |
| k: v for k, v in inp.items() if k in input_fields_to_include or | |
| (k.endswith('_pretokenized') and | |
| k[:-len('_pretokenized')] in input_fields_to_include) | |
| } | |
| else: | |
| inputs = {k: v for k, v in inp.items() if k.endswith('_pretokenized')} | |
| json_dict = {} | |
| json_dict['inputs'] = {k: _json_compat(v) for k, v in inputs.items()} | |
| if mode == 'predict': | |
| assert vocabulary is not None | |
| json_dict['prediction'] = _json_compat( | |
| vocabulary.decode_tf(tf.constant(predictions)).numpy()) | |
| if output_ids: | |
| pred = _json_compat(tf.constant(predictions).numpy()) | |
| # Truncate padding tokens. | |
| assert isinstance(pred, list) | |
| pred = pred[:pred.index(0)] if 0 in pred else pred | |
| json_dict['prediction_tokens'] = pred | |
| elif mode == 'score': | |
| json_dict['score'] = _json_compat(predictions) | |
| elif mode == 'predict_with_aux': | |
| assert vocabulary is not None | |
| json_dict['prediction'] = _json_compat( | |
| vocabulary.decode_tf(tf.constant(predictions)).numpy()) | |
| if output_ids: | |
| pred = _json_compat(tf.constant(predictions).numpy()) | |
| # Truncate padding tokens. | |
| pred = pred[:pred.index(0)] if 0 in pred else pred | |
| json_dict['prediction_tokens'] = pred | |
| json_dict['aux'] = jax.tree_map(_json_compat, aux_values) | |
| else: | |
| raise ValueError(f'Invalid mode: {mode}') | |
| json_str = json.dumps(json_dict, cls=json_encoder_cls) | |
| f.write(json_str + '\n') | |
| WriteFn = Callable[[ | |
| str, | |
| _Inferences, | |
| tf.data.Dataset, | |
| str, | |
| Optional[seqio.Vocabulary], | |
| ], None] | |
| MergeFn = Callable[[str, str, str, Optional[int]], None] | |
| def _extract_tokens_and_aux_values(inference_fn_outputs) -> _Inferences: | |
| """Extracts tokens and aux scores from a cached dataset.""" | |
| all_aux_values = {} | |
| if isinstance(inference_fn_outputs, tuple): | |
| indices_and_tokens, all_aux_values = inference_fn_outputs | |
| indices, tokens = zip(*indices_and_tokens) | |
| permutation = np.argsort(indices) | |
| tokens = [tokens[permutation[i]] for i in range(len(permutation))] | |
| for aux_keys, aux_values in all_aux_values.items(): | |
| all_aux_values[aux_keys] = [ | |
| aux_values[permutation[i]] for i in range(len(permutation)) | |
| ] | |
| else: | |
| indices_and_tokens = inference_fn_outputs | |
| _, tokens = zip(*sorted(indices_and_tokens, key=lambda x: x[0])) | |
| return tokens, all_aux_values | |
| def infer( | |
| *, | |
| mode: str, | |
| model: models.BaseTransformerModel, | |
| dataset_cfg: utils.DatasetConfig, | |
| restore_checkpoint_cfg: utils.RestoreCheckpointConfig, | |
| partitioner: partitioning.BasePartitioner, | |
| output_dir: str, | |
| checkpoint_period: int, | |
| shard_id: int = 0, | |
| num_shards: int = 1, | |
| merge_chunked_results: bool = True, | |
| write_fn: WriteFn = write_inferences_to_file, | |
| checkpoint_ds_iter: bool = True, | |
| fallback_init_rng: Optional[int] = None, | |
| merge_fn: MergeFn = merge_chunks_to_file, | |
| summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, | |
| ): | |
| """Infer function. | |
| Args: | |
| mode: Either 'predict' to decode targets, 'score' to compute the log | |
| likelihood of given targets, or 'predict_with_aux' for both. | |
| model: The model object to use for inference. | |
| dataset_cfg: Specification for the dataset to infer based on. | |
| restore_checkpoint_cfg: Specification for the model parameter checkpoint to | |
| load. | |
| partitioner: Partitioner for model parameters and data across devices. | |
| output_dir: Path to directory to write temporary files and final results. | |
| checkpoint_period: The intermediate results and dataset iterator will be | |
| checkpointed on each multiple of this number of batches to enable | |
| continuation after a failure. | |
| shard_id: Index of dataset shard for this instance to use if splitting the | |
| work across multiple jobs. | |
| num_shards: Total number of dataset shards to split dataset across. | |
| merge_chunked_results: Whether to merge results of all chunks into a single | |
| json file. | |
| write_fn: Callable function used to serialized and write inferences out to | |
| files. | |
| checkpoint_ds_iter: if True, will checkpoint the dataset iterator every | |
| `checkpoint_period` to enable faster restore. This must be disabled for | |
| certain datasets, for example since stateful iterators (e.g. from | |
| seqio.FunctionTask) cannot be checkpointed. | |
| fallback_init_rng: A random seed used for parameter initialization during | |
| model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is | |
| set to True. If None, parameter initialization is not allowed during model | |
| loading and having fallback_to_scratch enabled will result in an error. | |
| merge_fn: Callable function used to merge inferences from multiple files. | |
| summarize_config_fn: A function that takes in the model directory, an | |
| optional SummaryWriter, and the step number, and writes a summary of the | |
| configuration. SummaryWriter will be None in most cases. | |
| """ | |
| logging.info('Process ID: %d', jax.process_index()) | |
| summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) | |
| if mode not in ('predict', 'score', 'predict_with_aux'): | |
| raise ValueError( | |
| "`mode` must be one of 'predict', 'score' or 'predict_with_aux'. " | |
| f"Got '{mode}'") | |
| # Remove double-slashes in directory path to avoid inconsistencies. | |
| output_dir = re.sub(r'(?<!gs:)([\/]{2,})', '/', output_dir) | |
| ds_vocabs = utils.get_vocabulary(dataset_cfg) | |
| if (ds_vocabs[0] != model.input_vocabulary or | |
| ds_vocabs[1] != model.output_vocabulary): | |
| raise ValueError( | |
| 'Model and Task vocabularies do not match.\n' | |
| f'Task Input: {ds_vocabs[0]}, Model Input: {model.input_vocabulary}\n' | |
| f'Task Output: {ds_vocabs[1]}, Model Output: {model.output_vocabulary}') | |
| batch_size = dataset_cfg.batch_size | |
| # Set up dataset. | |
| if dataset_cfg.module: | |
| utils.import_module(dataset_cfg.module) | |
| host_shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards) | |
| task_or_mixture = seqio.get_mixture_or_task(dataset_cfg.mixture_or_task_name) | |
| feature_converter = model.FEATURE_CONVERTER_CLS(pack=False) | |
| def _get_dataset(dataset_provider): | |
| # TODO(adarob): assert pack is false, shuffle is false, seed? | |
| return dataset_provider.get_dataset( | |
| sequence_length=dataset_cfg.task_feature_lengths, | |
| split=dataset_cfg.split, | |
| shuffle=False, | |
| num_epochs=1, | |
| shard_info=host_shard_info, | |
| use_cached=dataset_cfg.use_cached, | |
| seed=dataset_cfg.seed) | |
| # Each "chunk" should be how often we checkpoint the input dataset and flush | |
| # the inferences to disk. | |
| logging.info('Inferring with checkpoints every %d batches of %d examples.', | |
| checkpoint_period, batch_size) | |
| logging.info('Initializing model, optimizer, and step functions.') | |
| element_spec = feature_converter( | |
| _get_dataset(task_or_mixture), | |
| dataset_cfg.task_feature_lengths).element_spec | |
| input_shapes = { | |
| k: (batch_size,) + spec.shape for k, spec in element_spec.items() | |
| } | |
| input_types = { | |
| k: jnp.dtype(spec.dtype.as_numpy_dtype) | |
| for k, spec in element_spec.items() | |
| } | |
| # Initialize optimizer from the existing checkpoint. | |
| # TODO(adarob): Support inference over multiple checkpoints. | |
| train_state_initializer = utils.TrainStateInitializer( | |
| optimizer_def=None, # Do not load optimizer state. | |
| init_fn=model.get_initial_variables, | |
| input_shapes=input_shapes, | |
| input_types=input_types, | |
| partitioner=partitioner) | |
| # Log the variable shapes information and write to a file. | |
| model_info_log_file = os.path.join(output_dir, 'model-info.txt') | |
| if shard_id == 0: | |
| utils.log_model_info(model_info_log_file, | |
| train_state_initializer.global_train_state_shape, | |
| partitioner) | |
| # Disable strictness since we are dropping the optimizer state. | |
| restore_checkpoint_cfg.strict = False | |
| if fallback_init_rng is not None: | |
| fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) | |
| train_state = train_state_initializer.from_checkpoint( | |
| [restore_checkpoint_cfg], init_rng=fallback_init_rng) | |
| if mode == 'predict': | |
| infer_step = model.predict_batch | |
| elif mode == 'predict_with_aux': | |
| infer_step = model.predict_batch_with_aux | |
| else: # mode == 'score' | |
| infer_step = model.score_batch | |
| infer_fn = functools.partial( | |
| utils.get_infer_fn( | |
| infer_step=infer_step, | |
| batch_size=batch_size, | |
| train_state_axes=train_state_initializer.train_state_axes, | |
| partitioner=partitioner), | |
| train_state=train_state) | |
| def infer_task(task: seqio.Task): | |
| tmp_dir = os.path.join(output_dir, | |
| f'tmp-{task.name}-{shard_id:05}-of-{num_shards:05}') | |
| if jax.process_index() == 0: | |
| gfile.makedirs(tmp_dir) | |
| # Use `max_workers=1` to ensure writes occur sequentially. | |
| write_thread_pool = FailFastThreadPoolExecutor(max_workers=1) | |
| logging.info("Loading dataset for task '%s'.", task.name) | |
| ds = _get_dataset(task) | |
| model_ds = feature_converter( | |
| ds, task_feature_lengths=dataset_cfg.task_feature_lengths) | |
| # Zip task and model features. | |
| # (task, model) | |
| infer_ds = tf.data.Dataset.zip((ds, model_ds)) | |
| # Create batches the size of each chunk and index them. | |
| # (i, [(task, model)] * chunk_size) | |
| infer_ds = infer_ds.padded_batch( | |
| checkpoint_period * batch_size, drop_remainder=False).enumerate() | |
| infer_ds_iter: Iterator[Tuple[int, Any]] = iter(infer_ds.prefetch(AUTOTUNE)) | |
| if checkpoint_ds_iter: | |
| # Create checkpoint manager and restore state, if applicable. | |
| ckpt_path = os.path.join(tmp_dir, 'input.ckpt') | |
| input_ckpt = tf.train.Checkpoint(ds=infer_ds_iter) | |
| if gfile.glob(ckpt_path + '*'): | |
| logging.info('Restoring input iterator from %s', ckpt_path) | |
| input_ckpt.read(ckpt_path).assert_consumed() | |
| output_fname = f'{task.name}-{mode}.jsonl-{shard_id:05}-of-{num_shards:05}' | |
| if gfile.exists(os.path.join(output_dir, output_fname)): | |
| logging.info( | |
| "File %s exists. Skipping inference for shard %d/%d of task '%s'", | |
| output_fname, shard_id, num_shards, task.name) | |
| return | |
| logging.info("Starting inference loop for shard %d of %d of task '%s'.", | |
| shard_id, num_shards, task.name) | |
| def _write_chunk_and_canonicalize_ckpt(chunk: int, chunk_path: str, | |
| inferences: _Inferences, | |
| task_ds: tf.data.Dataset, | |
| chunk_ckpt_path: Optional[str]): | |
| write_tick = time.time() | |
| logging.info('Writing chunk %d results to %s', chunk, chunk_path) | |
| write_fn(chunk_path, inferences, task_ds, mode, | |
| task.output_features['targets'].vocabulary) | |
| with gfile.GFile(chunk_path + '.COMPLETED', 'w') as f: | |
| f.write('') | |
| write_time = time.time() - write_tick | |
| logging.info('Writing completed in %02f seconds (%02f examples/sec).', | |
| write_time, | |
| len(inferences) / write_time) | |
| update_measurement_series('writing_total_sec', chunk, write_time) | |
| update_measurement_series('writing_examples_per_sec', chunk, | |
| len(inferences) / write_time) | |
| if chunk_ckpt_path: | |
| # Canonicalize checkpoint. | |
| for fname in gfile.glob(chunk_ckpt_path + '*'): | |
| gfile.rename( | |
| fname, fname.replace(chunk_ckpt_path, ckpt_path), overwrite=True) | |
| # Main Loop over "chunks". | |
| for chunk, chunk_batch in infer_ds_iter: | |
| logging.info('Starting chunk %d', chunk) | |
| chunk_tick = time.time() | |
| # Load the dataset for the next chunk. We can't use `infer_ds_iter` | |
| # directly since `infer_fn` needs to know the exact size of each chunk, | |
| # which may be smaller for the final one. | |
| chunk_ds = tf.data.Dataset.from_tensor_slices(chunk_batch) | |
| chunk_ds.cache().prefetch(AUTOTUNE) | |
| # Unzip chunk dataset in to pretokenized and model datasets. | |
| task_ds = chunk_ds.map(lambda p, m: p, num_parallel_calls=AUTOTUNE) | |
| model_ds = chunk_ds.map(lambda p, m: m, num_parallel_calls=AUTOTUNE) | |
| # Get a chunk-specific RNG key. | |
| chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk) | |
| chunk_path = os.path.join(tmp_dir, f'{output_fname}-chunk{chunk:05}') | |
| if gfile.exists(chunk_path + '.COMPLETED') and not checkpoint_ds_iter: | |
| logging.info('Skipping chunk %s. Chunk file already exists.', chunk) | |
| continue | |
| logging.info('Running inference on %d batches.', checkpoint_period) | |
| inferences = _extract_tokens_and_aux_values( | |
| infer_fn(model_ds.enumerate(), rng=chunk_rng)) | |
| if jax.process_index() == 0: | |
| chunk_time = time.time() - chunk_tick | |
| logging.info('chunk completed in %02f seconds (%02f examples/sec).', | |
| chunk_time, | |
| len(inferences) / chunk_time) | |
| update_measurement_series('inference_total_sec', chunk, chunk_time) | |
| update_measurement_series('inference_examples_per_sec', chunk, | |
| len(inferences) / chunk_time) | |
| chunk_ckpt_path = None | |
| if checkpoint_ds_iter: | |
| # Store iterator checkpoint in temporary location before writing the | |
| # model output asynchronously. After outputs are written, the | |
| # checkpoint will be moved to the canonical location to be used if | |
| # restart occurs. | |
| ckpt_tick = time.time() | |
| chunk_ckpt_path = input_ckpt.write( | |
| os.path.join(tmp_dir, f'{chunk}.ckpt')) | |
| logging.info( | |
| 'Checkpoint written to temporary location in %02f seconds.', | |
| time.time() - ckpt_tick) | |
| # These will execute sequentially since the ThreadPool size is 1. | |
| write_thread_pool.submit( | |
| _write_chunk_and_canonicalize_ckpt, | |
| chunk=chunk, | |
| chunk_path=chunk_path, | |
| inferences=inferences, | |
| task_ds=task_ds, | |
| chunk_ckpt_path=chunk_ckpt_path) | |
| # Wait for checkpoint to be written before continuing. | |
| multihost_utils.sync_global_devices( | |
| f'{task.name}:checkpoint_chunk{chunk:05}') | |
| logging.info("Finished inference for task '%s'.", task.name) | |
| logging.info('Waiting for chunk writes to complete.') | |
| write_thread_pool.shutdown(wait=True) | |
| if jax.process_index() == 0 and merge_chunked_results: | |
| step = None if train_state is None else int(train_state.step) | |
| merge_fn(output_dir, output_fname, tmp_dir, step) | |
| logging.info('Deleting temporary files.') | |
| gfile.rmtree(tmp_dir) | |
| # Wait for host 0 to finish writing before exiting. | |
| multihost_utils.sync_global_devices(f'{task.name}:complete') | |
| for task in seqio.get_subtasks(task_or_mixture): | |
| logging.info("Starting inference for task '%s'", task.name) | |
| infer_task(task) | |
| logging.info('DONE') | |
| def update_measurement_series(series_name: str, step: int, value: float): | |
| """Not implemented externally.""" | |
| del series_name, step, value | |
| if __name__ == '__main__': | |
| # pylint:disable=g-import-not-at-top | |
| from absl import app | |
| from absl import flags | |
| import gin | |
| # pylint:enable=g-import-not-at-top | |
| FLAGS = flags.FLAGS | |
| jax.config.parse_flags_with_absl() | |
| flags.DEFINE_integer( | |
| 'shard_id', | |
| default=None, | |
| help='Index to use for splitting the Task across multiple inference ' | |
| 'runs. NB: If set, this overrides --gin.infer.shard_id') | |
| flags.DEFINE_multi_string( | |
| 'gin_file', | |
| default=None, | |
| help='Path to gin configuration file. Multiple paths may be passed and ' | |
| 'will be imported in the given order, with later configurations ' | |
| 'overriding earlier ones.') | |
| flags.DEFINE_multi_string( | |
| 'gin_bindings', default=[], help='Individual gin bindings.') | |
| flags.DEFINE_list( | |
| 'gin_search_paths', | |
| default=['.'], | |
| help='Comma-separated list of gin config path prefixes to be prepended ' | |
| 'to suffixes given via `--gin_file`. If a file appears in. Only the ' | |
| 'first prefix that produces a valid path for each suffix will be ' | |
| 'used.') | |
| flags.DEFINE_string( | |
| 'tfds_data_dir', None, | |
| 'If set, this directory will be used to store datasets prepared by ' | |
| 'TensorFlow Datasets that are not available in the public TFDS GCS ' | |
| 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' | |
| 'all `Task`s.') | |
| def main(argv: Sequence[str]): | |
| """Wrapper for pdb post mortems.""" | |
| _main(argv) | |
| def _main(argv: Sequence[str]): | |
| """True main function.""" | |
| if len(argv) > 1: | |
| raise app.UsageError('Too many command-line arguments.') | |
| if FLAGS.tfds_data_dir: | |
| seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) | |
| # Create gin-configurable version of `infer`. | |
| infer_using_gin = gin.configurable(infer) | |
| gin_utils.parse_gin_flags( | |
| # User-provided gin paths take precedence if relative paths conflict. | |
| FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, | |
| FLAGS.gin_file, | |
| FLAGS.gin_bindings) | |
| # See http://yaqs/7882016229479677952 for further gin-config discussion. | |
| def _get_gin_parameter(key: str) -> Any: | |
| value = gin.query_parameter(key) | |
| if isinstance(value, gin.config.ConfigurableReference): | |
| if value.evaluate: | |
| return value.scoped_configurable_fn() | |
| return value.scoped_configurable_fn | |
| return value | |
| shard_id = ( | |
| FLAGS.shard_id | |
| if FLAGS.shard_id is not None else _get_gin_parameter('infer.shard_id')) | |
| if shard_id == 0: | |
| gin_utils.summarize_gin_config( | |
| model_dir=_get_gin_parameter('infer.output_dir'), | |
| summary_writer=None, | |
| step=0) | |
| if FLAGS.shard_id is not None: | |
| # We fall back to this flag since XM does not support sweeps over flags | |
| # with '.' in them (it treats them like nested dictionaries). | |
| # TODO(adarob): Figure out a workaround so we can deprecate this flag. | |
| infer_using_gin(shard_id=FLAGS.shard_id) | |
| else: | |
| infer_using_gin() | |
| gin_utils.run(main) | |