Spaces:
Build error
Build error
| # Defaults for infer.py. | |
| # | |
| # You must also include a binding for MODEL. | |
| # | |
| # Required to be set: | |
| # | |
| # - TASK_PREFIX | |
| # - TASK_FEATURE_LENGTHS | |
| # - CHECKPOINT_PATH | |
| # - INFER_OUTPUT_DIR | |
| # | |
| # Commonly overridden options: | |
| # | |
| # - infer.mode | |
| # - infer.checkpoint_period | |
| # - infer.shard_id | |
| # - infer.num_shards | |
| # - DatasetConfig.split | |
| # - DatasetConfig.batch_size | |
| # - DatasetConfig.use_cached | |
| # - RestoreCheckpointConfig.is_tensorflow | |
| # - RestoreCheckpointConfig.mode | |
| # - PjitPartitioner.num_partitions | |
| from __gin__ import dynamic_registration | |
| import __main__ as infer_script | |
| from mt3 import inference | |
| from mt3 import preprocessors | |
| from mt3 import tasks | |
| from mt3 import vocabularies | |
| from t5x import partitioning | |
| from t5x import utils | |
| # Must be overridden | |
| TASK_PREFIX = %gin.REQUIRED | |
| TASK_FEATURE_LENGTHS = %gin.REQUIRED | |
| CHECKPOINT_PATH = %gin.REQUIRED | |
| INFER_OUTPUT_DIR = %gin.REQUIRED | |
| # Number of velocity bins: set to 1 (no velocity) or 127 | |
| NUM_VELOCITY_BINS = %gin.REQUIRED | |
| VOCAB_CONFIG = @vocabularies.VocabularyConfig() | |
| vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS | |
| # Program granularity: set to 'flat', 'midi_class', or 'full' | |
| PROGRAM_GRANULARITY = %gin.REQUIRED | |
| preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY | |
| TASK_SUFFIX = 'test' | |
| tasks.construct_task_name: | |
| task_prefix = %TASK_PREFIX | |
| vocab_config = %VOCAB_CONFIG | |
| task_suffix = %TASK_SUFFIX | |
| ONSETS_ONLY = %gin.REQUIRED | |
| USE_TIES = %gin.REQUIRED | |
| inference.write_inferences_to_file: | |
| vocab_config = %VOCAB_CONFIG | |
| onsets_only = %ONSETS_ONLY | |
| use_ties = %USE_TIES | |
| infer_script.infer: | |
| mode = 'predict' | |
| model = %MODEL # imported from separate gin file | |
| output_dir = %INFER_OUTPUT_DIR | |
| dataset_cfg = @utils.DatasetConfig() | |
| partitioner = @partitioning.PjitPartitioner() | |
| restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() | |
| # This is a hack, but pass an extremely large value here to make sure the | |
| # entire dataset fits in a single epoch. Otherwise, segments from a single | |
| # example may end up in different epochs after splitting. | |
| checkpoint_period = 1000000 | |
| shard_id = 0 | |
| num_shards = 1 | |
| write_fn = @inference.write_inferences_to_file | |
| utils.DatasetConfig: | |
| mixture_or_task_name = @tasks.construct_task_name() | |
| task_feature_lengths = %TASK_FEATURE_LENGTHS | |
| use_cached = True | |
| split = 'eval' | |
| batch_size = 32 | |
| shuffle = False | |
| seed = 0 | |
| pack = False | |
| partitioning.PjitPartitioner.num_partitions = 1 | |
| utils.RestoreCheckpointConfig: | |
| path = %CHECKPOINT_PATH | |
| mode = 'specific' | |