Spaces:
Build error
Build error
| # Copyright 2022 The MT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Model vocabulary.""" | |
| import dataclasses | |
| import math | |
| from typing import Callable, Optional, Sequence | |
| from mt3 import event_codec | |
| import note_seq | |
| import seqio | |
| import t5.data | |
| import tensorflow as tf | |
| DECODED_EOS_ID = -1 | |
| DECODED_INVALID_ID = -2 | |
| # defaults for vocabulary config | |
| DEFAULT_STEPS_PER_SECOND = 100 | |
| DEFAULT_MAX_SHIFT_SECONDS = 10 | |
| DEFAULT_NUM_VELOCITY_BINS = 127 | |
| class VocabularyConfig: | |
| """Vocabulary configuration parameters.""" | |
| steps_per_second: int = DEFAULT_STEPS_PER_SECOND | |
| max_shift_seconds: int = DEFAULT_MAX_SHIFT_SECONDS | |
| num_velocity_bins: int = DEFAULT_NUM_VELOCITY_BINS | |
| def abbrev_str(self): | |
| s = '' | |
| if self.steps_per_second != DEFAULT_STEPS_PER_SECOND: | |
| s += 'ss%d' % self.steps_per_second | |
| if self.max_shift_seconds != DEFAULT_MAX_SHIFT_SECONDS: | |
| s += 'ms%d' % self.max_shift_seconds | |
| if self.num_velocity_bins != DEFAULT_NUM_VELOCITY_BINS: | |
| s += 'vb%d' % self.num_velocity_bins | |
| return s | |
| def num_velocity_bins_from_codec(codec: event_codec.Codec): | |
| """Get number of velocity bins from event codec.""" | |
| lo, hi = codec.event_type_range('velocity') | |
| return hi - lo | |
| def velocity_to_bin(velocity, num_velocity_bins): | |
| if velocity == 0: | |
| return 0 | |
| else: | |
| return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) | |
| def bin_to_velocity(velocity_bin, num_velocity_bins): | |
| if velocity_bin == 0: | |
| return 0 | |
| else: | |
| return int(note_seq.MAX_MIDI_VELOCITY * velocity_bin / num_velocity_bins) | |
| def drop_programs(tokens, codec: event_codec.Codec): | |
| """Drops program change events from a token sequence.""" | |
| min_program_id, max_program_id = codec.event_type_range('program') | |
| return tokens[(tokens < min_program_id) | (tokens > max_program_id)] | |
| def programs_to_midi_classes(tokens, codec): | |
| """Modifies program events to be the first program in the MIDI class.""" | |
| min_program_id, max_program_id = codec.event_type_range('program') | |
| is_program = (tokens >= min_program_id) & (tokens <= max_program_id) | |
| return tf.where( | |
| is_program, | |
| min_program_id + 8 * ((tokens - min_program_id) // 8), | |
| tokens) | |
| class ProgramGranularity: | |
| # both tokens_map_fn and program_map_fn should be idempotent | |
| tokens_map_fn: Callable[[Sequence[int], event_codec.Codec], Sequence[int]] | |
| program_map_fn: Callable[[int], int] | |
| PROGRAM_GRANULARITIES = { | |
| # "flat" granularity; drop program change tokens and set NoteSequence | |
| # programs to zero | |
| 'flat': ProgramGranularity( | |
| tokens_map_fn=drop_programs, | |
| program_map_fn=lambda program: 0), | |
| # map each program to the first program in its MIDI class | |
| 'midi_class': ProgramGranularity( | |
| tokens_map_fn=programs_to_midi_classes, | |
| program_map_fn=lambda program: 8 * (program // 8)), | |
| # leave programs as is | |
| 'full': ProgramGranularity( | |
| tokens_map_fn=lambda tokens, codec: tokens, | |
| program_map_fn=lambda program: program) | |
| } | |
| def build_codec(vocab_config: VocabularyConfig): | |
| """Build event codec.""" | |
| event_ranges = [ | |
| event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH, | |
| note_seq.MAX_MIDI_PITCH), | |
| # velocity bin 0 is used for note-off | |
| event_codec.EventRange('velocity', 0, vocab_config.num_velocity_bins), | |
| # used to indicate that a pitch is present at the beginning of a segment | |
| # (only has an "off" event as when using ties all pitch events until the | |
| # "tie" event belong to the tie section) | |
| event_codec.EventRange('tie', 0, 0), | |
| event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM, | |
| note_seq.MAX_MIDI_PROGRAM), | |
| event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH, | |
| note_seq.MAX_MIDI_PITCH), | |
| ] | |
| return event_codec.Codec( | |
| max_shift_steps=(vocab_config.steps_per_second * | |
| vocab_config.max_shift_seconds), | |
| steps_per_second=vocab_config.steps_per_second, | |
| event_ranges=event_ranges) | |
| def vocabulary_from_codec(codec: event_codec.Codec) -> seqio.Vocabulary: | |
| return GenericTokenVocabulary( | |
| codec.num_classes, extra_ids=t5.data.DEFAULT_EXTRA_IDS) | |
| class GenericTokenVocabulary(seqio.Vocabulary): | |
| """Vocabulary with pass-through encoding of tokens.""" | |
| def __init__(self, regular_ids: int, extra_ids: int = 0): | |
| # The special tokens: 0=PAD, 1=EOS, and 2=UNK | |
| self._num_special_tokens = 3 | |
| self._num_regular_tokens = regular_ids | |
| super().__init__(extra_ids=extra_ids) | |
| def eos_id(self) -> Optional[int]: | |
| return 1 | |
| def unk_id(self) -> Optional[int]: | |
| return 2 | |
| def _base_vocab_size(self) -> int: | |
| """Number of ids. | |
| Returns: | |
| an integer, the vocabulary size | |
| """ | |
| return self._num_special_tokens + self._num_regular_tokens | |
| def _encode(self, token_ids: Sequence[int]) -> Sequence[int]: | |
| """Encode a list of tokens ids as a list of integers. | |
| To keep the first few ids for special tokens, increase ids by the number | |
| of special tokens. | |
| Args: | |
| token_ids: array of token ids. | |
| Returns: | |
| a list of integers (not terminated by EOS) | |
| """ | |
| encoded = [] | |
| for token_id in token_ids: | |
| if not 0 <= token_id < self._num_regular_tokens: | |
| raise ValueError( | |
| f'token_id {token_id} does not fall within valid range of ' | |
| f'[0, {self._num_regular_tokens})') | |
| encoded.append(token_id + self._num_special_tokens) | |
| return encoded | |
| def _decode(self, ids: Sequence[int]) -> Sequence[int]: | |
| """Decode a list of integers to a list of token ids. | |
| The special tokens of PAD and UNK as well as extra_ids will be | |
| replaced with DECODED_INVALID_ID in the output. If EOS is present, it will | |
| be the final token in the decoded output and will be represented by | |
| DECODED_EOS_ID. | |
| Args: | |
| ids: a list of integers | |
| Returns: | |
| a list of token ids. | |
| """ | |
| # convert all the extra ids to INVALID_ID | |
| def _decode_id(encoded_id): | |
| if encoded_id == self.eos_id: | |
| return DECODED_EOS_ID | |
| elif encoded_id < self._num_special_tokens: | |
| return DECODED_INVALID_ID | |
| elif encoded_id >= self._base_vocab_size: | |
| return DECODED_INVALID_ID | |
| else: | |
| return encoded_id - self._num_special_tokens | |
| ids = [_decode_id(int(i)) for i in ids] | |
| return ids | |
| def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor: | |
| """Encode a list of tokens to a tf.Tensor. | |
| Args: | |
| token_ids: array of audio token ids. | |
| Returns: | |
| a 1d tf.Tensor with dtype tf.int32 | |
| """ | |
| with tf.control_dependencies( | |
| [tf.debugging.assert_less( | |
| token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)), | |
| tf.debugging.assert_greater_equal( | |
| token_ids, tf.cast(0, token_ids.dtype)) | |
| ]): | |
| tf_ids = token_ids + self._num_special_tokens | |
| return tf_ids | |
| def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: | |
| """Decode in TensorFlow. | |
| The special tokens of PAD and UNK as well as extra_ids will be | |
| replaced with DECODED_INVALID_ID in the output. If EOS is present, it and | |
| all following tokens in the decoded output and will be represented by | |
| DECODED_EOS_ID. | |
| Args: | |
| ids: a 1d tf.Tensor with dtype tf.int32 | |
| Returns: | |
| a 1d tf.Tensor with dtype tf.int32 | |
| """ | |
| # Create a mask that is true from the first EOS position onward. | |
| # First, create an array that is True whenever there is an EOS, then cumsum | |
| # that array so that every position after and including the first True is | |
| # >1, then cast back to bool for the final mask. | |
| eos_and_after = tf.cumsum( | |
| tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1) | |
| eos_and_after = tf.cast(eos_and_after, tf.bool) | |
| return tf.where( | |
| eos_and_after, | |
| DECODED_EOS_ID, | |
| tf.where( | |
| tf.logical_and( | |
| tf.greater_equal(ids, self._num_special_tokens), | |
| tf.less(ids, self._base_vocab_size)), | |
| ids - self._num_special_tokens, | |
| DECODED_INVALID_ID)) | |
| def __eq__(self, other): | |
| their_extra_ids = other.extra_ids | |
| their_num_regular_tokens = other._num_regular_tokens | |
| return (self.extra_ids == their_extra_ids and | |
| self._num_regular_tokens == their_num_regular_tokens) | |
| def num_embeddings(vocabulary: GenericTokenVocabulary) -> int: | |
| """Vocabulary size as a multiple of 128 for TPU efficiency.""" | |
| return 128 * math.ceil(vocabulary.vocab_size / 128) | |