Spaces:
Build error
Build error
| # Defaults for training with train.py. | |
| # | |
| # You must also include a binding for MODEL. | |
| # | |
| # Required to be set: | |
| # | |
| # - TASK_PREFIX | |
| # - TASK_FEATURE_LENGTHS | |
| # - TRAIN_STEPS | |
| # - MODEL_DIR | |
| # | |
| # Commonly overridden options: | |
| # - BATCH_SIZE | |
| # - PjitPartitioner.num_partitions | |
| # - Trainer.num_microbatches | |
| # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess | |
| # on the fly. | |
| from __gin__ import dynamic_registration | |
| import __main__ as train_script | |
| import seqio | |
| from mt3 import mixing | |
| from mt3 import preprocessors | |
| from mt3 import tasks | |
| from mt3 import vocabularies | |
| from t5x import gin_utils | |
| from t5x import partitioning | |
| from t5x import utils | |
| from t5x import trainer | |
| # Must be overridden | |
| TASK_PREFIX = %gin.REQUIRED | |
| TASK_FEATURE_LENGTHS = %gin.REQUIRED | |
| TRAIN_STEPS = %gin.REQUIRED | |
| MODEL_DIR = %gin.REQUIRED | |
| # Commonly overridden | |
| TRAIN_TASK_SUFFIX = 'train' | |
| EVAL_TASK_SUFFIX = 'eval' | |
| USE_CACHED_TASKS = True | |
| BATCH_SIZE = 256 | |
| # Sometimes overridden | |
| EVAL_STEPS = 20 | |
| # Convenience overrides. | |
| EVALUATOR_USE_MEMORY_CACHE = True | |
| EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. | |
| JSON_WRITE_N_RESULTS = 0 # Don't write any inferences. | |
| # Number of velocity bins: set to 1 (no velocity) or 127 | |
| NUM_VELOCITY_BINS = %gin.REQUIRED | |
| VOCAB_CONFIG = @vocabularies.VocabularyConfig() | |
| vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS | |
| # Program granularity: set to 'flat', 'midi_class', or 'full' | |
| PROGRAM_GRANULARITY = %gin.REQUIRED | |
| preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY | |
| # Maximum number of examples per mix, or None for no mixing | |
| MAX_EXAMPLES_PER_MIX = None | |
| mixing.mix_transcription_examples.max_examples_per_mix = %MAX_EXAMPLES_PER_MIX | |
| train/tasks.construct_task_name: | |
| task_prefix = %TASK_PREFIX | |
| vocab_config = %VOCAB_CONFIG | |
| task_suffix = %TRAIN_TASK_SUFFIX | |
| eval/tasks.construct_task_name: | |
| task_prefix = %TASK_PREFIX | |
| vocab_config = %VOCAB_CONFIG | |
| task_suffix = %EVAL_TASK_SUFFIX | |
| train_script.train: | |
| model = %MODEL # imported from separate gin file | |
| model_dir = %MODEL_DIR | |
| train_dataset_cfg = @train/utils.DatasetConfig() | |
| train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() | |
| infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() | |
| checkpoint_cfg = @utils.CheckpointConfig() | |
| partitioner = @partitioning.PjitPartitioner() | |
| trainer_cls = @trainer.Trainer | |
| total_steps = %TRAIN_STEPS | |
| eval_steps = %EVAL_STEPS | |
| eval_period = 5000 | |
| random_seed = None # use faster, hardware RNG | |
| summarize_config_fn = @gin_utils.summarize_gin_config | |
| inference_evaluator_cls = @seqio.Evaluator | |
| seqio.Evaluator: | |
| logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] | |
| num_examples = %EVALUATOR_NUM_EXAMPLES | |
| use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE | |
| seqio.JSONLogger: | |
| write_n_results = %JSON_WRITE_N_RESULTS | |
| train/utils.DatasetConfig: | |
| mixture_or_task_name = @train/tasks.construct_task_name() | |
| task_feature_lengths = %TASK_FEATURE_LENGTHS | |
| split = 'train' | |
| batch_size = %BATCH_SIZE | |
| shuffle = True | |
| seed = None # use a new seed each run/restart | |
| use_cached = %USE_CACHED_TASKS | |
| pack = False | |
| train_eval/utils.DatasetConfig: | |
| mixture_or_task_name = @train/tasks.construct_task_name() | |
| task_feature_lengths = %TASK_FEATURE_LENGTHS | |
| split = 'eval' | |
| batch_size = %BATCH_SIZE | |
| shuffle = False | |
| seed = 42 | |
| use_cached = %USE_CACHED_TASKS | |
| pack = False | |
| infer_eval/utils.DatasetConfig: | |
| mixture_or_task_name = @eval/tasks.construct_task_name() | |
| task_feature_lengths = %TASK_FEATURE_LENGTHS | |
| split = 'eval' | |
| batch_size = %BATCH_SIZE | |
| shuffle = False | |
| seed = 42 | |
| use_cached = %USE_CACHED_TASKS | |
| pack = False | |
| utils.CheckpointConfig: | |
| restore = None | |
| save = @utils.SaveCheckpointConfig() | |
| utils.SaveCheckpointConfig: | |
| period = 5000 | |
| dtype = 'float32' | |
| keep = None # keep all checkpoints | |
| save_dataset = False # don't checkpoint dataset state | |
| partitioning.PjitPartitioner: | |
| num_partitions = 1 | |
| model_parallel_submesh = None | |
| trainer.Trainer: | |
| num_microbatches = None | |
| learning_rate_fn = @utils.create_learning_rate_scheduler() | |
| utils.create_learning_rate_scheduler: | |
| factors = 'constant' | |
| base_learning_rate = 0.001 | |
| warmup_steps = 1000 | |