diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..065822bf4f12d03eb850ecdb5750d3bd9de0623a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,42 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_01.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_noflow_01.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_w_cfm_chorus.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_w_cfm_intro.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_w_cfm_verse_ras.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_w_cfm_verse_topk.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_w_cfm_verse.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_wo_cfm_chorus.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_wo_cfm_intro.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_wo_cfm_verse_topk.wav filter=lfs diff=lfs merge=lfs -text +example/inspiremusic/inspiremusic_wo_cfm_verse.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_01.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_02.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_03.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_04.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_05.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_06.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_07.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_08.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_09.wav filter=lfs diff=lfs merge=lfs -text +example/ras/chorus/chorus_10.wav filter=lfs diff=lfs merge=lfs -text +example/ras/intro/intro_01.wav filter=lfs diff=lfs merge=lfs -text +example/ras/intro/intro_02.wav filter=lfs diff=lfs merge=lfs -text +example/ras/intro/intro_03.wav filter=lfs diff=lfs merge=lfs -text +example/ras/intro/intro_04.wav filter=lfs diff=lfs merge=lfs -text +example/ras/intro/intro_05.wav filter=lfs diff=lfs merge=lfs -text +example/ras/intro/intro_06.wav filter=lfs diff=lfs merge=lfs -text +example/ras/outro/outro_01.wav filter=lfs diff=lfs merge=lfs -text +example/ras/outro/outro_02.wav filter=lfs diff=lfs merge=lfs -text +example/ras/outro/outro_03.wav filter=lfs diff=lfs merge=lfs -text +example/ras/outro/outro_04.wav filter=lfs diff=lfs merge=lfs -text +example/ras/verse/verse_01.wav filter=lfs diff=lfs merge=lfs -text +example/ras/verse/verse_02.wav filter=lfs diff=lfs merge=lfs -text +example/ras/verse/verse_03.wav filter=lfs diff=lfs merge=lfs -text +example/ras/verse/verse_04.wav filter=lfs diff=lfs merge=lfs -text +example/ras/verse/verse_05.wav filter=lfs diff=lfs merge=lfs -text +example/ras/verse/verse_06.wav filter=lfs diff=lfs merge=lfs -text +example/ras/verse/verse_07.wav filter=lfs diff=lfs merge=lfs -text +example/ras/verse/verse_08.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..3d1b157d53a4d636737eefb44d513fa9ec0515d5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/Matcha-TTS"] + path = third_party/Matcha-TTS + url = https://github.com/shivammehta25/Matcha-TTS.git diff --git a/README.md b/README.md index bb2209425210ca69bac075e47d676ee7b01a3bd6..af752cb142fc070c2418e6a75bedb85b58767139 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,13 @@ --- title: InspireMusic -emoji: 🏃 -colorFrom: blue -colorTo: blue +emoji: 🎶 +colorFrom: indigo +colorTo: purple sdk: gradio -sdk_version: 5.23.1 app_file: app.py pinned: false license: apache-2.0 -short_description: InspireMusic +short_description: Music Generation - text to music, music continuation. --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..49deeebbf1b51f56fa75546b5d4fff2c954acf0f 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,240 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Chong Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.system('nvidia-smi') +os.system('apt update -y && apt-get install -y apt-utils && apt install -y unzip') +os.environ['PYTHONPATH'] = 'third_party/Matcha-TTS' +os.system('mkdir pretrained_models && cd pretrained_models && git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base-24kHz.git && for i in InspireMusic-Base InspireMusic-Base-24kHz InspireMusic-1.5B InspireMusic-1.5B-24kHz InspireMusic-1.5B-Long; do sed -i -e "s/\.\.\/\.\.\///g" ${i}/inspiremusic.yaml; done && cd ..') + +import sys +import torch +print(torch.backends.cudnn.version()) + +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR)) + +import spaces import gradio as gr +from inspiremusic.cli.inference import InspireMusicUnified, set_env_variables +import torchaudio +import datetime +import hashlib +import threading +import time +import importlib + +MODELS = ["InspireMusic-1.5B-Long", "InspireMusic-1.5B", "InspireMusic-Base", "InspireMusic-1.5B-24kHz", "InspireMusic-Base-24kHz"] +AUDIO_PROMPT_DIR = "demo/audio_prompts" +OUTPUT_AUDIO_DIR = "demo/outputs" + +DEMO_TEXT_PROMPTS = ["Jazz music with drum beats.", + "A captivating classical piano performance, this piece exudes a dynamic and intense atmosphere, showcasing intricate and expressive instrumental artistry.", + "A soothing instrumental piece blending elements of light music and pop, featuring a gentle guitar rendition. The overall feel is serene and reflective, likely instrumental with no vocals.", + "The instrumental rock piece features dynamic oscillations and wave-like progressions, creating an immersive and energetic atmosphere. The music is purely instrumental, with no vocals, and it blends elements of rock and post-rock for a powerful and evocative experience.", + "The classical instrumental piece exudes a haunting and evocative atmosphere, characterized by its intricate guitar work and profound emotional depth.", + "Experience a dynamic blend of instrumental electronic music with futuristic house vibes, featuring energetic beats and a captivating rhythm. The tracks are likely instrumental, focusing on the immersive soundscapes rather than vocal performances."] + +# Shared flag to control the process +stop_flag = threading.Event() + +def cancel_process(): + """ + Sets the stop_flag to stop the long-running process. + """ + stop_flag.set() + return "Cancellation requested. Please wait for the process to stop." + +def generate_filename(): + hash_object = hashlib.sha256(str(int(datetime.datetime.now().timestamp())).encode()) + hash_string = hash_object.hexdigest() + return hash_string + +def get_args( + task, text="", audio=None, model_name="InspireMusic-Base", + chorus="intro", + output_sample_rate=48000, max_generate_audio_seconds=30.0, time_start = 0.0, time_end=30.0, trim=False): + + if "24kHz" in model_name: + output_sample_rate = 24000 + + if output_sample_rate == 24000: + fast = True + else: + fast = False + # This function constructs the arguments required for InspireMusic + args = { + "task" : task, + "text" : text, + "audio_prompt" : audio, + "model_name" : model_name, + "chorus" : chorus, + "fast" : fast, + "fade_out" : True, + "trim" : trim, + "output_sample_rate" : output_sample_rate, + "min_generate_audio_seconds": 10.0, + "max_generate_audio_seconds": max_generate_audio_seconds, + "max_audio_prompt_length": 5.0, + "model_dir" : os.path.join("pretrained_models", + model_name), + "result_dir" : OUTPUT_AUDIO_DIR, + "output_fn" : generate_filename(), + "format" : "wav", + "time_start" : time_start, + "time_end": time_end, + "fade_out_duration": 1.0, + } + + if args["time_start"] is None: + args["time_start"] = 0.0 + args["time_end"] = args["time_start"] + args["max_generate_audio_seconds"] + + print(args) + return args + + +def trim_audio(audio_file, cut_seconds=5): + audio, sr = torchaudio.load(audio_file) + num_samples = cut_seconds * sr + cutted_audio = audio[:, :num_samples] + output_path = os.path.join(AUDIO_PROMPT_DIR, "audio_prompt_" + generate_filename() + ".wav") + torchaudio.save(output_path, cutted_audio, sr) + return output_path + +@spaces.GPU(duration=120) +def music_generation(args): + set_env_variables() + model = InspireMusicUnified( + model_name=args["model_name"], + model_dir=args["model_dir"], + min_generate_audio_seconds=args["min_generate_audio_seconds"], + max_generate_audio_seconds=args["max_generate_audio_seconds"], + sample_rate=24000, + output_sample_rate=args["output_sample_rate"], + load_jit=True, + load_onnx=False, + fast=args["fast"], + result_dir=args["result_dir"]) + + output_path = model.inference( + task=args["task"], + text=args["text"], + audio_prompt=args["audio_prompt"], + chorus=args["chorus"], + time_start=args["time_start"], + time_end=args["time_end"], + output_fn=args["output_fn"], + max_audio_prompt_length=args["max_audio_prompt_length"], + fade_out_duration=args["fade_out_duration"], + output_format=args["format"], + fade_out_mode=args["fade_out"], + trim=args["trim"]) + return output_path + + +def demo_inspiremusic_t2m(text, model_name, chorus, + output_sample_rate, max_generate_audio_seconds): + args = get_args( + task='text-to-music', text=text, audio=None, + model_name=model_name, chorus=chorus, + output_sample_rate=output_sample_rate, + max_generate_audio_seconds=max_generate_audio_seconds) + return music_generation(args) + +def demo_inspiremusic_con(text, audio, model_name, chorus, + output_sample_rate, max_generate_audio_seconds): + args = get_args( + task='continuation', text=text, audio=trim_audio(audio, cut_seconds=5), + model_name=model_name, chorus=chorus, + output_sample_rate=output_sample_rate, + max_generate_audio_seconds=max_generate_audio_seconds) + return music_generation(args) + +def process(args, progress=gr.Progress()): + progress(0, desc="Starting process...") + idx = 1 + for i in range(idx): + if stop_flag.is_set(): + progress(i / idx, desc="Process canceled.") + break + music_generation(args) + time.sleep(1) + progress((i + 1) / idx, desc=f"Processing step {i + 1}/{idx}") + return "Process completed successfully." + +def main(): + with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown(""" + # InspireMusic + - Support music generation tasks with long-form and high audio quality, sampling rates up to 48kHz. + - Github: https://github.com/FunAudioLLM/InspireMusic/ | ModelScope Studio: https://modelscope.cn/studios/iic/InspireMusic + - Available music generation models: [InspireMusic-1.5B-Long](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long), [InspireMusic-1.5B](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B), [InspireMusic-Base](https://huggingface.co/FunAudioLLM/InspireMusic-Base), [InspireMusic-1.5B-24kHz](https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz), [InspireMusic-Base-24kHz](https://huggingface.co/FunAudioLLM/InspireMusic-Base-24kHz). Both on Huggingface and ModelScope. + - Currently only support English text prompts. + - This page is for demo purpose, if you want to generate long-form audio, e.g., 5mins, please try to deploy locally. Thank you for your support. + """) + + with gr.Row(equal_height=True): + model_name = gr.Dropdown( + MODELS, label="Select Model Name", + value="InspireMusic-1.5B-Long") + chorus = gr.Dropdown(["intro", "verse", "chorus", "outro"], + label="Chorus Mode", value="intro") + output_sample_rate = gr.Dropdown([48000, 24000], + label="Output Audio Sample Rate (Hz)", + value=48000) + max_generate_audio_seconds = gr.Slider(10, 300, + label="Generate Audio Length (s)", + value=30) + + with gr.Row(equal_height=True): + text_input = gr.Textbox(label="Input Text (For Text-to-Music Task)", + value="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.") + + audio_input = gr.Audio( + label="Input Audio Prompt (For Music Continuation Task)", + type="filepath") + music_output = gr.Audio(label="Generated Music", type="filepath", autoplay=True, show_download_button = True) + + with gr.Row(): + button = gr.Button("Submit Text-to-Music Task") + button.click(demo_inspiremusic_t2m, + inputs=[text_input, model_name, + chorus, + output_sample_rate, + max_generate_audio_seconds], + outputs=music_output) + + generate_button = gr.Button("Submit Music Continuation Task") + generate_button.click(demo_inspiremusic_con, + inputs=[text_input, audio_input, model_name, + chorus, + output_sample_rate, + max_generate_audio_seconds], + outputs=music_output) + cancel_button = gr.Button("Cancel") -def greet(name): - return "Hello " + name + "!!" + cancel_button.click( + fn=cancel_process, + inputs=[], + outputs="Cancel process." + ) + t2m_examples = gr.Examples(examples=DEMO_TEXT_PROMPTS, inputs=[text_input]) + demo.launch() -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() +if __name__ == '__main__': + os.makedirs(AUDIO_PROMPT_DIR, exist_ok=True) + os.makedirs(OUTPUT_AUDIO_DIR, exist_ok=True) + main() diff --git a/example/conf/InspireMusic-1.5B-24kHz.yaml b/example/conf/InspireMusic-1.5B-24kHz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..364ec7879e2d9582f7e162b80150ee615f81696f --- /dev/null +++ b/example/conf/InspireMusic-1.5B-24kHz.yaml @@ -0,0 +1,171 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1024] +__set_seed2: !apply:numpy.random.seed [1024] +__set_seed3: !apply:torch.manual_seed [1024] +__set_seed4: !apply:torch.cuda.manual_seed_all [1024] + +# fixed params +sample_rate: 24000 +text_encoder_input_size: 512 +llm_input_size: 1536 +llm_output_size: 1536 + +basemodel_path: 'pretrained_models/InspireMusic-1.5B-24kHz/' +generator_path: 'pretrained_models/InspireMusic-1.5B-24kHz/music_tokenizer' + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:inspiremusic.llm.llm.LLM + text_encoder_input_size: !ref + llm_input_size: !ref + llm_output_size: !ref + audio_token_size: 4096 + length_normalized_loss: True + lsm_weight: 0 + text_encoder_conf: + name: "none" + llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder + input_size: !ref + pretrain_path: !ref + + sampling: !name:inspiremusic.utils.common.topk_sampling + top_k: 350 + train_cfg_ratio: 0.2 + infer_cfg_ratio: 3.0 +flow: !new:inspiremusic.flow.flow.MaskedDiff + input_size: 256 + output_size: 80 + output_type: 'mel' + vocab_size: 4096 + input_frame_rate: 75 + only_mask_loss: True + encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder + output_size: 512 + attention_heads: 4 + linear_units: 1024 + num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 256 + use_cnn_module: False + macaron_style: False + length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator + channels: 512 + sampling_ratios: [1, 1, 1, 1] + decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM + in_channels: 240 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder + in_channels: 1024 + out_channels: 512 + channels: [256, 256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 8 + num_heads: 8 + act_fn: 'gelu' + generator_model_dir: !ref + +hift: !new:inspiremusic.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 8] + upsample_kernel_sizes: [16, 16] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator + +# processor functions +parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener +get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer + tokenizer_path: !ref + tokenizer_name: "qwen-2.5" +allowed_special: 'all' +tokenize: !name:inspiremusic.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:inspiremusic.dataset.processor.filter + max_length: 28000 + min_length: 0 + token_max_length: 200 + token_min_length: 1 +resample: !name:inspiremusic.dataset.processor.resample + resample_rate: !ref +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 128 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 24000 + center: False +compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank + feat_extractor: !ref +parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding + normalize: True +shuffle: !name:inspiremusic.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:inspiremusic.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:inspiremusic.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 10000 # llm 12000 +padding: !name:inspiremusic.dataset.processor.padding + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + + +# train conf +train_conf: + optim: adam + optim_conf: + lr: 0.0001 # change to 0.001 if you want to train flow from scratch + scheduler: warmuplr + scheduler_conf: + warmup_steps: 5000 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: 500 diff --git a/example/conf/InspireMusic-1.5B-Long.yaml b/example/conf/InspireMusic-1.5B-Long.yaml new file mode 100644 index 0000000000000000000000000000000000000000..805cb04d845929f15c51bf82d582fedd996f4a92 --- /dev/null +++ b/example/conf/InspireMusic-1.5B-Long.yaml @@ -0,0 +1,171 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1988] +__set_seed2: !apply:numpy.random.seed [1988] +__set_seed3: !apply:torch.manual_seed [1988] +__set_seed4: !apply:torch.cuda.manual_seed_all [1988] + +# fixed params +sample_rate: 24000 +text_encoder_input_size: 512 +llm_input_size: 1536 +llm_output_size: 1536 + +basemodel_path: 'pretrained_models/InspireMusic-1.5B-Long/' +generator_path: 'pretrained_models/InspireMusic-1.5B-Long/music_tokenizer' + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:inspiremusic.llm.llm.LLM + text_encoder_input_size: !ref + llm_input_size: !ref + llm_output_size: !ref + audio_token_size: 4096 + length_normalized_loss: True + lsm_weight: 0 + text_encoder_conf: + name: "none" + llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder + input_size: !ref + pretrain_path: !ref + + sampling: !name:inspiremusic.utils.common.topk_sampling + top_k: 350 + train_cfg_ratio: 0.2 + infer_cfg_ratio: 3.0 +flow: !new:inspiremusic.flow.flow.MaskedDiff + input_size: 256 + output_size: 80 + output_type: 'mel' + vocab_size: 4096 + input_frame_rate: 75 + only_mask_loss: True + encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder + output_size: 512 + attention_heads: 4 + linear_units: 1024 + num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 256 + use_cnn_module: False + macaron_style: False + length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator + channels: 512 + sampling_ratios: [1, 1, 1, 1] + decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM + in_channels: 240 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder + in_channels: 1024 + out_channels: 512 + channels: [256, 256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 8 + num_heads: 8 + act_fn: 'gelu' + generator_model_dir: !ref + +hift: !new:inspiremusic.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 8] + upsample_kernel_sizes: [16, 16] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator + +# processor functions +parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener +get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer + tokenizer_path: !ref + tokenizer_name: "qwen-2.5" +allowed_special: 'all' +tokenize: !name:inspiremusic.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:inspiremusic.dataset.processor.filter + max_length: 28000 + min_length: 0 + token_max_length: 200 + token_min_length: 1 +resample: !name:inspiremusic.dataset.processor.resample + resample_rate: !ref +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 128 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 24000 + center: False +compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank + feat_extractor: !ref +parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding + normalize: True +shuffle: !name:inspiremusic.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:inspiremusic.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:inspiremusic.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 10000 # llm 12000 +padding: !name:inspiremusic.dataset.processor.padding + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + + +# train conf +train_conf: + optim: adam + optim_conf: + lr: 0.0001 # change to 0.001 if you want to train flow from scratch + scheduler: warmuplr + scheduler_conf: + warmup_steps: 5000 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: 500 diff --git a/example/conf/InspireMusic-1.5B.yaml b/example/conf/InspireMusic-1.5B.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29f90b49856fd078c55c5232e2c242a639aa3aeb --- /dev/null +++ b/example/conf/InspireMusic-1.5B.yaml @@ -0,0 +1,171 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1988] +__set_seed2: !apply:numpy.random.seed [1988] +__set_seed3: !apply:torch.manual_seed [1988] +__set_seed4: !apply:torch.cuda.manual_seed_all [1988] + +# fixed params +sample_rate: 24000 +text_encoder_input_size: 512 +llm_input_size: 1536 +llm_output_size: 1536 + +basemodel_path: 'pretrained_models/InspireMusic-1.5B/' +generator_path: 'pretrained_models/InspireMusic-1.5B/music_tokenizer' + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:inspiremusic.llm.llm.LLM + text_encoder_input_size: !ref + llm_input_size: !ref + llm_output_size: !ref + audio_token_size: 4096 + length_normalized_loss: True + lsm_weight: 0 + text_encoder_conf: + name: "none" + llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder + input_size: !ref + pretrain_path: !ref + + sampling: !name:inspiremusic.utils.common.topk_sampling + top_k: 350 + train_cfg_ratio: 0.2 + infer_cfg_ratio: 3.0 +flow: !new:inspiremusic.flow.flow.MaskedDiff + input_size: 256 + output_size: 80 + output_type: 'mel' + vocab_size: 4096 + input_frame_rate: 75 + only_mask_loss: True + encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder + output_size: 512 + attention_heads: 4 + linear_units: 1024 + num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 256 + use_cnn_module: False + macaron_style: False + length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator + channels: 512 + sampling_ratios: [1, 1, 1, 1] + decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM + in_channels: 240 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder + in_channels: 1024 + out_channels: 512 + channels: [256, 256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 8 + num_heads: 8 + act_fn: 'gelu' + generator_model_dir: !ref + +hift: !new:inspiremusic.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 8] + upsample_kernel_sizes: [16, 16] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator + +# processor functions +parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener +get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer + tokenizer_path: !ref + tokenizer_name: "qwen-2.5" +allowed_special: 'all' +tokenize: !name:inspiremusic.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:inspiremusic.dataset.processor.filter + max_length: 28000 + min_length: 0 + token_max_length: 200 + token_min_length: 1 +resample: !name:inspiremusic.dataset.processor.resample + resample_rate: !ref +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 128 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 24000 + center: False +compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank + feat_extractor: !ref +parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding + normalize: True +shuffle: !name:inspiremusic.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:inspiremusic.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:inspiremusic.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 10000 # llm 12000 +padding: !name:inspiremusic.dataset.processor.padding + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + + +# train conf +train_conf: + optim: adam + optim_conf: + lr: 0.0001 # change to 0.001 if you want to train flow from scratch + scheduler: warmuplr + scheduler_conf: + warmup_steps: 5000 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: 500 diff --git a/example/conf/InspireMusic-Base-24kHz.yaml b/example/conf/InspireMusic-Base-24kHz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34c4daf69132deb173b9ef48dd40289e5b8a3f6c --- /dev/null +++ b/example/conf/InspireMusic-Base-24kHz.yaml @@ -0,0 +1,171 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1024] +__set_seed2: !apply:numpy.random.seed [1024] +__set_seed3: !apply:torch.manual_seed [1024] +__set_seed4: !apply:torch.cuda.manual_seed_all [1024] + +# fixed params +sample_rate: 24000 +text_encoder_input_size: 512 +llm_input_size: 896 +llm_output_size: 896 + +basemodel_path: 'pretrained_models/InspireMusic-Base-24kHz/' +generator_path: 'pretrained_models/InspireMusic-Base-24kHz/music_tokenizer' + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:inspiremusic.llm.llm.LLM + text_encoder_input_size: !ref + llm_input_size: !ref + llm_output_size: !ref + audio_token_size: 4096 + length_normalized_loss: True + lsm_weight: 0 + text_encoder_conf: + name: "none" + llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder + input_size: !ref + pretrain_path: !ref + + sampling: !name:inspiremusic.utils.common.topk_sampling + top_k: 350 + train_cfg_ratio: 0.2 + infer_cfg_ratio: 7.0 +flow: !new:inspiremusic.flow.flow.MaskedDiff + input_size: 256 + output_size: 80 + output_type: 'mel' + vocab_size: 4096 + input_frame_rate: 75 + only_mask_loss: True + encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder + output_size: 512 + attention_heads: 4 + linear_units: 1024 + num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 256 + use_cnn_module: False + macaron_style: False + length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator + channels: 512 + sampling_ratios: [1, 1, 1, 1] + decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM + in_channels: 240 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder + in_channels: 1024 + out_channels: 512 + channels: [256, 256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 8 + num_heads: 8 + act_fn: 'gelu' + generator_model_dir: !ref + +hift: !new:inspiremusic.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 8] + upsample_kernel_sizes: [16, 16] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator + +# processor functions +parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener +get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer + tokenizer_path: !ref + tokenizer_name: "qwen-2.0" +allowed_special: 'all' +tokenize: !name:inspiremusic.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:inspiremusic.dataset.processor.filter + max_length: 28000 + min_length: 0 + token_max_length: 200 + token_min_length: 1 +resample: !name:inspiremusic.dataset.processor.resample + resample_rate: !ref +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 128 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 24000 + center: False +compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank + feat_extractor: !ref +parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding + normalize: True +shuffle: !name:inspiremusic.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:inspiremusic.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:inspiremusic.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 10000 # llm 12000 +padding: !name:inspiremusic.dataset.processor.padding + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + + +# train conf +train_conf: + optim: adam + optim_conf: + lr: 0.0001 # change to 0.001 if you want to train flow from scratch + scheduler: warmuplr + scheduler_conf: + warmup_steps: 5000 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: 500 diff --git a/example/conf/InspireMusic-Base.yaml b/example/conf/InspireMusic-Base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1b4231202abc3e1e9deb57c1e2a723138438aaa4 --- /dev/null +++ b/example/conf/InspireMusic-Base.yaml @@ -0,0 +1,180 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1024] +__set_seed2: !apply:numpy.random.seed [1024] +__set_seed3: !apply:torch.manual_seed [1024] +__set_seed4: !apply:torch.cuda.manual_seed_all [1024] + +# fixed params +sample_rate: 24000 +target_sample_rate: 48000 +text_encoder_input_size: 512 +llm_input_size: 896 +llm_output_size: 896 + +basemodel_path: 'pretrained_models/InspireMusic-Base/' +generator_path: 'pretrained_models/InspireMusic-Base/music_tokenizer' + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:inspiremusic.llm.llm.LLM + text_encoder_input_size: !ref + llm_input_size: !ref + llm_output_size: !ref + audio_token_size: 4096 + length_normalized_loss: True + lsm_weight: 0 + text_encoder_conf: + name: "none" + llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder + input_size: !ref + pretrain_path: !ref + + sampling: !name:inspiremusic.utils.common.topk_sampling + top_k: 350 + train_cfg_ratio: 0.2 + infer_cfg_ratio: 3.0 +flow: !new:inspiremusic.flow.flow.MaskedDiff + input_size: 256 + output_size: 80 + output_type: 'mel' + vocab_size: 4096 + input_frame_rate: 75 + only_mask_loss: True + encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder + output_size: 512 + attention_heads: 4 + linear_units: 1024 + num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 256 + use_cnn_module: False + macaron_style: False + length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator + channels: 512 + sampling_ratios: [1, 1, 1, 1] + decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM + in_channels: 240 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder + in_channels: 1024 + out_channels: 512 + channels: [256, 256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 8 + num_heads: 8 + act_fn: 'gelu' + generator_model_dir: !ref + +hift: !new:inspiremusic.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 8] + upsample_kernel_sizes: [16, 16] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator + +# processor functions +parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener +get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer + tokenizer_path: !ref + tokenizer_name: "qwen-2.0" +allowed_special: 'all' +tokenize: !name:inspiremusic.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:inspiremusic.dataset.processor.filter + max_length: 20000 + min_length: 1 + token_max_length: 200 + token_min_length: 1 + max_acoustic_length: 20000 + min_acoustic_length: 1800 + mode: 'train_flow' + +resample: !name:inspiremusic.dataset.processor.resample + resample_rate: !ref + +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 128 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 24000 + center: False +compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank + feat_extractor: !ref +parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding + normalize: True +shuffle: !name:inspiremusic.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:inspiremusic.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:inspiremusic.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 15500 # llm 12000 + # batch_type: 'static' + # batch_size: 2 # llm 12000 +padding: !name:inspiremusic.dataset.processor.padding + mode: 'train' + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + + +# train conf +train_conf: + optim: adam + optim_conf: + lr: 0.0001 # change to 0.001 if you want to train flow from scratch + scheduler: warmuplr + scheduler_conf: + warmup_steps: 500 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: 500 diff --git a/inspiremusic/.DS_Store b/inspiremusic/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5dada59e0d29dcc28673cc34f6f6166ef8f0c615 Binary files /dev/null and b/inspiremusic/.DS_Store differ diff --git a/inspiremusic/__init__.py b/inspiremusic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/bin/export_jit.py b/inspiremusic/bin/export_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..f68203f61246f1167f68682b5a1ff8fc5929f521 --- /dev/null +++ b/inspiremusic/bin/export_jit.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import sys +import torch +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append('{}/../..'.format(ROOT_DIR)) +sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) +from inspiremusic.cli.inspiremusic import InspireMusic + + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='pretrained_models/InspireMusic', + help='local path') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + torch._C._jit_set_fusion_strategy([('STATIC', 1)]) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + + inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False) + + # 1. export llm text_encoder + llm_text_encoder = inspiremusic.model.llm.text_encoder.half() + script = torch.jit.script(llm_text_encoder) + script = torch.jit.freeze(script) + script = torch.jit.optimize_for_inference(script) + script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) + + # 2. export llm llm + llm_llm = inspiremusic.model.llm.llm.half() + script = torch.jit.script(llm_llm) + script = torch.jit.freeze(script, preserved_attrs=['forward_chunk']) + script = torch.jit.optimize_for_inference(script) + script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) + + # 3. export flow encoder + flow_encoder = inspiremusic.model.flow.encoder + script = torch.jit.script(flow_encoder) + script = torch.jit.freeze(script) + script = torch.jit.optimize_for_inference(script) + script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) + + +if __name__ == '__main__': + main() diff --git a/inspiremusic/bin/export_onnx.py b/inspiremusic/bin/export_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..659ee13e4fc495757d11dfc558bf2b1629f35089 --- /dev/null +++ b/inspiremusic/bin/export_onnx.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import sys +import onnxruntime +import random +import torch +from tqdm import tqdm +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append('{}/../..'.format(ROOT_DIR)) +sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) +from inspiremusic.cli.inspiremusic import InspireMusic + + +def get_dummy_input(batch_size, seq_len, out_channels, device): + x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + t = torch.rand((batch_size), dtype=torch.float32, device=device) + spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + return x, mask, mu, t, spks, cond + + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='pretrained_models/InspireMusic', + help='local path') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False) + + # 1. export flow decoder estimator + estimator = inspiremusic.model.flow.decoder.estimator + + device = inspiremusic.model.device + batch_size, seq_len = 1, 256 + out_channels = inspiremusic.model.flow.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {0: 'batch_size', 2: 'seq_len'}, + 'mask': {0: 'batch_size', 2: 'seq_len'}, + 'mu': {0: 'batch_size', 2: 'seq_len'}, + 'cond': {0: 'batch_size', 2: 'seq_len'}, + 't': {0: 'batch_size'}, + 'spks': {0: 'batch_size'}, + 'estimator_out': {0: 'batch_size', 2: 'seq_len'}, + } + ) + + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + sess_options=option, providers=providers) + + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + + +if __name__ == "__main__": + main() diff --git a/inspiremusic/bin/flow_only_infer.py b/inspiremusic/bin/flow_only_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..d84e9c3671546071b02e1e78d071f8e0ded78b94 --- /dev/null +++ b/inspiremusic/bin/flow_only_infer.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import torch +from torch.utils.data import DataLoader +import torchaudio +from hyperpyyaml import load_hyperpyyaml +from tqdm import tqdm +from inspiremusic.cli.model import InspireMusicModel +from inspiremusic.dataset.dataset import Dataset +from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS + +def get_args(): + parser = argparse.ArgumentParser(description='inference only with flow model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--prompt_data', required=True, help='prompt data file') + parser.add_argument('--flow_model', required=True, help='flow model file') + parser.add_argument('--llm_model', default=None,required=False, help='llm model file') + + parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file') + parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file') + parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.') + parser.add_argument('--sample_rate', type=int, default=48000, required=False, + help='sampling rate of generated audio') + parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False, + help='the minimum generated audio length in seconds') + parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False, + help='the maximum generated audio length in seconds') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--result_dir', required=True, help='asr result file') + args = parser.parse_args() + print(args) + return args + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + # Init inspiremusic models from configs + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f) + + model = InspireMusicModel(None, configs['flow'], configs['hift'], configs['wavtokenizer']) + model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer) + + if args.llm_model is None: + model.llm = None + else: + model.llm = model.llm.to(torch.float32) + + if args.flow_model is None: + model.flow = None + + test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + del configs + os.makedirs(args.result_dir, exist_ok=True) + fn = os.path.join(args.result_dir, 'wav.scp') + f = open(fn, 'w') + with torch.no_grad(): + for _, batch in tqdm(enumerate(test_data_loader)): + utts = batch["utts"] + assert len(utts) == 1, "inference mode only support batchsize 1" + + if "semantic_token" in batch: + token = batch["semantic_token"].to(device) + token_len = batch["semantic_token_len"].to(device) + else: + if audio_token is None: + token = None + token_len = None + else: + token = audio_token.view(audio_token.size(0),-1,4)[:,:,0] + token_len = audio_token_len / 4 + + text_token = batch["text_token"].to(device) + text_token_len = batch["text_token_len"].to(device) + text = batch["text"] + + if "time_start" not in batch.keys(): + batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64) + if "time_end" not in batch.keys(): + batch["time_end"] = torch.randint(args.min_generate_audio_seconds, args.max_generate_audio_seconds, (1,)).to(torch.float64) + elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds: + batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64) + + if "chorus" not in batch.keys(): + batch["chorus"] = torch.randint(1, 5, (1,)) + + if args.chorus == "random": + batch["chorus"] = torch.randint(1, 5, (1,)) + elif args.chorus == "intro": + batch["chorus"] = torch.Tensor([0]) + elif "verse" in args.chorus: + batch["chorus"] = torch.Tensor([1]) + elif args.chorus == "chorus": + batch["chorus"] = torch.Tensor([2]) + elif args.chorus == "outro": + batch["chorus"] = torch.Tensor([4]) + + time_start = batch["time_start"].to(device) + time_end = batch["time_end"].to(device) + chorus = batch["chorus"].to(torch.int) + + text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>" + chorus = chorus.to(device) + + model_input = {"text": text, "audio_token": token, "audio_token_len": token_len, + "text_token": text_token, "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], "raw_text":text} + + music_audios = [] + for model_output in model.inference(**model_input): + music_audios.append(model_output['music_audio']) + + music_key = utts[0] + music_fn = os.path.join(args.result_dir, '{}.wav'.format(music_key)) + torchaudio.save(music_fn, music_audios[0], sample_rate=args.sample_rate) + f.write('{} {}\n'.format(music_key, music_fn)) + f.flush() + f.close() + logging.info('Result wav.scp saved in {}'.format(fn)) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/inspiremusic/bin/inference.py b/inspiremusic/bin/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c7429fd34487c7b749cdca6afecd4b3a5dcc4bad --- /dev/null +++ b/inspiremusic/bin/inference.py @@ -0,0 +1,266 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging + +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import torch +from torch.utils.data import DataLoader +import torchaudio +from hyperpyyaml import load_hyperpyyaml +from tqdm import tqdm +from inspiremusic.cli.model import InspireMusicModel +from inspiremusic.dataset.dataset import Dataset +import time +from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio +from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def get_args(): + parser = argparse.ArgumentParser(description='inference only with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--prompt_data', required=True, help='prompt data file') + parser.add_argument('--flow_model', default=None, required=False, help='flow model file') + parser.add_argument('--llm_model', default=None,required=False, help='flow model file') + parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file') + parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file') + parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.') + parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.') + parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model') + parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio') + parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds') + parser.add_argument('--trim', default=False, type=bool, required=False, help='trim the silence ending of generated audio') + parser.add_argument('--format', type=str, default="wav", required=False, + choices=["wav", "mp3", "m4a", "flac"], + help='sampling rate of input audio') + parser.add_argument('--sample_rate', type=int, default=24000, required=False, + help='sampling rate of input audio') + parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000], + help='sampling rate of generated output audio') + parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False, + help='the minimum generated audio length in seconds') + parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False, + help='the maximum generated audio length in seconds') + parser.add_argument('--gpu', + type=int, + default=0, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--task', + default='text-to-music', + choices=['text-to-music', 'continuation', "reconstruct", "super_resolution"], + help='choose inference task type. text-to-music: text-to-music task. continuation: music continuation task. reconstruct: reconstruction of original music. super_resolution: convert original 24kHz music into 48kHz music.') + parser.add_argument('--result_dir', required=True, help='asr result file') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + if args.fast: + args.output_sample_rate = 24000 + + min_generate_audio_length = int(args.output_sample_rate * args.min_generate_audio_seconds) + max_generate_audio_length = int(args.output_sample_rate * args.max_generate_audio_seconds) + assert args.min_generate_audio_seconds <= args.max_generate_audio_seconds + + # Init inspiremusic models from configs + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f) + + model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.fast, args.fp16) + + model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer) + + if args.llm_model is None: + model.llm = None + else: + model.llm = model.llm.to(torch.float32) + + if args.flow_model is None: + model.flow = None + + test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + del configs + os.makedirs(args.result_dir, exist_ok=True) + fn = os.path.join(args.result_dir, 'wav.scp') + f = open(fn, 'w') + caption_fn = os.path.join(args.result_dir, 'captions.txt') + caption_f = open(caption_fn, 'w') + + with torch.no_grad(): + for _, batch in tqdm(enumerate(test_data_loader)): + utts = batch["utts"] + + assert len(utts) == 1, "inference mode only support batchsize 1" + text_token = batch["text_token"].to(device) + text_token_len = batch["text_token_len"].to(device) + + if "time_start" not in batch.keys(): + batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64) + + if batch["time_start"].numpy()[0] > 300: + batch["time_start"] = torch.Tensor([0]).to(torch.float64) + + if "time_end" not in batch.keys(): + batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64) + else: + if (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds: + batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64) + elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) > args.max_generate_audio_seconds: + batch["time_end"] = torch.Tensor([(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds)]).to(torch.float64) + + if "chorus" not in batch.keys(): + batch["chorus"] = torch.randint(1, 5, (1,)) + + if args.chorus == "random": + batch["chorus"] = torch.randint(1, 5, (1,)) + elif args.chorus == "intro": + batch["chorus"] = torch.Tensor([0]) + elif "verse" in args.chorus: + batch["chorus"] = torch.Tensor([1]) + elif args.chorus == "chorus": + batch["chorus"] = torch.Tensor([2]) + elif args.chorus == "outro": + batch["chorus"] = torch.Tensor([4]) + else: + batch["chorus"] = batch["chorus"] + + time_start = batch["time_start"].to(device) + time_end = batch["time_end"].to(device) + chorus = batch["chorus"].to(torch.int) + + text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>" + chorus = chorus.to(device) + + if batch["acoustic_token"] is None: + audio_token = None + audio_token_len = None + else: + audio_token = batch["acoustic_token"].to(device) + audio_token_len = batch["acoustic_token_len"].to(device) + + text = batch["text"] + + if "semantic_token" in batch: + token = batch["semantic_token"].to(device) + token_len = batch["semantic_token_len"].to(device) + else: + if audio_token is None: + token = None + token_len = None + else: + token = audio_token.view(audio_token.size(0), -1, 4)[:, :, 0] + token_len = audio_token_len / 4 + + if args.task in ['text-to-music', 'continuation']: + # text to music, music continuation + model_input = {"text": text, "audio_token": token, + "audio_token_len": token_len, + "text_token": text_token, + "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], + "raw_text": text, + "sample_rate": args.output_sample_rate, + "duration_to_gen": args.max_generate_audio_seconds, + "task": args.task} + elif args.task in ['reconstruct', 'super_resolution']: + # audio reconstruction, audio super resolution + model_input = {"text": text, "audio_token": audio_token, + "audio_token_len": audio_token_len, + "text_token": text_token, + "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], + "raw_text": text, + "sample_rate": args.output_sample_rate, + "duration_to_gen": args.max_generate_audio_seconds, + "task": args.task} + else: + # zero-shot + model_input = {'text' : text, + 'text_len' : text_token_len, + 'prompt_text' : text_token, + 'prompt_text_len' : text_token_len, + 'llm_prompt_audio_token' : token, + 'llm_prompt_audio_token_len' : token_len, + 'flow_prompt_audio_token' : audio_token, + 'flow_prompt_audio_token_len': audio_token_len, + 'prompt_audio_feat' : audio_feat, + 'prompt_audio_feat_len' : audio_feat_len, + "embeddings" : [time_start, + time_end, + chorus]} + + music_key = utts[0] + music_audios = [] + music_fn = os.path.join(args.result_dir, f'{music_key}.{args.format}') + bench_start = time.time() + + for model_output in model.inference(**model_input): + music_audios.append(model_output['music_audio']) + bench_end = time.time() + if args.trim: + music_audio = trim_audio(music_audios[0], + sample_rate=args.output_sample_rate, + threshold=0.05, + min_silence_duration=0.8) + else: + music_audio = music_audios[0] + if music_audio.shape[0] != 0: + if music_audio.shape[1] > max_generate_audio_length: + music_audio = music_audio[:, :max_generate_audio_length] + if music_audio.shape[1] >= min_generate_audio_length: + try: + if args.fade_out: + music_audio = fade_out(music_audio, args.output_sample_rate, args.fade_out_duration) + music_audio = music_audio.repeat(2, 1) + if args.format in ["wav", "flac"]: + torchaudio.save(music_fn, music_audio, sample_rate=args.output_sample_rate, encoding="PCM_S", bits_per_sample=24) + elif args.format in ["mp3", "m4a"]: + torchaudio.backend.sox_io_backend.save(filepath=music_fn, src=music_audio, sample_rate=args.output_sample_rate, format=args.format) + else: + logging.info(f"Format is not supported. Please choose from wav, mp3, m4a, flac.") + except Exception as e: + logging.info(f"Error saving file: {e}") + raise + + audio_duration = music_audio.shape[1] / args.output_sample_rate + rtf = (bench_end - bench_start) / audio_duration + logging.info(f"processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}") + f.write('{} {}\n'.format(music_key, music_fn)) + f.flush() + caption_f.write('{}\t{}\n'.format(music_key, text_prompt)) + caption_f.flush() + else: + logging.info(f"Generate audio length {music_audio.shape[1]} is shorter than min_generate_audio_length.") + else: + logging.info(f"Generate audio is empty, dim = {music_audio.shape[0]}.") + f.close() + logging.info('Result wav.scp saved in {}'.format(fn)) + + +if __name__ == '__main__': + main() diff --git a/inspiremusic/bin/train.py b/inspiremusic/bin/train.py new file mode 100644 index 0000000000000000000000000000000000000000..92a9bae52670a4d7d9b48d944f3e96ac086ca762 --- /dev/null +++ b/inspiremusic/bin/train.py @@ -0,0 +1,194 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import datetime +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +from copy import deepcopy +import torch +import torch.distributed as dist +import deepspeed +import glob +import os +from hyperpyyaml import load_hyperpyyaml +from torch.cuda.amp import GradScaler, autocast +from torch.distributed.elastic.multiprocessing.errors import record +from peft import get_peft_config, get_peft_model, LoraConfig, TaskType +from inspiremusic.utils.executor import Executor +from inspiremusic.utils.train_utils import ( + init_distributed, + init_dataset_and_dataloader, + init_optimizer_and_scheduler, + init_summarywriter, save_model, + wrap_cuda_model, check_modify_and_save_config) + + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--train_engine', + default='torch_ddp', + choices=['torch_ddp', 'deepspeed'], + help='Engine for paralleled training') + parser.add_argument('--model', required=True, help='model which will be trained') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--num_workers', + default=0, + type=int, + help='number of subprocess workers for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--pin_memory', + action='store_true', + default=True, + help='Use pinned memory buffers used for reading') + parser.add_argument('--deepspeed.save_states', + dest='save_states', + default='model_only', + choices=['model_only', 'model+optimizer'], + help='save model/optimizer states') + parser.add_argument('--timeout', + default=30, + type=int, + help='timeout (in seconds) of inspiremusic_join.') + parser.add_argument('--fp16', + action='store_true', + default=False, + help='Enable fp16 mixed precision training') + parser.add_argument('--lora', + action='store_true', + default=False, + help='Enable LoRA training') + parser.add_argument('--lora_rank', + default=4, + type=int, + help='LoRA rank') + parser.add_argument('--lora_alpha', + default=16, + type=int, + help='LoRA alpha') + parser.add_argument('--lora_dropout', + default=0.1, + type=float, + help='LoRA dropout rate') + parser.add_argument('--lora_target_modules', + nargs='+', + default=["k_proj","v_proj"], + help='Target modules to apply LoRA (e.g., ["q_proj", "v_proj"])') + + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +@record +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model} + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict) + configs['train_conf'].update(vars(args)) + + # Init env for ddp + init_distributed(args) + + # Get dataset & dataloader + train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ + init_dataset_and_dataloader(args, configs) + + # Do some sanity checks and save config to arsg.model_dir + configs = check_modify_and_save_config(args, configs) + + # Tensorboard summary + writer = init_summarywriter(args) + + # load checkpoint + model = configs[args.model] + + if args.checkpoint is not None: + model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')) + else: + # Find and load the latest checkpoint + checkpoint_files = glob.glob(os.path.join(args.model_dir, '*.pt')) + + if checkpoint_files: + latest_checkpoint = max(checkpoint_files, key=os.path.getctime) + logging.info(f"Loaded latest checkpoint from {latest_checkpoint}") + + model.load_state_dict(torch.load(latest_checkpoint, map_location='cpu')) + + if args.lora: + logging.info("Applying LoRA to the model...") + if not args.lora_target_modules: + raise ValueError("No target modules specified for LoRA. Please provide --lora_target_modules.") + lora_config = LoraConfig( + task_type="CAUSAL_LM", # Change to appropriate task type + inference_mode=False, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.lora_target_modules + ) + model.llm.model = get_peft_model(model.llm.model, lora_config) + # Optionally freeze the base model + else: + logging.info("LoRA is not enabled. Training the full model.") + + # Dispatch model from cpu to gpu + model = wrap_cuda_model(args, model) + + # Get optimizer & scheduler + model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model) + + # Initialize AMP for torch_ddp if fp16 is enabled + scaler = None + if args.fp16: + scaler = GradScaler() + logging.info("Initialized AMP GradScaler for mixed precision training.") + + # Save init checkpoints + info_dict = deepcopy(configs['train_conf']) + + # Get executor + executor = Executor() + + # Start training loop + for epoch in range(info_dict['max_epoch']): + executor.epoch = epoch + train_dataset.set_epoch(epoch) + dist.barrier() + group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + executor.train_one_epoch(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=scaler) + dist.destroy_process_group(group_join) + +if __name__ == '__main__': + main() diff --git a/inspiremusic/cli/__init__.py b/inspiremusic/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/cli/frontend.py b/inspiremusic/cli/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..b8fb99a8083f839f7991710d51a77750ac57ce75 --- /dev/null +++ b/inspiremusic/cli/frontend.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +import torch +from typing import Callable +import re +import inflect +from inspiremusic.cli.model import InspireMusicModel +from inspiremusic.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph +from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer + +class InspireMusicFrontEnd: + def __init__(self, + configs: Callable, + get_tokenizer: Callable, + llm_model: str, + flow_model: str, + music_tokenizer_dir: str, + audio_tokenizer_dir: str, + instruct: bool = False, + fast: bool = False, + fp16: bool = True, + allowed_special: str = 'all'): + self.tokenizer = get_tokenizer() + self.audio_tokenizer_dir = audio_tokenizer_dir + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.bandwidth_id = torch.tensor([0]).to(self.device) + self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device) + + self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16) + self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir) + + self.instruct = instruct + self.allowed_special = allowed_special + self.inflect_parser = inflect.engine() + + def _extract_text_token(self, text): + text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) + text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) + text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) + return text_token, text_token_len + + def _extract_audio_token(self, audio, sample_rate=24000): + audio = torch.tensor(audio, dtype=torch.float32, device=self.device) + _, audio_token = self.wavtokenizer.encode_infer(audio, bandwidth_id=self.bandwidth_id) + audio_token = audio_token.squeeze(0) + audio_token_len = torch.tensor([audio_token.shape[1]], dtype=torch.int32, device=self.device) + return audio_token, audio_token_len + + def text_normalize(self, text, split=True): + text = text.strip() + if contains_chinese(text): + text = text.replace("\n", "") + text = replace_blank(text) + text = replace_corner_mark(text) + text = text.replace(".", "、") + text = text.replace(" - ", ",") + text = remove_bracket(text) + text = re.sub(r'[,,]+$', '。', text) + texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, + token_min_n=60, merge_len=20, comma_split=False)) + else: + text = spell_out_number(text, self.inflect_parser) + texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, + token_min_n=60, merge_len=20, comma_split=False)) + if split is False: + return text + return texts + + def frontend_text_to_music(self, text, time_start, time_end, chorus): + text_token, text_token_len = self._extract_text_token(text) + model_input = {"text": text, "audio_token": None, "audio_token_len": None, + "text_token": text_token, "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], "raw_text":text} + return model_input + + def frontend_continuation(self, text, audio, time_start, time_end, chorus, target_sr=24000): + if text is None: + text_token = None + text_token_len = None + else: + text_token, text_token_len = self._extract_text_token(text) + audio_token, audio_token_len = self._extract_audio_token(audio, target_sr) + model_input = {"text": text, "audio_token": audio_token, "audio_token_len": audio_token_len, + "text_token": text_token, "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], "raw_text":text} + return model_input + diff --git a/inspiremusic/cli/inference.py b/inspiremusic/cli/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..38dc3298c07373660a1d95ea495be80bacff184b --- /dev/null +++ b/inspiremusic/cli/inference.py @@ -0,0 +1,312 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import torchaudio +import time +import logging +import argparse +from inspiremusic.cli.inspiremusic import InspireMusic +from inspiremusic.utils.file_utils import logging +import torch +from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio + +def set_env_variables(): + os.environ['PYTHONIOENCODING'] = 'UTF-8' + os.environ['TOKENIZERS_PARALLELISM'] = 'False' + main_root = os.getcwd() + bin_dir = os.path.join(main_root, 'inspiremusic') + third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS') + python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}" + os.environ['PATH'] = python_path + sys.path.extend([main_root, third_party_matcha_tts_path]) + +class InspireMusicUnified: + def __init__(self, + model_name: str = "InspireMusic-1.5B-Long", + model_dir: str = None, + min_generate_audio_seconds: float = 10.0, + max_generate_audio_seconds: float = 30.0, + sample_rate: int = 24000, + output_sample_rate: int = 48000, + load_jit: bool = True, + load_onnx: bool = False, + fast: bool = False, + fp16: bool = True, + gpu: int = 0, + result_dir: str = None, + hub="modelscope"): + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) + + # Set model_dir or default to downloading if it doesn't exist + if model_dir is None: + model_dir = f"pretrained_models/{model_name}" + else: + model_dir = model_dir.replace("../../", "./") + + if not os.path.isfile(f"{model_dir}/llm.pt"): + if hub == "modelscope": + from modelscope import snapshot_download + if model_name == "InspireMusic-Base": + snapshot_download(f"iic/InspireMusic", local_dir=model_dir) + else: + snapshot_download(f"iic/{model_name}", local_dir=model_dir) + + self.model_dir = model_dir + print(self.model_dir) + + self.sample_rate = sample_rate + self.output_sample_rate = 24000 if fast else output_sample_rate + self.result_dir = result_dir or f"exp/{model_name}" + os.makedirs(self.result_dir, exist_ok=True) + + self.min_generate_audio_seconds = min_generate_audio_seconds + self.max_generate_audio_seconds = max_generate_audio_seconds + self.min_generate_audio_length = int(self.output_sample_rate * self.min_generate_audio_seconds) + self.max_generate_audio_length = int(self.output_sample_rate * self.max_generate_audio_seconds) + assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds" + + use_cuda = gpu >= 0 and torch.cuda.is_available() + self.device = torch.device('cuda' if use_cuda else 'cpu') + self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, fast=fast, fp16=fp16) + self.model.model.llm = self.model.model.llm.to(torch.float16) + + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + @torch.inference_mode() + def inference(self, + task: str = 'text-to-music', + text: str = None, + audio_prompt: str = None, # audio prompt file path + chorus: str = "verse", + time_start: float = 0.0, + time_end: float = 30.0, + output_fn: str = "output_audio", + max_audio_prompt_length: float = 5.0, + fade_out_duration: float = 1.0, + output_format: str = "wav", + fade_out_mode: bool = True, + trim: bool = False, + ): + + with torch.no_grad(): + text_prompt = f"<|{time_start}|><|{chorus}|><|{text}|><|{time_end}|>" + chorus_dict = {"random": torch.randint(1, 5, (1,)).item(), "intro" : 0, "verse": 1, "chorus": 2, "outro": 4} + chorus = chorus_dict.get(chorus, 1) + chorus = torch.tensor([chorus], dtype=torch.int).to(self.device) + + time_start_tensor = torch.tensor([time_start], dtype=torch.float64).to(self.device) + time_end_tensor = torch.tensor([time_end], dtype=torch.float64).to(self.device) + + music_fn = os.path.join(self.result_dir, f'{output_fn}.{output_format}') + + bench_start = time.time() + + if task == 'text-to-music': + model_input = { + "text" : text, + "audio_prompt" : audio_prompt, + "time_start" : time_start_tensor, + "time_end" : time_end_tensor, + "chorus" : chorus, + "task" : task, + "stream" : False, + "duration_to_gen": self.max_generate_audio_seconds, + "sr" : self.sample_rate + } + elif task == 'continuation': + if audio_prompt is not None: + audio, _ = process_audio(audio_prompt, self.sample_rate) + if audio.size(1) < self.sample_rate: + logging.warning("Warning: Input prompt audio length is shorter than 1s. Please provide an appropriate length audio prompt and try again.") + audio = None + else: + max_audio_prompt_length_samples = int(max_audio_prompt_length * self.sample_rate) + audio = audio[:, :max_audio_prompt_length_samples] # Trimming prompt audio + + model_input = { + "text" : text, + "audio_prompt" : audio, + "time_start" : time_start_tensor, + "time_end" : time_end_tensor, + "chorus" : chorus, + "task" : task, + "stream" : False, + "duration_to_gen": self.max_generate_audio_seconds, + "sr" : self.sample_rate + } + + music_audios = [] + for model_output in self.model.cli_inference(**model_input): + music_audios.append(model_output['music_audio']) + + bench_end = time.time() + + if trim: + music_audio = trim_audio(music_audios[0], + sample_rate=self.output_sample_rate, + threshold=0.05, + min_silence_duration=0.8) + else: + music_audio = music_audios[0] + + if music_audio.shape[0] != 0: + if music_audio.shape[1] > self.max_generate_audio_length: + music_audio = music_audio[:, :self.max_generate_audio_length] + + if music_audio.shape[1] >= self.min_generate_audio_length: + try: + if fade_out_mode: + music_audio = fade_out(music_audio, self.output_sample_rate, fade_out_duration) + + music_audio = music_audio.repeat(2, 1) + + if output_format in ["wav", "flac"]: + torchaudio.save(music_fn, music_audio, + sample_rate=self.output_sample_rate, + encoding="PCM_S", + bits_per_sample=24) + elif output_format in ["mp3", "m4a"]: + torchaudio.backend.sox_io_backend.save( + filepath=music_fn, src=music_audio, + sample_rate=self.output_sample_rate, + format=output_format) + else: + logging.info("Format is not supported. Please choose from wav, mp3, m4a, flac.") + + except Exception as e: + logging.error(f"Error saving file: {e}") + raise + + audio_duration = music_audio.shape[1] / self.output_sample_rate + rtf = (bench_end - bench_start) / audio_duration + logging.info(f"Processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}") + + else: + logging.error(f"Generated audio length is shorter than minimum required audio length.") + if music_fn: + if os.path.exists(music_fn): + logging.info(f"Generated audio file {music_fn} is saved.") + return music_fn + else: + logging.error(f"{music_fn} does not exist.") + +def get_args(): + parser = argparse.ArgumentParser(description='Run inference with your model') + parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long", + help='Model name') + + parser.add_argument('-d', '--model_dir', + help='Model folder path') + + parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.", + help='Prompt text') + + parser.add_argument('-a', '--audio_prompt', default=None, + help='Prompt audio') + + parser.add_argument('-c', '--chorus', default="intro", + help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)') + + parser.add_argument('-f', '--fast', type=bool, default=False, + help='Enable fast inference mode (without flow matching)') + + parser.add_argument('-g', '--gpu', type=int, default=0, + help='GPU ID for this rank, -1 for CPU') + + parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'], + help='Inference task type: text-to-music, continuation, reconstruct, super_resolution') + + parser.add_argument('-r', '--result_dir', default="exp/inspiremusic", + help='Directory to save generated audio') + + parser.add_argument('-o', '--output_fn', default="output_audio", + help='Output file name') + + parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"], + help='Format of output audio') + + parser.add_argument('--sample_rate', type=int, default=24000, + help='Sampling rate of input audio') + + parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000], + help='Sampling rate of generated output audio') + + parser.add_argument('-s', '--time_start', type=float, default=0.0, + help='Start time in seconds') + + parser.add_argument('-e', '--time_end', type=float, default=30.0, + help='End time in seconds') + + parser.add_argument('--max_audio_prompt_length', type=float, default=5.0, + help='Maximum audio prompt length in seconds') + + parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, + help='Minimum generated audio length in seconds') + + parser.add_argument('--max_generate_audio_seconds', type=float, default=300.0, + help='Maximum generated audio length in seconds') + + parser.add_argument('--fp16', type=bool, default=True, + help='Inference with fp16 model') + + parser.add_argument('--fade_out', type=bool, default=True, + help='Apply fade out effect to generated audio') + + parser.add_argument('--fade_out_duration', type=float, default=1.0, + help='Fade out duration in seconds') + + parser.add_argument('--trim', type=bool, default=False, + help='Trim the silence ending of generated audio') + + args = parser.parse_args() + + if not args.model_dir: + args.model_dir = os.path.join("pretrained_models", args.model_name) + + print(args) + return args + +def main(): + set_env_variables() + args = get_args() + model = InspireMusicUnified(model_name = args.model_name, + model_dir = args.model_dir, + min_generate_audio_seconds = args.min_generate_audio_seconds, + max_generate_audio_seconds = args.max_generate_audio_seconds, + sample_rate = args.sample_rate, + output_sample_rate = args.output_sample_rate, + load_jit = True, + load_onnx = False, + fast = args.fast, + fp16 = args.fp16, + gpu = args.gpu, + result_dir = args.result_dir) + + model.inference(task = args.task, + text = args.text, + audio_prompt = args.audio_prompt, + chorus = args.chorus, + time_start = args.time_start, + time_end = args.time_end, + output_fn = args.output_fn, + max_audio_prompt_length = args.max_audio_prompt_length, + fade_out_duration = args.fade_out_duration, + output_format = args.format, + fade_out_mode = args.fade_out, + trim = args.trim) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/inspiremusic/cli/inspiremusic.py b/inspiremusic/cli/inspiremusic.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb14c5cc6ff4fa3235749fa99aa92c68817be9d --- /dev/null +++ b/inspiremusic/cli/inspiremusic.py @@ -0,0 +1,143 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from tqdm import tqdm +from hyperpyyaml import load_hyperpyyaml +from inspiremusic.cli.frontend import InspireMusicFrontEnd +from inspiremusic.cli.model import InspireMusicModel +from inspiremusic.utils.file_utils import logging +import torch + +class InspireMusic: + def __init__(self, model_dir, load_jit=True, load_onnx=False, fast = False, fp16=True, hub="modelscope"): + instruct = True if '-Instruct' in model_dir else False + + if model_dir is None: + model_dir = f"pretrained_models/InspireMusic-1.5B-Long" + + if not os.path.isfile(f"{model_dir}/llm.pt"): + model_name = model_dir.split("/")[-1] + if hub == "modelscope": + from modelscope import snapshot_download + if model_name == "InspireMusic-Base": + snapshot_download(f"iic/InspireMusic", local_dir=model_dir) + else: + snapshot_download(f"iic/{model_name}", local_dir=model_dir) + + assert os.path.exists(f'{model_dir}/inspiremusic.yaml') + with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f: + configs = load_hyperpyyaml(f) + + self.frontend = InspireMusicFrontEnd(configs, + configs['get_tokenizer'], + '{}/llm.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir), + '{}/music_tokenizer/'.format(model_dir), + '{}/wavtokenizer/'.format(model_dir), + instruct, + fast, + fp16, + configs['allowed_special']) + + self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16) + self.model.load('{}/llm.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir), + '{}/music_tokenizer/'.format(model_dir), + '{}/wavtokenizer/model.pt'.format(model_dir)) + del configs + + @torch.inference_mode() + def inference(self, task, text, audio, time_start, time_end, chorus, stream=False, sr=24000): + if task == "text-to-music": + for i in tqdm(self.frontend.text_normalize(text, split=True)): + model_input = self.frontend.frontend_text_to_music(i, time_start, time_end, chorus) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.inference(**model_input, stream=stream): + music_audios_len = model_output['music_audio'].shape[1] / sr + logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len)) + yield model_output + start_time = time.time() + + elif task == "continuation": + if text is None: + if audio is not None: + for i in tqdm(audio): + model_input = self.frontend.frontend_continuation(None, i, time_start, time_end, chorus, sr, max_audio_length) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.continuation_inference(**model_input, stream=stream): + music_audios_len = model_output['music_audio'].shape[1] / sr + logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len)) + yield model_output + start_time = time.time() + else: + if audio is not None: + for i in tqdm(self.frontend.text_normalize(text, split=True)): + model_input = self.frontend.frontend_continuation(i, audio, time_start, time_end, chorus, sr, max_audio_length) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.continuation_inference(**model_input, stream=stream): + music_audios_len = model_output['music_audio'].shape[1] / sr + logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len)) + yield model_output + start_time = time.time() + else: + print("Please input text or audio.") + else: + print("Currently only support text-to-music and music continuation tasks.") + + @torch.inference_mode() + def cli_inference(self, text, audio_prompt, time_start, time_end, chorus, task, stream=False, duration_to_gen=30, sr=24000): + if task == "text-to-music": + model_input = self.frontend.frontend_text_to_music(text, time_start, time_end, chorus) + logging.info('prompt text {}'.format(text)) + elif task == "continuation": + model_input = self.frontend.frontend_continuation(text, audio_prompt, time_start, time_end, chorus, sr) + logging.info('prompt audio length: {}'.format(len(audio_prompt))) + + start_time = time.time() + for model_output in self.model.inference(**model_input, duration_to_gen=duration_to_gen, task=task): + music_audios_len = model_output['music_audio'].shape[1] / sr + logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len)) + yield model_output + start_time = time.time() + + @torch.inference_mode() + def inference_zero_shot(self, text, prompt_text, prompt_audio_16k, stream=False, sr=24000): + prompt_text = self.frontend.text_normalize(prompt_text, split=False) + for i in tqdm(self.frontend.text_normalize(text, split=True)): + model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_audio_16k) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.inference(**model_input, stream=stream): + audio_len = model_output['music_audio'].shape[1] / sr + logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len)) + yield model_output + start_time = time.time() + @torch.inference_mode() + def inference_instruct(self, text, spk_id, instruct_text, stream=False, sr=24000): + if self.frontend.instruct is False: + raise ValueError('{} do not support instruct inference'.format(self.model_dir)) + instruct_text = self.frontend.text_normalize(instruct_text, split=False) + for i in tqdm(self.frontend.text_normalize(text, split=True)): + model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.inference(**model_input, stream=stream): + audio_len = model_output['music_audio'].shape[1] / sr + logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len)) + yield model_output + start_time = time.time() diff --git a/inspiremusic/cli/model.py b/inspiremusic/cli/model.py new file mode 100644 index 0000000000000000000000000000000000000000..85b2e626a563a51c478f4a85712fb0a3a1682120 --- /dev/null +++ b/inspiremusic/cli/model.py @@ -0,0 +1,295 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import threading +import time +from contextlib import nullcontext +import uuid +from inspiremusic.music_tokenizer.vqvae import VQVAE +from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer +from torch.cuda.amp import autocast +import logging +import torch +import os + + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class InspireMusicModel: + + def __init__(self, + llm: torch.nn.Module, + flow: torch.nn.Module, + music_tokenizer: torch.nn.Module, + wavtokenizer: torch.nn.Module, + fast: bool = False, + fp16: bool = True, + ): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.llm = llm + self.flow = flow + self.music_tokenizer = music_tokenizer + self.wavtokenizer = wavtokenizer + self.fp16 = fp16 + self.token_min_hop_len = 100 + self.token_max_hop_len = 200 + self.token_overlap_len = 20 + # mel fade in out + self.mel_overlap_len = 34 + self.mel_window = np.hamming(2 * self.mel_overlap_len) + # hift cache + self.mel_cache_len = 20 + self.source_cache_len = int(self.mel_cache_len * 256) + # rtf and decoding related + self.stream_scale_factor = 1 + assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' + self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.lock = threading.Lock() + # dict used to store session related variable + self.music_token_dict = {} + self.llm_end_dict = {} + self.mel_overlap_dict = {} + self.fast = fast + self.generator = "hifi" + + def load(self, llm_model, flow_model, hift_model, wavtokenizer_model): + if llm_model is not None: + self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) + self.llm.to(self.device).eval() + else: + self.llm = None + if flow_model is not None: + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) + self.flow.to(self.device).eval() + if hift_model is not None: + if ".pt" not in hift_model: + self.music_tokenizer = VQVAE( hift_model + '/config.json', + hift_model + '/model.pt', with_encoder=True) + else: + self.music_tokenizer = VQVAE(os.path.dirname(hift_model) + '/config.json', + hift_model, with_encoder=True) + self.music_tokenizer.to(self.device).eval() + if wavtokenizer_model is not None: + if ".pt" not in wavtokenizer_model: + self.wavtokenizer = WavTokenizer.from_pretrained_feat( wavtokenizer_model + '/config.yaml', + wavtokenizer_model + '/model.pt') + else: + self.wavtokenizer = WavTokenizer.from_pretrained_feat( os.path.dirname(wavtokenizer_model) + '/config.yaml', + wavtokenizer_model ) + self.wavtokenizer.to(self.device) + + def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): + assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model" + llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) + self.llm.text_encoder = llm_text_encoder + llm_llm = torch.jit.load(llm_llm_model) + self.llm.llm = llm_llm + flow_encoder = torch.jit.load(flow_encoder_model) + self.flow.encoder = flow_encoder + + def load_onnx(self, flow_decoder_estimator_model): + import onnxruntime + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + del self.flow.decoder.estimator + self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers) + + def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task): + with self.llm_context: + local_res = [] + with autocast(enabled=self.fp16): + inference_kwargs = { + 'text': text.to(self.device), + 'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), + 'prompt_text': prompt_text.to(self.device), + 'prompt_text_len': torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + 'prompt_audio_token': llm_prompt_audio_token.to(self.device), + 'prompt_audio_token_len': torch.tensor([llm_prompt_audio_token.shape[1]], dtype=torch.int32).to(self.device), + 'embeddings': embeddings, + 'duration_to_gen': duration_to_gen, + 'task': task + } + + if audio_token is not None: + inference_kwargs['audio_token'] = audio_token.to(self.device) + else: + inference_kwargs['audio_token'] = torch.Tensor([0]).to(self.device) + + if audio_token_len is not None: + inference_kwargs['audio_token_len'] = audio_token_len.to(self.device) + else: + inference_kwargs['audio_token_len'] = torch.Tensor([0]).to(self.device) + + for i in self.llm.inference(**inference_kwargs): + local_res.append(i) + + self.music_token_dict[uuid] = local_res + self.llm_end_dict[uuid] = True + + # def token2wav(self, token, token_len, text, text_len, uuid, sample_rate, finalize=False): + def token2wav(self, token, token_len, uuid, sample_rate, finalize=False, flow_cfg=None): + # if self.flow is not None: + # if isinstance(self.flow,MaskedDiffWithText): + # codec_embed = self.flow.inference(token=token.to(self.device), + # token_len=token_len.to(self.device), + # text_token=text, + # text_token_len=text_len, + # ) + # else: + if flow_cfg is not None: + codec_embed = self.flow.inference_cfg(token=token.to(self.device), + token_len=token_len.to(self.device), + sample_rate=sample_rate + ) + else: + codec_embed = self.flow.inference(token=token.to(self.device), + token_len=token_len.to(self.device), + sample_rate=sample_rate + ) + # use music_tokenizer decoder + wav = self.music_tokenizer.generator(codec_embed) + wav = wav.squeeze(0).cpu().detach() + return wav + + def acoustictoken2wav(self, token): + # use music_tokenizer to generate waveform from token + token = token.view(token.size(0), -1, 4) + # codec = token.view(1, -1, 4) + codec_embed = self.music_tokenizer.quantizer.embed(torch.tensor(token).long().to(self.device)).cuda() + wav = self.music_tokenizer.generator(codec_embed) + wav = wav.squeeze(0).cpu().detach() + return wav + + def semantictoken2wav(self, token): + # fast mode, use wavtokenizer decoder + new_tensor = torch.tensor(token.to(self.device)).unsqueeze(0) + features = self.wavtokenizer.codes_to_features(new_tensor) + bandwidth_id = torch.tensor([0]).to(self.device) + wav = self.wavtokenizer.to(self.device).decode(features, bandwidth_id=bandwidth_id) + wav = wav.cpu().detach() + return wav + + @torch.inference_mode() + def inference(self, text, audio_token, audio_token_len, text_token, text_token_len, embeddings=None, + prompt_text=torch.zeros(1, 0, dtype=torch.int32), + llm_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32), + flow_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32), + prompt_audio_feat=torch.zeros(1, 0, 80), sample_rate=48000, duration_to_gen = 30, task="continuation", trim = True, stream=False, **kwargs): + + # this_uuid is used to track variables related to this inference thread + # support tasks: + # text to music task + # music continuation task + # require either audio input only or text and audio inputs + + this_uuid = str(uuid.uuid1()) + + if self.llm: + with self.lock: + self.music_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + + p = threading.Thread(target=self.llm_job, args=(text_token, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, this_uuid, duration_to_gen, task)) + p.start() + + if stream is True: + token_hop_len = self.token_min_hop_len + while True: + time.sleep(0.1) + if len(self.music_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: + this_music_audio = self.token2wav(token=text_token, + token_len=text_token_len, + uuid=this_uuid, + sample_rate=sample_rate, + finalize=False) + yield {'music_audio': this_music_audio.cpu()} + with self.lock: + self.music_token_dict[this_uuid] = self.music_token_dict[this_uuid][token_hop_len:] + # increase token_hop_len for better audio quality + token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) + if self.llm_end_dict[this_uuid] is True and len(self.music_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: + break + p.join() + # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None + this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1) + with self.flow_hift_context: + this_music_audio = self.token2wav(token=this_music_token, + prompt_token=flow_prompt_audio_token, + prompt_feat=prompt_audio_feat, + embedding=flow_embedding, + uuid=this_uuid, + sample_rate=sample_rate, + finalize=True) + yield {'music_audio': this_music_audio.cpu()} + else: + # deal with all tokens + if self.fast: + if task == "reconstruct": + assert audio_token is None + this_music_token = audio_token + this_music_audio = self.acoustictoken2wav(token=this_music_token) + else: + if self.llm: + p.join() + print(len(self.music_token_dict[this_uuid])) + this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1) + print(this_music_token.shape) + else: + this_music_token = text_token + + logging.info("using wavtokenizer generator without flow matching") + this_music_audio = self.semantictoken2wav(token=this_music_token) + print(this_music_audio.shape) + + else: + if self.llm: + p.join() + if len(self.music_token_dict[this_uuid]) != 0: + this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1) + else: + print(f"The list of tensors is empty for UUID: {this_uuid}") + else: + this_music_token = text_token + logging.info(f"LLM generated audio token length: {this_music_token.shape[1]}") + logging.info(f"using flow matching and {self.generator} generator") + + if self.generator == "hifi": + if (embeddings[1] - embeddings[0]) <= duration_to_gen: + if trim: + trim_length = (int((embeddings[1] - embeddings[0])*75)) + this_music_token = this_music_token[:, :trim_length] + logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}") + elif (embeddings[1] - embeddings[0]) < 1: + logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.") + + this_music_audio = self.token2wav(token=this_music_token, + token_len=torch.LongTensor([this_music_token.size(1)]), + uuid=this_uuid, + sample_rate=sample_rate, + finalize=True) + logging.info(f"Generated audio sequence length: {this_music_audio.shape[1]}") + elif self.generator == "wavtokenizer": + if (embeddings[1] - embeddings[0]) < duration_to_gen: + if trim: + trim_length = (int((embeddings[1] - embeddings[0])*75)) + this_music_token = this_music_token[:,:trim_length] + logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}") + elif (embeddings[1] - embeddings[0]) < 1: + logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.") + + this_music_audio = self.semantictoken2wav(token=this_music_token) + + yield {'music_audio': this_music_audio.cpu()} + torch.cuda.synchronize() \ No newline at end of file diff --git a/inspiremusic/dataset/__init__.py b/inspiremusic/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/dataset/dataset.py b/inspiremusic/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..188872bf21cfbf55bd3df12ff3071e62c8b49f06 --- /dev/null +++ b/inspiremusic/dataset/dataset.py @@ -0,0 +1,154 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import json +import math +from functools import partial + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset +from inspiremusic.utils.file_utils import read_lists, read_json_lists + +class Processor(IterableDataset): + + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + # force datalist even + + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + if len(data) < self.world_size: + print(len(data), self.world_size) + data = data * math.ceil(self.world_size / len(data)) + data = data[:self.world_size] + data = data[self.rank::self.world_size] + if len(data) < self.num_workers: + data = data * math.ceil(self.num_workers / len(data)) + data = data[:self.num_workers] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_list_file, + data_pipeline, + mode='train', + shuffle=True, + partition=True + ): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + tokenizer (BaseTokenizer): tokenizer to tokenize + partition(bool): whether to do data partition in terms of rank + """ + assert mode in ['train', 'inference', 'processing'] + lists = read_lists(data_list_file) + + dataset = DataList(lists, + shuffle=shuffle, + partition=partition) + + for func in data_pipeline: + dataset = Processor(dataset, func, mode=mode) + + return dataset diff --git a/inspiremusic/dataset/processor.py b/inspiremusic/dataset/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..21572b593e8e36b96202a1264b62cd73c7b6ecf0 --- /dev/null +++ b/inspiremusic/dataset/processor.py @@ -0,0 +1,595 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random + +import pyarrow.parquet as pq +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +import numpy as np +import re + +torchaudio.set_audio_backend('soundfile') + +AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} +CHORUS = {"intro": 0, "chorus": 1, "verse1": 2, "verse2": 3, "verse": 2, + "outro": 4} + +metadata_pattern = re.compile(r'^\[(ti|ar|al|by|offset):.*\]$') +timestamp_pattern = re.compile(r'^\[\d{2}:\d{2}\.\d{2}\](.*)$') + + +def parquet_opener(data, mode='train', audio_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + + url = sample['src'] + try: + df = pq.read_table(url).to_pandas() + for i in df.index: + sample.update(dict(df.loc[i])) + yield {**sample} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(url, ex)) + + +def clean_lyrics(data, mode="train"): + for sample in data: + lyrics = sample["text"] + cleaned = [] + for line in lyrics.splitlines(): + if metadata_pattern.match(line): + continue + timestamp_match = timestamp_pattern.match(line) + if timestamp_match: + lyric = timestamp_match.group(1).strip() + if lyric: + cleaned.append(lyric) + else: + if line.strip(): + cleaned.append(line.strip()) + sample["text"] = '\n'.join(cleaned) + yield sample + + +def cut_by_length(data, max_length=8000, num_times=4, mode="train"): + for sample in data: + if "semantic_token" in sample: + sample["semantic_token"] = [ + sample["semantic_token"][0][:max_length]] + if "acoustic_token" not in sample: + sample["acoustic_token"] = sample["speech_token"] + sample["acoustic_token"] = sample["acoustic_token"][ + :max_length * num_times] + + yield sample + + +def filter(data, + max_length=22500, # 22500 #5min #10240 + max_acoustic_length=45000, + min_length=10, + min_acoustic_length=150, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1, + mode='train'): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if mode == "train": + for sample in data: + if "semantic_token" in sample: + new_sample_frames = sample['semantic_token'][0].shape[0] + else: + new_sample_frames = sample['speech_token'] + + if "text_token" in sample: + new_sample_frames += len(sample['text_token']) + + if new_sample_frames > max_length or new_sample_frames < min_length: + print(f"skipped 1 item length={new_sample_frames}") + continue + + sample["chorus"] = sample["chorus"].split(",") + if not isinstance(sample["time_start"], np.ndarray): + sample["time_start"] = [sample["time_start"]] + sample["time_end"] = [sample["time_end"]] + for i, t in enumerate(sample["chorus"]): + if sample["chorus"][i] == "verse": + sample["chorus"][i] = "verse1" + + yield sample + + if mode == "train_flow": + for sample in data: + if "semantic_token" in sample: + new_sample_frames = sample['semantic_token'][0].shape[0] + if "acoustic_token" in sample: + target_sample_frames = sample['acoustic_token'][0].shape[0] + + if new_sample_frames > max_length or new_sample_frames < min_acoustic_length or new_sample_frames < min_length or target_sample_frames > max_acoustic_length: + print( + f"skipped 1 item length={new_sample_frames}, target_length={target_sample_frames}") + continue + + yield sample + + elif mode == "inference": + for sample in data: + yield sample + + +def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + sample_rate = sample['sample_rate'] + waveform = sample['speech'] + if sample_rate != resample_rate: + if sample_rate < min_sample_rate: + continue + sample['sample_rate'] = resample_rate + sample['speech'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + max_val = sample['speech'].abs().max() + if max_val > 1: + sample['speech'] /= max_val + yield sample + + +def truncate(data, truncate_length=24576, mode='train'): + """ Truncate data. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + truncate_length: truncate length + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + waveform = sample['audio'] + if waveform.shape[1] > truncate_length: + start = random.randint(0, waveform.shape[1] - truncate_length) + waveform = waveform[:, start: start + truncate_length] + else: + waveform = torch.concat([waveform, torch.zeros(1, truncate_length - + waveform.shape[1])], + dim=1) + sample['audio'] = waveform + yield sample + + +def upsample(data, resample_rate=48000, min_sample_rate=16000, mode='train', + n_codebook=4): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'semantic_token' in sample + # TODO: unify data processing key names + if 'acoustic_token' not in sample: + continue + + if 'sample_rate' in sample.keys(): + sample_rate = sample['sample_rate'] + else: + sample_rate = 24000 + token = np.array(sample['semantic_token'][0][:-1]) + + # Calculate the repetition factor for resampling + repetition_factor = int(n_codebook * resample_rate / sample_rate) + if sample_rate != resample_rate: + if sample_rate < min_sample_rate: + continue + sample['sample_rate'] = resample_rate + sample['semantic_token'] = np.array( + [np.repeat(token, repetition_factor)]) + + yield sample + +def compute_fbank(data, + feat_extractor, + mode='train'): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) + sample['speech_feat'] = mat + del sample['speech'] + yield sample + + +def parse_embedding(data, normalize, mode='train'): + """ Parse utt_embedding/spk_embedding + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + + for sample in data: + sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], + dtype=torch.float32) + sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], + dtype=torch.float32) + if normalize: + sample['utt_embedding'] = F.normalize(sample['utt_embedding'], + dim=0) + sample['spk_embedding'] = F.normalize(sample['spk_embedding'], + dim=0) + yield sample + +def tokenize(data, get_tokenizer, allowed_special, mode='train'): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + tokenizer = get_tokenizer() + + for sample in data: + assert 'text' in sample + sample['text_token'] = tokenizer.encode(sample['text'], + allowed_special=allowed_special) + yield sample + + +def shuffle(data, shuffle_size=10000, mode='train'): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500, mode='train'): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + if sample["chorus"] == "verse": + sample["chorus"] = "verse1" + + if sample["acoustic_token"].shape[0] == 1: + sample["acoustic_token"] = np.concatenate( + sample["acoustic_token"][0]) + else: + sample["acoustic_token"] = np.concatenate(sample["acoustic_token"]) + + sample["acoustic_token"] = torch.from_numpy(sample["acoustic_token"]) + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['acoustic_token'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['acoustic_token'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=32): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + data_empty = True + for sample in data: + data_empty = False + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if data_empty: + raise ValueError("data is empty") + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'acoustic_token' in sample + assert isinstance(sample['acoustic_token'], torch.Tensor) + + if 'semantic_token' in sample: + new_sample_frames = sample['semantic_token'][0].shape[0] + else: + new_sample_frames = sample['semantic_token'] + + if "text_token" in sample: + new_sample_frames += len(sample['text_token']) + + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + + if frames_after_padding > max_frames_in_batch: + if len(buf) > 0: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, + mode='train'): + """ Wrapper for static/dynamic batch + """ + if mode == 'inference': + return static_batch(data, 1) + elif mode == 'processing': + return static_batch(data, batch_size) + else: + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data, mode='train'): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + if mode == "train": + for sample in data: + assert isinstance(sample, list) + if len(sample) != 0: + acoustic_feat_len = torch.tensor( + [x['acoustic_token'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(acoustic_feat_len, descending=True) + utts = [sample[i]['utt'] for i in order] + acoustic_token = [ + sample[i]['acoustic_token'].clone().to(torch.int32) for i in + order] + acoustic_token_len = torch.tensor( + [i.size(0) for i in acoustic_token], dtype=torch.int32) + + acoustic_token = pad_sequence(acoustic_token, + batch_first=True, + padding_value=0) + + text = [sample[i]['text'] for i in order] + text_token = [torch.tensor(sample[i]['text_token']).long() for i + in order] + text_token_len = torch.tensor([i.size(0) for i in text_token], + dtype=torch.int32) + text_token = pad_sequence(text_token, batch_first=True, + padding_value=0) + time_start = torch.tensor( + [sample[i]['time_start'] for i in order]) + time_end = torch.tensor([sample[i]['time_end'] for i in order]) + + if isinstance(sample[0]['chorus'], str): + chorus = torch.tensor( + [CHORUS[sample[i]['chorus']] for i in order]) + else: + chorus = [ + torch.tensor([CHORUS[t] for t in sample[i]['chorus']]) + for i in order] + chorus = pad_sequence(chorus, batch_first=True, + padding_value=-1) + + batch = { + "utts" : utts, + "acoustic_token" : acoustic_token, + "acoustic_token_len": acoustic_token_len, + "time_start" : time_start, + "time_end" : time_end, + "chorus" : chorus, + "text" : text, + "text_token" : text_token, + "text_token_len" : text_token_len, + } + + if "semantic_token" in sample[0]: + semantic_token = [ + torch.tensor(sample[i]['semantic_token'][0], + dtype=torch.int32) for i in order] + semantic_token_len = torch.tensor( + [i.size(0) for i in semantic_token], + dtype=torch.int32) + semantic_token = pad_sequence(semantic_token, + batch_first=True, + padding_value=0) + batch.update({"semantic_token" : semantic_token, + "semantic_token_len": semantic_token_len}) + + yield batch + else: + logging.info("WARNING: sample is empty []!") + + elif mode == "inference": + for sample in data: + assert isinstance(sample, list) + utts = [sample[i]['utt'] for i in range(len(sample))] + text = [sample[i]['text'] for i in range(len(sample))] + text_token = [torch.tensor(sample[i]['text_token']).long() for i in + range(len(sample))] + text_token_len = torch.tensor([i.size(0) for i in text_token], + dtype=torch.int32) + text_token = pad_sequence(text_token, batch_first=True, + padding_value=0) + time_start = torch.tensor( + [sample[i]['time_start'] for i in range(len(sample))]) + time_end = torch.tensor( + [sample[i]['time_end'] for i in range(len(sample))]) + + if isinstance(sample[0]['chorus'], str): + chorus = torch.tensor([CHORUS[sample[i]['chorus']] for i in + range(len(sample))]) + else: + chorus = [torch.tensor([CHORUS[t] for t in sample[i]['chorus']]) + for i in range(len(sample))] + chorus = pad_sequence(chorus, batch_first=True, + padding_value=-1) + + if "acoustic_token" in sample[0]: + acoustic_token = [ + sample[i]['acoustic_token'].clone().to(torch.int32) for i in + range(len(sample))] + acoustic_token_len = torch.tensor( + [i.size(0) for i in acoustic_token], dtype=torch.int32) + acoustic_token = pad_sequence(acoustic_token, + batch_first=True, + padding_value=0) + else: + acoustic_token = None + acoustic_token_len = None + + batch = { + "utts" : utts, + "acoustic_token" : acoustic_token, + "acoustic_token_len": acoustic_token_len, + "time_start" : time_start, + "time_end" : time_end, + "chorus" : chorus, + "text" : text, + "text_token" : text_token, + "text_token_len" : text_token_len, + } + + if "semantic_token" in sample[0]: + semantic_token = [torch.tensor(sample[i]['semantic_token'][0], + dtype=torch.int32) for i in + range(len(sample))] + semantic_token_len = torch.tensor( + [i.size(0) for i in semantic_token], dtype=torch.int32) + semantic_token = pad_sequence(semantic_token, + batch_first=True, + padding_value=0) + batch.update({"semantic_token" : semantic_token, + "semantic_token_len": semantic_token_len}) + + yield batch diff --git a/inspiremusic/flow/decoder.py b/inspiremusic/flow/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1dff57ba44a1b4952d6836729233368c131fe8 --- /dev/null +++ b/inspiremusic/flow/decoder.py @@ -0,0 +1,277 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from einops import pack, rearrange, repeat +from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D +from matcha.models.components.transformer import BasicTransformerBlock + +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor): + x = torch.transpose(x, self.dim0, self.dim1) + return x + +class CausalBlock1D(Block1D): + def __init__(self, dim: int, dim_out: int): + super(CausalBlock1D, self).__init__(dim, dim_out) + self.block = torch.nn.Sequential( + CausalConv1d(dim, dim_out, 3), + Transpose(1, 2), + nn.LayerNorm(dim_out), + Transpose(1, 2), + nn.Mish(), + ) + + def forward(self, x: torch.Tensor, mask: torch.Tensor): + output = self.block(x * mask) + return output * mask + + +class CausalResnetBlock1D(ResnetBlock1D): + def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): + super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) + self.block1 = CausalBlock1D(dim, dim_out) + self.block2 = CausalBlock1D(dim_out, dim_out) + +class CausalConv1d(torch.nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + super(CausalConv1d, self).__init__(in_channels, out_channels, + kernel_size, stride, + padding=0, dilation=dilation, + groups=groups, bias=bias, + padding_mode=padding_mode, + device=device, dtype=dtype) + assert stride == 1 + self.causal_padding = (kernel_size - 1, 0) + + def forward(self, x: torch.Tensor): + x = F.pad(x, self.causal_padding) + x = super(CausalConv1d, self).forward(x) + return x + +class ConditionalDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = ResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + x = pack([x, mu], "b * t")[0] + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask diff --git a/inspiremusic/flow/flow.py b/inspiremusic/flow/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..ac83321e360bfa794a2802cd1247624045707f44 --- /dev/null +++ b/inspiremusic/flow/flow.py @@ -0,0 +1,143 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +from typing import Dict, Optional +import torch +import torch.nn as nn +from torch.nn import functional as F +from omegaconf import DictConfig +from inspiremusic.utils.mask import make_pad_mask +from inspiremusic.music_tokenizer.vqvae import VQVAE + +class MaskedDiff(torch.nn.Module): + def __init__(self, + input_size: int = 512, + output_size: int = 128, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + encoder: torch.nn.Module = None, + length_regulator: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, + 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', + 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), + 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, + mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000, + 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000}, + generator_model_dir: str = "pretrained_models/InspireMusic-Base/music_tokenizer", + num_codebooks: int = 4 + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = nn.Embedding(vocab_size, input_size) + + self.encoder = encoder + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = decoder + self.length_regulator = length_regulator + self.only_mask_loss = only_mask_loss + self.quantizer = VQVAE( f'{generator_model_dir}/config.json', + f'{generator_model_dir}/model.pt',with_encoder=True).quantizer + self.quantizer.eval() + self.num_codebooks = num_codebooks + self.cond = None + self.interpolate = False + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + + audio_token = batch['acoustic_token'].to(device) + audio_token_len = batch['acoustic_token_len'].to(device) + audio_token = audio_token.view(audio_token.size(0),-1,self.num_codebooks) + if "semantic_token" not in batch: + token = audio_token[:,:,0] + token_len = (audio_token_len/self.num_codebooks).long() + + else: + token = batch['semantic_token'].to(device) + token_len = batch['semantic_token_len'].to(device) + + with torch.no_grad(): + feat = self.quantizer.embed(audio_token) + feat_len = (audio_token_len/self.num_codebooks).long() + + token = self.input_embedding(token) + h, h_lengths = self.encoder(token, token_len) + h, h_lengths = self.length_regulator(h, feat_len) + + # get conditions + if self.cond: + conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] + conds = conds.transpose(1, 2) + else: + conds = None + + mask = (~make_pad_mask(feat_len)).to(h) + + loss, _ = self.decoder.compute_loss( + feat, + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + None, + cond=conds + ) + + return {'loss': loss} + + @torch.inference_mode() + def inference(self, + token, + token_len, + sample_rate): + assert token.shape[0] == 1 + + token = self.input_embedding(torch.clamp(token, min=0)) + h, h_lengths = self.encoder(token, token_len) + + if sample_rate == 48000: + token_len = 2 * token_len + + h, h_lengths = self.length_regulator(h, token_len) + + # get conditions + conds = None + + mask = (~make_pad_mask(token_len)).to(h) + feat = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=None, + cond=conds, + n_timesteps=10 + ) + return feat \ No newline at end of file diff --git a/inspiremusic/flow/flow_matching.py b/inspiremusic/flow/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..e803c6c59f4c2daa1fd1d0a7ee374706507e78c3 --- /dev/null +++ b/inspiremusic/flow/flow_matching.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +from matcha.models.components.flow_matching import BASECFM + + +class ConditionalCFM(BASECFM): + def __init__(self, in_channels, cfm_params, estimator: torch.nn.Module = None): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + ) + self.t_scheduler = cfm_params.t_scheduler + self.training_cfg_rate = cfm_params.training_cfg_rate + self.inference_cfg_rate = cfm_params.inference_cfg_rate + # Just change the architecture of the estimator here + self.estimator = estimator + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond) + # Classifier-Free Guidance inference introduced in VoiceBox + if self.inference_cfg_rate > 0: + cfg_dphi_dt = self.forward_estimator( + x, mask, + torch.zeros_like(mu), t, + torch.zeros_like(spks) if spks is not None else None, + torch.zeros_like(cond) if cond is not None else None + ) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - + self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def forward_estimator(self, x, mask, mu, t, spks, cond): + if isinstance(self.estimator, torch.nn.Module): + return self.estimator.forward(x, mask, mu, t, spks, cond) + elif isinstance(self.estimator, onnxruntime.InferenceSession): + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output = self.estimator.run(None, ort_inputs)[0] + return torch.tensor(output, dtype=x.dtype, device=x.device) + else: + self.estimator.set_input_shape('x', (2, 80, x.size(2))) + self.estimator.set_input_shape('mask', (2, 1, x.size(2))) + self.estimator.set_input_shape('mu', (2, 80, x.size(2))) + self.estimator.set_input_shape('t', (2,)) + self.estimator.set_input_shape('spks', (2, 80)) + self.estimator.set_input_shape('cond', (2, 80, x.size(2))) + # run trt engine + self.estimator.execute_v2([x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()]) + return x + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mo) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t = 1 - torch.cos(t * 0.5 * torch.pi) + + z = torch.randn_like(x1) + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + # during training, we randomly drop condition to trade off mode coverage and sample fidelity + if self.training_cfg_rate > 0: + cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate + mu = mu * cfg_mask.view(-1, 1, 1) + if cond is not None: + cond = cond * cfg_mask.view(-1, 1, 1) + + pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) + loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) + return loss, y + diff --git a/inspiremusic/flow/length_regulator.py b/inspiremusic/flow/length_regulator.py new file mode 100644 index 0000000000000000000000000000000000000000..05b74a9403a526c65dd05f0e558c62084b1772fe --- /dev/null +++ b/inspiremusic/flow/length_regulator.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple +import torch.nn as nn +import torch +from torch.nn import functional as F +from inspiremusic.utils.mask import make_pad_mask + + +class InterpolateRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + out_channels: int = None, + groups: int = 1, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + for _ in sampling_ratios: + module = nn.Conv1d(channels, channels, 3, 1, 1) + norm = nn.GroupNorm(groups, channels) + act = nn.Mish() + model.extend([module, norm, act]) + model.append( + nn.Conv1d(channels, out_channels, 1, 1) + ) + self.model = nn.Sequential(*model) + + def forward(self, x, ylens=None): + # x in (B, T, D) + mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) + x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') + out = self.model(x).transpose(1, 2).contiguous() + olens = ylens + return out * mask, olens + + def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): + # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel + # x in (B, T, D) + if x2.shape[1] > 40: + x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') + x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, + mode='linear') + x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') + x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) + else: + x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') + if x1.shape[1] != 0: + x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') + x = torch.concat([x1, x2], dim=2) + else: + x = x2 + out = self.model(x).transpose(1, 2).contiguous() + return out, mel_len1 + mel_len2 diff --git a/inspiremusic/hifigan/discriminator.py b/inspiremusic/hifigan/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc784599f3493e20830290a9cd182789c0428d5 --- /dev/null +++ b/inspiremusic/hifigan/discriminator.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm +from typing import List, Optional, Tuple +from einops import rearrange +from torchaudio.transforms import Spectrogram + + +class MultipleDiscriminator(nn.Module): + def __init__( + self, mpd: nn.Module, mrd: nn.Module + ): + super().__init__() + self.mpd = mpd + self.mrd = mrd + + def forward(self, y: torch.Tensor, y_hat: torch.Tensor): + y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] + this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1)) + y_d_rs += this_y_d_rs + y_d_gs += this_y_d_gs + fmap_rs += this_fmap_rs + fmap_gs += this_fmap_gs + this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat) + y_d_rs += this_y_d_rs + y_d_gs += this_y_d_gs + fmap_rs += this_fmap_rs + fmap_gs += this_fmap_gs + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + fft_sizes: Tuple[int, ...] = (2048, 1024, 512), + num_embeddings: Optional[int] = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + window_length: int, + num_embeddings: Optional[int] = None, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) + + def spectrogram(self, x): + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + x = rearrange(x, "b f t c -> b c t f") + # Split into bands + x_bands = [x[..., b[0]: b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): + x_bands = self.spectrogram(x) + fmap = [] + x = [] + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + x = torch.cat(x, dim=-1) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + + return x, fmap diff --git a/inspiremusic/hifigan/f0_predictor.py b/inspiremusic/hifigan/f0_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..331394ddc5c9240a2f66186ab0ce263d80ceeac0 --- /dev/null +++ b/inspiremusic/hifigan/f0_predictor.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + + +class ConvRNNF0Predictor(nn.Module): + def __init__(self, + num_class: int = 1, + in_channels: int = 80, + cond_channels: int = 512 + ): + super().__init__() + + self.num_class = num_class + self.condnet = nn.Sequential( + weight_norm( + nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + ) + self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.condnet(x) + x = x.transpose(1, 2) + return torch.abs(self.classifier(x).squeeze(-1)) diff --git a/inspiremusic/hifigan/generator.py b/inspiremusic/hifigan/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..033758c05d55b9a1e23f79c7b551c6000762ee26 --- /dev/null +++ b/inspiremusic/hifigan/generator.py @@ -0,0 +1,411 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HIFI-GAN""" + +from typing import Dict, Optional, List +import numpy as np +from scipy.signal import get_window +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d +from torch.nn import ConvTranspose1d +from torch.nn.utils import remove_weight_norm +from torch.nn.utils import weight_norm +from torch.distributions.uniform import Uniform + +from inspiremusic.transformer.activation import Snake +from inspiremusic.utils.common import get_padding +from inspiremusic.utils.common import init_weights + + +"""hifigan based generator implementation. + +This code is modified from https://github.com/jik876/hifi-gan + ,https://github.com/kan-bayashi/ParallelWaveGAN and + https://github.com/NVIDIA/BigVGAN + +""" + + +class ResBlock(torch.nn.Module): + """Residual block module in HiFiGAN/BigVGAN.""" + def __init__( + self, + channels: int = 512, + kernel_size: int = 3, + dilations: List[int] = [1, 3, 5], + ): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList() + self.convs2 = nn.ModuleList() + + for dilation in dilations: + self.convs1.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + padding=get_padding(kernel_size, dilation) + ) + ) + ) + self.convs2.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1) + ) + ) + ) + self.convs1.apply(init_weights) + self.convs2.apply(init_weights) + self.activations1 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs1)) + ]) + self.activations2 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs2)) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for idx in range(len(self.convs1)): + xt = self.activations1[idx](x) + xt = self.convs1[idx](xt) + xt = self.activations2[idx](xt) + xt = self.convs2[idx](xt) + x = xt + x + return x + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): + remove_weight_norm(self.convs1[idx]) + remove_weight_norm(self.convs2[idx]) + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + @torch.no_grad() + def forward(self, f0): + """ + :param f0: [B, 1, sample_len], Hz + :return: [B, 1, sample_len] + """ + + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) + for i in range(self.harmonic_num + 1): + F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate + + theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) + u_dist = Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) + phase_vec[:, 0, :] = 0 + + # generate sine waveforms + sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2)) + sine_wavs = sine_wavs.transpose(1, 2) + uv = uv.transpose(1, 2) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class HiFTGenerator(nn.Module): + """ + HiFTNet Generator: Neural Source Filter + ISTFTNet + https://arxiv.org/abs/2309.09493 + """ + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + nb_harmonics: int = 8, + sampling_rate: int = 22050, + nsf_alpha: float = 0.1, + nsf_sigma: float = 0.003, + nsf_voiced_threshold: float = 10, + upsample_rates: List[int] = [8, 8], + upsample_kernel_sizes: List[int] = [16, 16], + istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + source_resblock_kernel_sizes: List[int] = [7, 11], + source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]], + lrelu_slope: float = 0.1, + audio_limit: float = 0.99, + f0_predictor: torch.nn.Module = None, + ): + super(HiFTGenerator, self).__init__() + + self.out_channels = 1 + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.istft_params = istft_params + self.lrelu_slope = lrelu_slope + self.audio_limit = audio_limit + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=sampling_rate, + upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], + harmonic_num=nb_harmonics, + sine_amp=nsf_alpha, + add_noise_std=nsf_sigma, + voiced_threshod=nsf_voiced_threshold) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) + + self.conv_pre = weight_norm( + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + # Up + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + # Down + self.source_downs = nn.ModuleList() + self.source_resblocks = nn.ModuleList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)): + if u == 1: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) + ) + else: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2)) + ) + + self.source_resblocks.append( + ResBlock(base_channels // (2 ** (i + 1)), k, d) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = nn.ReflectionPad1d((1, 0)) + self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) + self.f0_predictor = f0_predictor + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + self.m_source.remove_weight_norm() + for l in self.source_downs: + remove_weight_norm(l) + for l in self.source_resblocks: + l.remove_weight_norm() + + def _stft(self, x): + spec = torch.stft( + x, + self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), + return_complex=True) + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[..., 0], spec[..., 1] + + def _istft(self, magnitude, phase): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) + inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], + self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) + return inverse_transform + + def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) + s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + # fusion + si = self.source_downs[i](s_stft) + si = self.source_resblocks[i](si) + x = x + si + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) + phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + + x = self._istft(magnitude, phase) + x = torch.clamp(x, -self.audio_limit, self.audio_limit) + return x + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + speech_feat = batch['speech_feat'].transpose(1, 2).to(device) + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # mel+source->speech + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, f0 + + @torch.inference_mode() + def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, s diff --git a/inspiremusic/hifigan/hifigan.py b/inspiremusic/hifigan/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..8d7b612d1cd86569d430e8bf03bc7e3e0fa72957 --- /dev/null +++ b/inspiremusic/hifigan/hifigan.py @@ -0,0 +1,66 @@ +from typing import Dict, Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss +from inspiremusic.utils.losses import tpr_loss, mel_loss + +class HiFiGan(nn.Module): + def __init__(self, generator, discriminator, mel_spec_transform, + multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0, + tpr_loss_weight=1.0, tpr_loss_tau=0.04): + super(HiFiGan, self).__init__() + self.generator = generator + self.discriminator = discriminator + self.mel_spec_transform = mel_spec_transform + self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight + self.feat_match_loss_weight = feat_match_loss_weight + self.tpr_loss_weight = tpr_loss_weight + self.tpr_loss_tau = tpr_loss_tau + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + if batch['turn'] == 'generator': + return self.forward_generator(batch, device) + else: + return self.forward_discriminator(batch, device) + + def forward_generator(self, batch, device): + real_speech = batch['speech'].to(device) + pitch_feat = batch['pitch_feat'].to(device) + # 1. calculate generator outputs + generated_speech, generated_f0 = self.generator(batch, device) + # 2. calculate discriminator outputs + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) + # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional] + loss_gen, _ = generator_loss(y_d_gs) + loss_fm = feature_loss(fmap_rs, fmap_gs) + loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) + if self.tpr_loss_weight != 0: + loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) + else: + loss_tpr = torch.zeros(1).to(device) + loss_f0 = F.l1_loss(generated_f0, pitch_feat) + loss = loss_gen + self.feat_match_loss_weight * loss_fm + \ + self.multi_mel_spectral_recon_loss_weight * loss_mel + \ + self.tpr_loss_weight * loss_tpr + loss_f0 + return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} + + def forward_discriminator(self, batch, device): + real_speech = batch['speech'].to(device) + # 1. calculate generator outputs + with torch.no_grad(): + generated_speech, generated_f0 = self.generator(batch, device) + # 2. calculate discriminator outputs + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) + # 3. calculate discriminator losses, tpr losses [Optional] + loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) + if self.tpr_loss_weight != 0: + loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) + else: + loss_tpr = torch.zeros(1).to(device) + loss = loss_disc + self.tpr_loss_weight * loss_tpr + return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr} diff --git a/inspiremusic/llm/llm.py b/inspiremusic/llm/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..2e16055bbd9a58d819c3475017e50d2a864baf7e --- /dev/null +++ b/inspiremusic/llm/llm.py @@ -0,0 +1,409 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Callable, List, Generator +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence, unpad_sequence +from inspiremusic.utils.common import IGNORE_ID +from inspiremusic.transformer.label_smoothing_loss import LabelSmoothingLoss +from inspiremusic.utils.common import th_accuracy +from torch import Tensor +from math import log +from einops import rearrange, reduce, repeat +import logging + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1).to(torch.float16) + +class LLM(torch.nn.Module): + def __init__( + self, + text_encoder_input_size: int, + llm_input_size: int, + llm_output_size: int, + audio_token_size: int, + llm: torch.nn.Module, + sampling: Callable, + text_encoder_conf: Dict = None, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + frozen_input_embed: bool = False, + **kwargs, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.audio_token_size = audio_token_size + # 1. build text token inputs related modules + + if llm is None: + self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size) + else: + self.text_embedding = llm.model.model.embed_tokens + if frozen_input_embed: + print("Freezing input embedding layer") + for p in self.text_embedding.parameters(): + p.requires_grad = False + self.chorus_embedding = torch.nn.Embedding(5, llm_input_size) # intro, chorus, verse1, verse2 , outro + + self.text_encoder_conf = text_encoder_conf + self.text_encoder = self.build_encoder(text_encoder_conf) + self.infer_cfg_ratio = kwargs.get("infer_cfg_ratio", None) + logging.info(f"infer_cfg_ratio: {self.infer_cfg_ratio}") + self.train_cfg_ratio = kwargs.get("train_cfg_ratio", None) + logging.info(f"train_cfg_ratio: {self.train_cfg_ratio}") + # 2. build audio token language model related modules + self.sos_eos = 0 + self.task_id = 1 + + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, audio_token_size + 1) + self.criterion_ce = LabelSmoothingLoss( + size=audio_token_size + 1, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build audio token related modules + self.speech_embedding = torch.nn.Embedding(audio_token_size, llm_input_size) + self.spk_embed_affine_layer = torch.nn.Linear(192, llm_input_size) + self.num_codebooks = 4 + # 4. sampling method + self.sampling = sampling + self.time_embedding = SinusoidalEmbedding(llm_input_size) + + def cfg_dropout(self, text_token, text_token_len, p): + # Classifier-Free Guidance Dropout + B = text_token.size(0) + num_samples_to_mask = int(p * B) + if num_samples_to_mask == 0: + num_samples_to_mask = 1 + indices_to_mask = torch.randperm(B, device=text_token.device)[:num_samples_to_mask] + text_token[indices_to_mask] = 0 + text_token_len[indices_to_mask] = 0 + + return text_token, text_token_len + + def build_encoder(self, encoder_conf=None): + if encoder_conf is None: + assert hasattr(self, "encoder_conf"), \ + "function param encoder_conf is None and model doesn't has encoder_conf attribute either." + encoder_conf = self.encoder_conf + + encoder_name = encoder_conf.pop("name", "transformer") + model = None + if encoder_name == "transformer": + from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder + model = ConformerEncoder( + **encoder_conf, + input_size=self.input_size, + use_cnn_module=False, + macaron_style=False, + ) + elif encoder_name == "conformer": + from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder + model = ConformerEncoder( + **encoder_conf, + input_size=self.input_size, + ) + elif encoder_name == "llama_encoder": + from inspiremusic.transformer.encoder.llama_encoder import LlamaEncoder + model = LlamaEncoder( + **encoder_conf, + input_size=self.input_size, + ) + elif encoder_name == "qwen2": + from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder + model = QwenEncoder( + **encoder_conf, + input_size=self.input_size, + ) + elif encoder_name == "qwen2.5": + from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder + model = QwenEncoder( + **encoder_conf, + input_size=self.input_size, + ) + + encoder_conf["name"] = encoder_name + + return model + + def encode(self, + text: torch.Tensor, + text_lengths: torch.Tensor): + if self.text_encoder is not None: + encoder_out, encoder_mask = self.text_encoder(text, text_lengths, + decoding_chunk_size=1, + num_decoding_left_chunks=-1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out = self.text_encoder_affine_layer(encoder_out) + else: + encoder_out, encoder_out_lens = text, text_lengths + return encoder_out, encoder_out_lens + + def pad_unpad_sequence(self, sos_eos_emb, embeddings, text_token, + text_token_len, task_id_emb, audio_token, + audio_token_len, seg_len): + text_token = unpad_sequence(text_token, text_token_len.cpu(), + batch_first=True) + + audio_token = unpad_sequence(audio_token, audio_token_len.cpu(), + batch_first=True) + + for i in range(len(embeddings)): + embeddings[i] = unpad_sequence(embeddings[i], seg_len.cpu(), batch_first=True) + + lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0)] + [embedding[i] for embedding in embeddings] + [text_token[i], task_id_emb.squeeze(dim=0), audio_token[i]], dim=0) for i in range(len(text_token))] + lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) + return lm_input, lm_input_len + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + mask = True + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + if "semantic_token" not in batch: + audio_token = batch['acoustic_token'].to(device) + audio_token_len = batch['acoustic_token_len'].to(device) + audio_token = audio_token.view(audio_token.size(0), -1, self.num_codebooks) + audio_token = audio_token[:, :, 0] + audio_token_len = (audio_token_len / self.num_codebooks).long() + + else: + audio_token = batch['semantic_token'].to(device) + audio_token_len = batch['semantic_token_len'].to(device) + + time_start = batch['time_start'].to(device) + time_end = batch['time_end'].to(device) + chorus = batch['chorus'].to(device) + # 1. encode text_token + + if self.train_cfg_ratio > 0: + # Classifier-Free Guidance + text_token, _ = self.cfg_dropout(text_token, text_token_len, self.train_cfg_ratio) + + # 2. Time Embedding & chorus embedding + text_token = self.text_embedding(text_token) + text_token, text_token_len = self.encode(text_token, text_token_len) + if mask: + time_mask = time_start != -1.0 + seg_len = time_mask.sum(-1) + time_start = time_start.masked_fill(~time_mask, 0.0) + time_end = time_end.masked_fill(~time_mask, 0.0) + chorus = chorus.masked_fill(~time_mask, 0) + time_start_embed = self.time_embedding(time_start.view(-1)).to(text_token.dtype) + time_end_embed = self.time_embedding(time_end.view(-1)).to(text_token.dtype) + time_start_embed = time_start_embed.view(chorus.size(0), chorus.size(1), -1) + time_end_embed = time_end_embed.view(chorus.size(0), chorus.size(1), -1) + chorus_embed = self.chorus_embedding(chorus) + lm_target = [torch.tensor([IGNORE_ID] * (1 + 3 * seg_len[i] + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))] + else: + time_start_embed = self.time_embedding(time_start).to(text_token.dtype) + time_end_embed = self.time_embedding(time_end).to(text_token.dtype) + chorus_embed = self.chorus_embedding(chorus) + + lm_target = [torch.tensor( + [IGNORE_ID] * (4 + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))] + + lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 3. eos and task_id + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + # 4. encode audio_token + audio_token = self.speech_embedding(audio_token) + + # 5. unpad and pad + lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, + [time_start_embed, + time_end_embed, + chorus_embed], + text_token, + text_token_len, + task_id_emb, + audio_token, + audio_token_len, + seg_len) + # 6. run lm forward + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target) + + acc = th_accuracy(logits.view(-1, self.audio_token_size + 1), lm_target, ignore_label=IGNORE_ID) + + return {'loss': loss, 'acc': acc} + + def sampling_ids( + self, + weighted_scores: torch.Tensor, + decoded_tokens: List, + ignore_eos: bool = True, + ): + top_ids = self.sampling(weighted_scores, decoded_tokens) + return top_ids + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + audio_token: torch.Tensor, + audio_token_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_audio_token: torch.Tensor, + prompt_audio_token_len: torch.Tensor, + embeddings: List, + duration_to_gen: float = 300, + task: str = "continuation", + token_rate: int = 75, + limit_audio_prompt_len: int = 5, + ) -> Generator[torch.Tensor, None, None]: + device = text.device + + if text is not None: + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + infer_cfg = self.infer_cfg_ratio >= 0.0 + if infer_cfg: + text_cfg = self.text_embedding(text.new_zeros(text.shape)) + text = self.text_embedding(text) + + # 1. encode text + text, text_len = self.encode(text, text_len) + + # 2. encode embedding + if embeddings is not None: + time_start, time_end, chorus = embeddings + + if len(chorus.shape) == 1: + time_start_embed = self.time_embedding(time_start).reshape(1, 1, -1) # .half() + time_end_embed = self.time_embedding(time_end).reshape(1, 1, -1) # .half() + chorus_embed = self.chorus_embedding(chorus).reshape(1, 1, -1) # .half() + else: + time_start_embed = self.time_embedding( + time_start.view(-1)).reshape(1, chorus.size(1), -1) # .half() + time_end_embed = self.time_embedding(time_end.view(-1)).reshape(1, chorus.size(1), -1) # .half() + chorus_embed = self.chorus_embedding(chorus) # .half() + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + if audio_token_len: + audio_token = audio_token[:, :(limit_audio_prompt_len * token_rate)] + audio_token_emb = self.speech_embedding(audio_token) + else: + audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + + if prompt_audio_token_len: + prompt_audio_token_emb = self.speech_embedding(prompt_audio_token) + else: + prompt_audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + # Check if removing prompt audio token will fail decoding. + + if task == "continuation": + lm_input = torch.concat( + [sos_eos_emb, time_start_embed, time_end_embed, + chorus_embed, text, task_id_emb, audio_token_emb], dim=1) + + if infer_cfg: + audio_cfg = self.speech_embedding( + audio_token.new_zeros(audio_token.shape)) + lm_cf_input = torch.concat( + [sos_eos_emb, torch.rand_like(time_start_embed), + torch.rand_like(time_end_embed), + torch.rand_like(chorus_embed), text_cfg, task_id_emb, + audio_cfg], dim=1) + lm_input = torch.cat([lm_input, lm_cf_input], 0) + else: + lm_input = torch.concat( + [sos_eos_emb, time_start_embed, time_end_embed, + chorus_embed, text, task_id_emb], dim=1) + if infer_cfg: + lm_cf_input = torch.concat( + [sos_eos_emb, torch.rand_like(time_start_embed), + torch.rand_like(time_end_embed), + torch.rand_like(chorus_embed), text_cfg, task_id_emb], + dim=1) + lm_input = torch.cat([lm_input, lm_cf_input], 0) + + # 4. cal min/max_length + min_len = 0.9 * duration_to_gen * token_rate + max_len = duration_to_gen * token_rate + logging.info( + f"LLM generation sequence length: {max_len}, generate audio length {duration_to_gen}s.") + + # 5. step by step decode + out_tokens = [] + offset = 0 + state = None + + for i in range(int(max_len)): + y_pred, _, state = self.llm.forward_one_step(lm_input, torch.ones(lm_input.shape[0], lm_input.shape[1], device=lm_input.device).to(torch.bool), cache=state) + logits = self.llm_decoder(y_pred[:, -1]) + if infer_cfg: + # perform context free guidance + logits_cf = logits[1] + logits = logits[0] + infer_cfg_ratio = self.infer_cfg_ratio + logits = infer_cfg_ratio * logits + (1 - infer_cfg_ratio) * logits_cf + + logp = logits.log_softmax(dim=-1) + logp = logp.squeeze(dim=0) + + if i < int(min_len): + logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16) + + if i < int(min_len): + logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16) + + top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item() + + if top_ids == self.audio_token_size: + break + + # # in stream mode, yield token one by one + + yield torch.tensor([[top_ids]], dtype=torch.int64, device=device) + out_tokens.append(top_ids) + offset += lm_input.size(1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + if infer_cfg: + lm_input = lm_input.repeat(2, 1, 1) diff --git a/inspiremusic/metrics/clap_score.py b/inspiremusic/metrics/clap_score.py new file mode 100644 index 0000000000000000000000000000000000000000..d77b200323d9374f4ea64ee3a7eeeb1c0c21fecf --- /dev/null +++ b/inspiremusic/metrics/clap_score.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import requests +from tqdm import tqdm +import torch +import numpy as np +import laion_clap +from clap_module.factory import load_state_dict +import librosa +import pyloudnorm as pyln + +# following documentation from https://github.com/LAION-AI/CLAP +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + +def float32_to_int16(x): + x = np.clip(x, a_min=-1., a_max=1.) + return (x * 32767.).astype(np.int16) + + +def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_audioset_epoch_15_esc_90.14.pt'): + """ + Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and + the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP + + This evaluation script assumes that audio_path files are identified with the ids in id2text. + + clap_score() evaluates all ids in id2text. + + GPU-based computation. + + Select one of the following models from https://github.com/LAION-AI/CLAP: + - music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen) + - music_audioset_epoch_15_esc_90.14.pt + - music_speech_epoch_15_esc_89.25.pt + - 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs) + + Params: + -- id2text: dictionary with the mapping between id (generated audio filenames in audio_path) + and text (prompt used to generate audio). clap_score() evaluates all ids in id2text. + -- audio_path: path where the generated audio files to evaluate are available. + -- audio_files_extension: files extension (default .wav) in eval_path. + -- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt'). + Returns: + -- CLAP-LION score + """ + # load model + if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt': + url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt' + clap_path = 'CLAP/music_speech_audioset_epoch_15_esc_89.98.pt' + model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda') + elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt': + url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt' + clap_path = 'CLAP/music_audioset_epoch_15_esc_90.14.pt' + model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda') + elif clap_model == 'music_speech_epoch_15_esc_89.25.pt': + url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt' + clap_path = 'CLAP/music_speech_epoch_15_esc_89.25.pt' + model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda') + elif clap_model == '630k-audioset-fusion-best.pt': + url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt' + clap_path = 'CLAP/630k-audioset-fusion-best.pt' + model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda') + else: + raise ValueError('clap_model not implemented') + + # download clap_model if not already downloaded + if not os.path.exists(clap_path): + print('Downloading ', clap_model, '...') + os.makedirs(os.path.dirname(clap_path), exist_ok=True) + + response = requests.get(url, stream=True) + total_size = int(response.headers.get('content-length', 0)) + + with open(clap_path, 'wb') as file: + with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar: + for data in response.iter_content(chunk_size=8192): + file.write(data) + progress_bar.update(len(data)) + + # fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118 + pkg = load_state_dict(clap_path) + pkg.pop('text_branch.embeddings.position_ids', None) + model.model.load_state_dict(pkg) + model.eval() + + if not os.path.isdir(audio_path): + raise ValueError(f'audio_path: {audio_path} does not exist') + + if id2text: + print('[EXTRACTING TEXT EMBEDDINGS] ') + batch_size = 64 + text_emb = {} + for i in tqdm(range(0, len(id2text), batch_size)): + batch_ids = list(id2text.keys())[i:i+batch_size] + batch_texts = [id2text[id] for id in batch_ids] + with torch.no_grad(): + embeddings = model.get_text_embedding(batch_texts, use_tensor=True) + for id, emb in zip(batch_ids, embeddings): + text_emb[id] = emb + + else: + raise ValueError('Must specify id2text') + + print('[EVALUATING GENERATIONS] ', audio_path) + score = 0 + count = 0 + for id in tqdm(id2text.keys()): + file_path = os.path.join(audio_path, str(id)+audio_files_extension) + if os.path.isfile(file_path): + with torch.no_grad(): + audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000 + audio = pyln.normalize.peak(audio, -1.0) + audio = audio.reshape(1, -1) # unsqueeze (1,T) + audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float() + audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True) + cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0] + print(f"{id} | CLAP score = {cosine_sim}") + score += cosine_sim + count += 1 + + return score / count if count > 0 else 0 + diff --git a/inspiremusic/metrics/openl3_fd.py b/inspiremusic/metrics/openl3_fd.py new file mode 100644 index 0000000000000000000000000000000000000000..78287970a8250dec2c6bbc77b5b4791122a02259 --- /dev/null +++ b/inspiremusic/metrics/openl3_fd.py @@ -0,0 +1,338 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import openl3 +import librosa +import numpy as np +from scipy import linalg +import glob +from tqdm import tqdm +import os +import soxr +import pyloudnorm as pyln + + +def calculate_embd_statistics(embd_lst): + if isinstance(embd_lst, list): + embd_lst = np.array(embd_lst) + mu = np.mean(embd_lst, axis=0) + sigma = np.cov(embd_lst, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """ + Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py + Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py + + Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Params: + -- mu1: Embedding's mean statistics for generated samples. + -- mu2: Embedding's mean statistics for reference samples. + -- sigma1: Covariance matrix over embeddings for generated samples. + -- sigma2: Covariance matrix over embeddings for reference samples. + Returns: + -- Fréchet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def extract_embeddings(directory_path, channels, samplingrate, content_type, openl3_hop_size, batch_size=16): + """ + Given a list of files, compute their embeddings in batches. + + If channels == 1: stereo audio is downmixed to mono. Mono embeddings are of dim=512. + + If channels == 2: mono audio is "faked" to stereo by copying the mono channel. + Stereo embeddings are of dim=1024, since we concatenate L (dim=512) and R (dim=512) embeddings. + + Params: + -- directory_path: path where the generated audio files are available. + -- channels: 1 (mono), or 2 (stereo) to get mono or stereo embeddings. + -- samplingrate: max bandwidth at which we evaluate the given signals. Up to 48kHz. + -- content_type: 'music' or 'env' to select a content type specific openl3 model. + -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec. + -- batch_size: number of audio files to process in each batch. + Returns: + -- list of embeddings: [np.array[], ...], as expected by calculate_frechet_distance() + """ + _, extension = os.path.splitext(directory_path) + if extension.lower() == ".scp": + wav_files = [] + with open(directory_path, "r") as f: + for line in f: + sec = line.strip().split(" ") + wav_files.append(sec[1]) + else: + wav_files = glob.glob(directory_path) + if len(wav_files) == 0: + raise ValueError('No files with this extension in this path!') + model = openl3.models.load_audio_embedding_model(input_repr="mel256", content_type=content_type, embedding_size=512) + + first = True + for i in tqdm(range(0, len(wav_files), batch_size)): + batch_files = wav_files[i:i+batch_size] + batch_audio_l = [] + batch_audio_r = [] + batch_sr = [] + + for file in batch_files: + audio, sr = librosa.load(file, sr=None, mono=False) + audio = audio.T + audio = pyln.normalize.peak(audio, -1.0) + if audio.shape[0] < sr: + print('Audio shorter than 1 sec, openl3 will zero-pad it:', file, audio.shape, sr) + + # resample to the desired evaluation bandwidth + audio = soxr.resample(audio, sr, samplingrate) # mono/stereo <- mono/stereo, input sr, output sr + + # mono embeddings are stored in batch_audio_l (R channel not used) + if channels == 1: + batch_audio_l.append(audio) + + elif channels == 2: + if audio.ndim == 1: + # if mono, "fake" stereo by copying mono channel to L and R + batch_audio_l.append(audio) + batch_audio_r.append(audio) + elif audio.ndim == 2: + # if it's stereo separate channels for openl3 + batch_audio_l.append(audio[:,0]) + batch_audio_r.append(audio[:,1]) + + batch_sr.append(samplingrate) + + # extracting mono embeddings (dim=512) or the L channel for stereo embeddings + emb, _ = openl3.get_audio_embedding(batch_audio_l, batch_sr, model=model, verbose=False, hop_size=openl3_hop_size, batch_size=batch_size) + + # format mono embedding + if channels == 1: + emb = np.concatenate(emb,axis=0) + + # extracting stereo embeddings (dim=1024), since we concatenate L (dim=512) and R (dim=512) embeddings + elif channels == 2: + # extract the missing R channel + emb_r, _ = openl3.get_audio_embedding(batch_audio_r, batch_sr, model=model, verbose=False, hop_size=openl3_hop_size, batch_size=batch_size) + emb = [np.concatenate([l, r], axis=1) for l, r in zip(emb, emb_r)] + emb = np.concatenate(emb, axis=0) + + # concatenate embeddings + if first: + embeddings = emb + first = False + else: + embeddings = np.concatenate([embeddings, emb], axis=0) + + # return as a list of embeddings: [np.array[], ...] + return [e for e in embeddings] + + +def extract_embeddings_nobatching(directory_path, channels, samplingrate, content_type, openl3_hop_size): + """ + Given a list of files, compute their embeddings one by one. + + If channels == 1: stereo audio is downmixed to mono. Mono embeddings are of dim=512. + + If channels == 2: mono audio is "faked" to stereo by copying the mono channel. + Stereo embeddings are of dim=1024, since we concatenate L (dim=512) and R (dim=512) embeddings. + + Params: + -- directory_path: path where the generated audio files are available. + -- channels: 1 (mono), or 2 (stereo) to get mono or stereo embeddings. + -- samplingrate: max bandwidth at which we evaluate the given signals. Up to 48kHz. + -- content_type: 'music' or 'env' to select a content type specific openl3 model. + -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec. + Returns: + -- list of embeddings: [np.array[], ...], as expected by calculate_frechet_distance() + """ + _, extension = os.path.splitext(directory_path) + if extension.lower() == ".scp": + wav_files = [] + with open(directory_path, "r") as f: + for line in f: + sec = line.strip().split(" ") + wav_files.append(sec[1]) + else: + wav_files = glob.glob(directory_path) + if len(wav_files) == 0: + raise ValueError('No files with this extension in this path!') + model = openl3.models.load_audio_embedding_model(input_repr="mel256", content_type=content_type, embedding_size=512) + + first = True + for file in tqdm(wav_files): + audio, sr = librosa.load(file, sr=None) + audio = pyln.normalize.peak(audio, -1.0) + if audio.shape[0] < sr: + print('Audio shorter than 1 sec, openl3 will zero-pad it:', file, audio.shape, sr) + + # resample to the desired evaluation bandwidth + audio = soxr.resample(audio, sr, samplingrate) # mono/stereo <- mono/stereo, input sr, output sr + + # extracting stereo embeddings (dim=1024), since we concatenate L (dim=512) and R (dim=512) embeddings + if channels == 2: + if audio.ndim == 1: + audio_l3, sr_l3 = audio, samplingrate + elif audio.ndim == 2: + # if it's stereo separate channels for openl3 + audio_l3 = [audio[:,0], audio[:,1]] + sr_l3 = [samplingrate, samplingrate] + emb, _ = openl3.get_audio_embedding(audio_l3, sr_l3, model=model, verbose=False, hop_size=openl3_hop_size) + if audio.ndim == 1: + # if mono audio, "fake" stereo by concatenating mono embedding as L and R embeddings + emb = np.concatenate([emb, emb],axis=1) + elif audio.ndim == 2: + emb = np.concatenate(emb,axis=1) + + # or extracting mono embeddings (dim=512) + elif channels == 1: + emb, _ = openl3.get_audio_embedding(audio, samplingrate, model=model, verbose=False, hop_size=openl3_hop_size) + + # concatenate embeddings + if first: + embeddings = emb + first = False + else: + embeddings = np.concatenate([embeddings, emb], axis=0) + + # return as a list of embeddings: [np.array[], ...] + return [e for e in embeddings] + + +def openl3_fd(channels, samplingrate, content_type, openl3_hop_size, eval_path, + eval_files_extension='.wav', ref_path=None, ref_files_extension='.wav', load_ref_embeddings=None, batching=False): + """ + Compute the Fréchet Distance between files in eval_path and ref_path. + + Fréchet distance computed on top of openl3 embeddings. + + GPU-based computation. + + Extracting the embeddings is timeconsuming. After being computed once, we store them. + We store pre-computed reference embedding statistics in load/openl3_fd/ + To load those and save computation, just set the path in load_ref_embeddings. + If load_ref_embeddings is set, ref_path is not required. + + Params: + -- channels: 1 (mono), or 2 (stereo) to get the Fréchet Distance over mono or stereo embeddings. + -- samplingrate: max bandwith at wich we evaluate the given signals. Up to 48kHz. + -- content_type: 'music' or 'env' to select a content type for openl3. + -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec. + -- eval_path: path where the generated audio files to evaluate are available. + -- eval_files_extenstion: files extension (default .wav) in eval_path. + -- ref_path: path where the reference audio files are available. (instead of load_ref_embeddings) + -- ref_files_extension: files extension (default .wav) in ref_path. + -- load_ref_embeddings: path to the reference embedding statistics. (inestead of ref_path) + -- batching: set batch size (with an int) or set to False (default False). + Returns: + -- Fréchet distance. + """ + + if not os.path.isdir(eval_path): + raise ValueError('eval_path does not exist') + + if load_ref_embeddings: + if not os.path.exists(load_ref_embeddings): + raise ValueError('load_ref_embeddings does not exist') + print('[LOADING REFERENCE EMBEDDINGS] ', load_ref_embeddings) + loaded = np.load(load_ref_embeddings) + mu_ref = loaded['mu_ref'] + sigma_ref = loaded['sigma_ref'] + + else: + if ref_path: + if not os.path.isdir(ref_path): + if not os.path.isfile(ref_path): + raise ValueError("ref_path does not exist") + if os.path.isfile(ref_path): + path = ref_path + else: + path = os.path.join(ref_path, '*'+ref_files_extension) + print('[EXTRACTING REFERENCE EMBEDDINGS] ', path) + if batching: + ref_embeddings = extract_embeddings(path, channels, samplingrate, content_type, openl3_hop_size, batch_size=batching) + else: + ref_embeddings = extract_embeddings_nobatching(path, channels, samplingrate, content_type, openl3_hop_size) + mu_ref, sigma_ref = calculate_embd_statistics(ref_embeddings) + + # store statistics to load later on + if not os.path.exists('load/openl3_fd'): + os.makedirs('load/openl3_fd/') + save_ref_embeddings_path = ( + 'load/openl3_fd/' + + path.replace('/', '_') + + '__channels' + str(channels) + + '__' + str(samplingrate) + + '__openl3' + str(content_type) + + '__openl3hopsize' + str(openl3_hop_size) + + '__batch' + str(batching) + + '.npz' + ) + np.savez(save_ref_embeddings_path, mu_ref=mu_ref, sigma_ref=sigma_ref) + print('[REFERENCE EMBEDDINGS][SAVED] ', save_ref_embeddings_path) + + else: + raise ValueError('Must specify ref_path or load_ref_embeddings') + + path = os.path.join(eval_path, '*'+eval_files_extension) + print('[EXTRACTING EVALUATION EMBEDDINGS] ', path) + if batching: + eval_embeddings = extract_embeddings(path, channels, samplingrate, content_type, openl3_hop_size, batch_size=batching) + else: + eval_embeddings = extract_embeddings_nobatching(path, channels, samplingrate, content_type, openl3_hop_size) + mu_eval, sigma_eval = calculate_embd_statistics(eval_embeddings) + + fd = calculate_frechet_distance(mu_eval, sigma_eval, mu_ref, sigma_ref) + if load_ref_embeddings: + print('[FRéCHET DISTANCE] ', eval_path, load_ref_embeddings, fd) + else: + print('[FRéCHET DISTANCE] ', eval_path, ref_path, fd) + + return fd \ No newline at end of file diff --git a/inspiremusic/metrics/passt_kld.py b/inspiremusic/metrics/passt_kld.py new file mode 100644 index 0000000000000000000000000000000000000000..aa27835ee82161f9e7cd1c7f9d99c07409bbbfc0 --- /dev/null +++ b/inspiremusic/metrics/passt_kld.py @@ -0,0 +1,232 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +import os +import contextlib +from functools import partial +from tqdm import tqdm +import pickle +import numpy as np +import librosa +from hear21passt.base import get_basic_model +import pyloudnorm as pyln + +import torch +import torch.nn.functional as F + + +SAMPLING_RATE = 32000 + + +class _patch_passt_stft: + """ + From version 1.8.0, return_complex must always be given explicitly + for real inputs and return_complex=False has been deprecated. + + Decorator to patch torch.stft in PaSST that uses an old stft version. + + Adapted from: https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py + """ + def __init__(self): + self.old_stft = torch.stft + + def __enter__(self): + # return_complex is a mandatory parameter in latest torch versions. + # torch is throwing RuntimeErrors when not set. + # see: https://pytorch.org/docs/1.7.1/generated/torch.stft.html?highlight=stft#torch.stft + # see: https://github.com/kkoutini/passt_hear21/commit/dce83183674e559162b49924d666c0a916dc967a + torch.stft = partial(torch.stft, return_complex=False) + + def __exit__(self, *exc): + torch.stft = self.old_stft + + +def return_probabilities(model, audio_path, window_size=10, overlap=5, collect='mean'): + """ + Given an audio and the PaSST model, return the probabilities of each AudioSet class. + + Audio is converted to mono at 32kHz. + + PaSST model is trained with 10 sec inputs. We refer to this parameter as the window_size. + We set it to 10 sec for consistency with PaSST training. + + For longer audios, we split audio into overlapping analysis windows of window_size and overlap of 10 and 5 seconds. + PaSST supports 10, 20 or 30 sec inputs. Not longer inputs: https://github.com/kkoutini/PaSST/issues/19 + + Note that AudioSet taggers normally use sigmoid output layers. Yet, to compute the + KL we work with normalized probabilities by running a softmax over logits as in MusicGen: + https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py + + This implementation assumes run will be on GPU. + + Params: + -- model: PaSST model on a GPU. + -- audio_path: path to the audio to be loaded with librosa. + -- window_size (default=10 sec): analysis window (and receptive field) of PaSST. + -- overlap (default=5 sec): overlap of the running analysis window for inputs longar than window_size (10 sec). + -- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along logits vector. + Returns: + -- 527 probabilities (after softmax, no logarithm). + """ + # load the audio using librosa + audio, _ = librosa.load(audio_path, sr=SAMPLING_RATE, mono=True) + audio = pyln.normalize.peak(audio, -1.0) + + # calculate the step size for the analysis windows with the specified overlap + step_size = int((window_size - overlap) * SAMPLING_RATE) + + # iterate over the audio, creating analysis windows + probabilities = [] + for i in range(0, max(step_size, len(audio) - step_size), step_size): + # extract the current analysis window + window = audio[i:i + int(window_size * SAMPLING_RATE)] + + # pad the window with zeros if it's shorter than the desired window size + if len(window) < int(window_size * SAMPLING_RATE): + # discard window if it's too small (avoid mostly zeros predicted as silence), as in MusicGen: + # https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py + if len(window) > int(window_size * SAMPLING_RATE * 0.15): + tmp = np.zeros(int(window_size * SAMPLING_RATE)) + tmp[:len(window)] = window + window = tmp + + # convert to a PyTorch tensor and move to GPU + audio_wave = torch.from_numpy(window.astype(np.float32)).unsqueeze(0).cuda() + + # get the probabilities for this analysis window + with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): + with torch.no_grad(), _patch_passt_stft(): + logits = model(audio_wave) + probabilities.append(torch.squeeze(logits)) + + probabilities = torch.stack(probabilities) + if collect == 'mean': + probabilities = torch.mean(probabilities, dim=0) + elif collect == 'max': + probabilities, _ = torch.max(probabilities, dim=0) + + return F.softmax(probabilities, dim=0).squeeze().cpu() + + +def passt_kld(ids, eval_path, eval_files_extension='.wav', ref_path=None, ref_files_extension='.wav', load_ref_probabilities=None, no_ids=[], collect='mean'): + """ + Compute KL-divergence between the label probabilities of the generated audio with respect to the original audio. + Both generated audio (in eval_path) and original audio (in ref_path) are represented by the same prompt/description. + Audios are identified by an id, that is the name of the file in both directories and links the audio with the prompt/description. + segmenting the audio + + For inputs longer that the 10 sec PaSST was trained on, we aggregate/collect via 'mean' (default) or 'max' pooling along the logits vector. + We split the inpot into overlapping analysis windows. Subsequently, we aggregate/collect (accross windows) the generated logits and then apply a softmax. + + This evaluation script assumes that ids are in both ref_path and eval_path. + + We label probabilities via the PaSST model: https://github.com/kkoutini/PaSST + + GPU-based computation. + + Extracting the probabilities is timeconsuming. After being computed once, we store them. + We store pre-computed reference probabilities in load/ + To load those and save computation, just set the path in load_ref_probabilities. + If load_ref_probabilities is set, ref_path is not required. + + Params: + -- ids: list of ids present in both eval_path and ref_path. + -- eval_path: path where the generated audio files to evaluate are available. + -- eval_files_extenstion: files extension (default .wav) in eval_path. + -- ref_path: path where the reference audio files are available. (instead of load_ref_probabilities) + -- ref_files_extenstion: files extension (default .wav) in ref_path. + -- load_ref_probabilities: path to the reference probabilities. (inestead of ref_path) + -- no_ids: it is possible that some reference audio is corrupted or not present. Ignore some this list of ids. + -- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along the logits vector. + Returns: + -- KL divergence + """ + with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): # capturing all useless outputs from passt + # load model + model = get_basic_model(mode="logits") + model.eval() + model = model.cuda() + + if not os.path.isdir(eval_path): + if not os.path.isfile(eval_path): + raise ValueError('eval_path does not exist') + + if load_ref_probabilities: + if not os.path.exists(load_ref_probabilities): + raise ValueError('load_ref_probabilities does not exist') + print('[LOADING REFERENCE PROBABILITIES] ', load_ref_probabilities) + with open(load_ref_probabilities, 'rb') as fp: + ref_p = pickle.load(fp) + + else: + if ref_path: + if not os.path.isdir(ref_path): + if os.path.isfile(ref_path): + id2utt = {} + with open(ref_path, "r") as f: + for line in f: + sec = line.strip().split(" ") + id2utt[sec[0]] = sec[1] + f.close() + else: + raise ValueError("ref_path does not exist") + print('[EXTRACTING REFERENCE PROBABILITIES] ', ref_path) + ref_p = {} + for id in tqdm(ids): + if id not in no_ids: + try: + if os.path.isfile(ref_path): + if id in id2utt.keys(): + audio_path = id2utt[id] + else: + raise ValueError(f"id: {id} not in {ref_path}!") + else: + audio_path = os.path.join(ref_path, str(id)+ref_files_extension) + if os.path.isfile(audio_path): + ref_p[id] = return_probabilities(model, audio_path, collect=collect) + except Exception as e: + print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.") + + # store reference probabilities to load later on + if not os.path.exists('load/passt_kld/'): + os.makedirs('load/passt_kld/') + save_ref_probabilities_path = 'load/passt_kld/'+ref_path.replace('/', '_')+'_collect'+str(collect)+'__reference_probabilities.pkl' + with open(save_ref_probabilities_path, 'wb') as fp: + pickle.dump(ref_p, fp) + print('[REFERENCE EMBEDDINGS][SAVED] ', save_ref_probabilities_path) + + else: + raise ValueError('Must specify ref_path or load_ref_probabilities') + + print('[EVALUATING GENERATIONS] ', eval_path) + + passt_kl = 0 + count = 0 + for id in tqdm(ids): + if id not in no_ids: + try: + audio_path = os.path.join(eval_path, str(id)+eval_files_extension) + if os.path.isfile(audio_path): + eval_p = return_probabilities(model, audio_path, collect=collect) + # note: F.kl_div(x, y) is KL(y||x) + # see: https://github.com/pytorch/pytorch/issues/7337 + # see: https://discuss.pytorch.org/t/kl-divergence-different-results-from-tf/56903/2 + passt_kl += F.kl_div((ref_p[id] + 1e-6).log(), eval_p, reduction='sum', log_target=False) + count += 1 + except Exception as e: + print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.") + return passt_kl / count if count > 0 else 0 diff --git a/inspiremusic/music_tokenizer/__init__.py b/inspiremusic/music_tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/music_tokenizer/env.py b/inspiremusic/music_tokenizer/env.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5843b8a244e9648bcd6f9e085dff9faa2e921a --- /dev/null +++ b/inspiremusic/music_tokenizer/env.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/inspiremusic/music_tokenizer/meldataset.py b/inspiremusic/music_tokenizer/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4a13b6247ec04b74559976687ff108c3380d42f1 --- /dev/null +++ b/inspiremusic/music_tokenizer/meldataset.py @@ -0,0 +1,226 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# code based on https://github.com/b04901014/MQTTS +import math +import os +import random + +import librosa +import numpy as np +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +def load_wav(full_path, sr): + wav, sr = librosa.load(full_path, sr=sr) + return wav, sr + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + +mel_basis = {} +hann_window = {} + +## modified to get stft with return complex value = True for pytorch ver2.0 +def mel_spectrogram(y, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + center=False): + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + '_' + + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int( + (n_fft - hop_size) / 2)), + mode='reflect') + y = y.squeeze(1) + + spec = torch.view_as_real(torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True + )) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, 'r') as f: + training_files = [l.strip() for l in f] + with open(a.input_validation_file, 'r') as f: + validation_files = [l.strip() for l in f] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__(self, + training_files, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + try: + # Note by yuantian: load with the sample_rate of config + audio, sampling_rate = load_wav(filename, sr=self.sampling_rate) + except Exception as e: + print(f"Error on audio: {filename}") + audio = np.random.normal(size=(160000, )) * 0.05 + sampling_rate = self.sampling_rate + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError("{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate)) + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start:audio_start + + self.segment_size] + else: + audio = torch.nn.functional.pad(audio, ( + 0, self.segment_size - audio.size(1)), 'constant') + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False) + else: + mel = np.load( + os.path.join(self.base_mels_path, + os.path.splitext(os.path.split(filename)[-1])[0] + + '.npy')) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, + mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start:mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size:( + mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, ( + 0, frames_per_seg - mel.size(2)), 'constant') + audio = torch.nn.functional.pad(audio, ( + 0, self.segment_size - audio.size(1)), 'constant') + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/inspiremusic/music_tokenizer/models.py b/inspiremusic/music_tokenizer/models.py new file mode 100644 index 0000000000000000000000000000000000000000..86302c699224252e72b05971c6a52a2ba7e8764d --- /dev/null +++ b/inspiremusic/music_tokenizer/models.py @@ -0,0 +1,548 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d +from torch.nn import Conv1d +from torch.nn import Conv2d +from torch.nn import ConvTranspose1d +from torch.nn.utils import remove_weight_norm +from torch.nn.utils import spectral_norm +from torch.nn.utils import weight_norm + +from inspiremusic.utils.tokenizer_utils import get_padding +from inspiremusic.utils.tokenizer_utils import init_weights + +LRELU_SLOPE = 0.1 + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))), weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))), weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(512, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, + k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + k, + u, + # padding=(u//2 + u%2), + padding=(k - u) // 2, + # output_padding=u%2 + ))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, + use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f( + Conv2d( + 1, + 32, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 32, + 128, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 128, + 512, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 512, + 1024, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class Encoder(torch.nn.Module): + def __init__(self, h): + super(Encoder, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(1, 32, 7, 1, padding=3)) + self.normalize = nn.ModuleList() + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + list( + reversed( + list(zip(h.upsample_rates, h.upsample_kernel_sizes))))): + self.ups.append( + weight_norm( + Conv1d( + 32 * (2**i), + 32 * (2**(i + 1)), + k, + u, + padding=((k - u) // 2) + # padding=(u//2 + u%2) + ))) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = 32 * (2**(i + 1)) + for j, (k, d) in enumerate( + zip( + list(reversed(h.resblock_kernel_sizes)), + list(reversed(h.resblock_dilation_sizes)))): + self.resblocks.append(resblock(h, ch, k, d)) + self.normalize.append( + torch.nn.GroupNorm(ch // 16, ch, eps=1e-6, affine=True)) + self.conv_post = Conv1d(512, 512, 3, 1, padding=1) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + xs = self.normalize[i * self.num_kernels + j](xs) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + xs = self.normalize[i * self.num_kernels + j](xs) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + + +class Quantizer_module(torch.nn.Module): + def __init__(self, n_e, e_dim): + super(Quantizer_module, self).__init__() + self.embedding = nn.Embedding(n_e, e_dim) + self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) + + def forward(self, x): + # compute Euclidean distance + d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \ + - 2 * torch.matmul(x, self.embedding.weight.T) + min_indicies = torch.argmin(d, 1) + z_q = self.embedding(min_indicies) + return z_q, min_indicies + + +class Quantizer(torch.nn.Module): + def __init__(self, h): + super(Quantizer, self).__init__() + assert 512 % h.n_code_groups == 0 + self.quantizer_modules = nn.ModuleList([ + Quantizer_module(h.n_codes, 512 // h.n_code_groups) + for _ in range(h.n_code_groups) + ]) + self.quantizer_modules2 = nn.ModuleList([ + Quantizer_module(h.n_codes, 512 // h.n_code_groups) + for _ in range(h.n_code_groups) + ]) + self.h = h + self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1 + self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25 + self.residul_layer = 2 + self.n_code_groups = h.n_code_groups + + def for_one_step(self, xin, idx): + xin = xin.transpose(1, 2) + x = xin.reshape(-1, 512) + x = torch.split(x, 512 // self.h.n_code_groups, dim=-1) + min_indicies = [] + z_q = [] + if idx == 0: + for _x, m in zip(x, self.quantizer_modules): + _z_q, _min_indicies = m(_x) + z_q.append(_z_q) + min_indicies.append(_min_indicies) #B * T, + z_q = torch.cat(z_q, -1).reshape(xin.shape) + # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) + loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \ + + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2) + z_q = xin + (z_q - xin).detach() + z_q = z_q.transpose(1, 2) + return z_q, loss, min_indicies + else: + for _x, m in zip(x, self.quantizer_modules2): + _z_q, _min_indicies = m(_x) + z_q.append(_z_q) + min_indicies.append(_min_indicies) #B * T, + z_q = torch.cat(z_q, -1).reshape(xin.shape) + # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) + loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \ + + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2) + z_q = xin + (z_q - xin).detach() + z_q = z_q.transpose(1, 2) + return z_q, loss, min_indicies + + def forward(self, xin): + #B, C, T + quantized_out = 0.0 + residual = xin + all_losses = [] + all_indices = [] + for i in range(self.residul_layer): + quantized, loss, indices = self.for_one_step(residual, i) # + residual = residual - quantized + quantized_out = quantized_out + quantized + all_indices.extend(indices) # + all_losses.append(loss) + all_losses = torch.stack(all_losses) + loss = torch.mean(all_losses) + return quantized_out, loss, all_indices + + def embed(self, x): + #idx: N, T, 4 + #print('x ', x.shape) + quantized_out = torch.tensor(0.0, device=x.device) + x = torch.split(x, 1, 2) # split, 将最后一个维度分开, 每个属于一个index group + #print('x.shape ', len(x),x[0].shape) + for i in range(self.residul_layer): + ret = [] + if i == 0: + for j in range(self.n_code_groups): + q = x[j] + embed = self.quantizer_modules[j] + q = embed.embedding(q.squeeze(-1).long()) + ret.append(q) + ret = torch.cat(ret, -1) + #print(ret.shape) + quantized_out = quantized_out + ret + else: + for j in range(self.n_code_groups): + q = x[j + self.n_code_groups] + embed = self.quantizer_modules2[j] + q = embed.embedding(q.squeeze(-1).long()) + ret.append(q) + ret = torch.cat(ret, -1) + quantized_out = quantized_out + ret + return quantized_out.transpose(1, 2) #N, C, T diff --git a/inspiremusic/music_tokenizer/vqvae.py b/inspiremusic/music_tokenizer/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..553275b3c25bb2f44bc9009ebaacaef7a346e206 --- /dev/null +++ b/inspiremusic/music_tokenizer/vqvae.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import torch +import torch.nn as nn +from inspiremusic.music_tokenizer.env import AttrDict +from inspiremusic.music_tokenizer.models import Encoder +from inspiremusic.music_tokenizer.models import Generator +from inspiremusic.music_tokenizer.models import Quantizer + + +class VQVAE(nn.Module): + def __init__(self, + config_path, + ckpt_path, + with_encoder=False): + super(VQVAE, self).__init__() + ckpt = torch.load(ckpt_path) + with open(config_path) as f: + data = f.read() + json_config = json.loads(data) + self.h = AttrDict(json_config) + self.quantizer = Quantizer(self.h) + self.generator = Generator(self.h) + self.generator.load_state_dict(ckpt['generator']) + self.quantizer.load_state_dict(ckpt['quantizer']) + if with_encoder: + self.encoder = Encoder(self.h) + self.encoder.load_state_dict(ckpt['encoder']) + + def forward(self, x): + # x is the codebook + # x.shape (B, T, Nq) + quant_emb = self.quantizer.embed(x) + return self.generator(quant_emb) + + def encode(self, x): + batch_size = x.size(0) + if len(x.shape) == 3 and x.shape[-1] == 1: + x = x.squeeze(-1) + c = self.encoder(x.unsqueeze(1)) + q, loss_q, c = self.quantizer(c) + c = [code.reshape(batch_size, -1) for code in c] + # shape: [N, T, 4] + return torch.stack(c, -1) diff --git a/inspiremusic/text/abs_tokenizer.py b/inspiremusic/text/abs_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..050e3a79bc9f195b6bc069b71327d69651fa2d1b --- /dev/null +++ b/inspiremusic/text/abs_tokenizer.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from abc import abstractmethod +from typing import Iterable +from typing import List + + +class AbsTokenizer(ABC): + @abstractmethod + def text2tokens(self, line: str) -> List[str]: + raise NotImplementedError + + @abstractmethod + def tokens2text(self, tokens: Iterable[str]) -> str: + raise NotImplementedError + + + + def encode(self, line: str, **kwargs) -> List[str]: + + return self.text2tokens(line) \ No newline at end of file diff --git a/inspiremusic/text/tokenizer.py b/inspiremusic/text/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2978783ce1833b797dad58c35ebebe64a492cf --- /dev/null +++ b/inspiremusic/text/tokenizer.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import re +from typing import Iterable, List, Union +import numpy as np +import torch + +from inspiremusic.text.abs_tokenizer import AbsTokenizer +from transformers import AutoTokenizer + +def get_tokenizer(tokenizer_name, tokenizer_path): + if "qwen" in tokenizer_name: + return QwenTokenizer(tokenizer_path,skip_special_tokens=True) + else: + return None + +class QwenTokenizer(AbsTokenizer): + def __init__( + self, + token_path: str, + skip_special_tokens: bool = True, + ): + super().__init__() + # NOTE: non-chat model, all these special tokens keep randomly initialized. + special_tokens = { + 'eos_token': '<|endoftext|>', + 'pad_token': '<|endoftext|>', + 'additional_special_tokens': [ + '<|im_start|>', '<|im_end|>', '<|endofprompt|>', + '[breath]', '', '', '[noise]', + '[laughter]', '[cough]', '[clucking]', '[accent]', + '[quick_breath]', + ] + } + self.tokenizer = AutoTokenizer.from_pretrained(token_path) + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + def get_vocab_size(self): + return self.tokenizer.vocab_size + + def text2tokens(self, line: str) -> List: + tokens = self.tokenizer([line], return_tensors="pt") + tokens = tokens["input_ids"][0].cpu().tolist() + return tokens + + def tokens2text(self, tokens) -> str: + tokens = torch.tensor(tokens, dtype=torch.int64) + text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0] + return text + + + +def get_qwen_vocab_size(token_type: str): + if "qwen1.5" in token_type.lower() or "qwen2.0" in token_type.lower() or "qwen2.5" in token_type.lower(): + # 293 for special and extra tokens, including endoftext, im_start, im_end, endofprompt and others in the future. + # model.vocab_size = 151936, tokenizer.vocab_size = 151643 + # NOTE: the first three special tokens (endoftext, im_start, im_end) are trained in Chat series models, + # others are kept in random initialization state. + return 151643 + 293 + else: + raise ValueError(f"Unknown tokenizer {token_type}") \ No newline at end of file diff --git a/inspiremusic/transformer/__init__.py b/inspiremusic/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/transformer/activation.py b/inspiremusic/transformer/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..b87727d9d1d0ba6a0aeea5e7df21674f99926787 --- /dev/null +++ b/inspiremusic/transformer/activation.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) +# 2020 Northwestern Polytechnical University (Pengcheng Guo) +# 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swish() activation function for Conformer.""" + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swish activation function.""" + return x * torch.sigmoid(x) + + +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/inspiremusic/transformer/attention.py b/inspiremusic/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf960e96ece8cc624c8c4b1ccd71a42910bfb62 --- /dev/null +++ b/inspiremusic/transformer/attention.py @@ -0,0 +1,328 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-Head Attention layer definition.""" + +import math +from typing import Tuple + +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + """Construct an MultiHeadedAttention object.""" + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0: # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + InspireMusic. + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + """ + q, k, v = self.forward_qkv(query, key, value) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, key_bias) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + return x + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used + if matrix_ac.shape != matrix_bd.shape: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/inspiremusic/transformer/convolution.py b/inspiremusic/transformer/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d96149154776000991a681a666fbe55e562fe --- /dev/null +++ b/inspiremusic/transformer/convolution.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""ConvolutionModule definition.""" + +from typing import Tuple + +import torch +from torch import nn + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model.""" + + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + super().__init__() + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + return x.transpose(1, 2), new_cache diff --git a/inspiremusic/transformer/decoder.py b/inspiremusic/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebf3b5109dc01149874d5c9f8c6414474d303bb --- /dev/null +++ b/inspiremusic/transformer/decoder.py @@ -0,0 +1,396 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Decoder definition.""" +from typing import Tuple, List, Optional + +import torch +import torch.utils.checkpoint as ckpt +import logging + +from inspiremusic.transformer.decoder_layer import DecoderLayer +from inspiremusic.transformer.positionwise_feed_forward import PositionwiseFeedForward +from inspiremusic.utils.class_utils import ( + INSPIREMUSIC_EMB_CLASSES, + INSPIREMUSIC_ATTENTION_CLASSES, + INSPIREMUSIC_ACTIVATION_CLASSES, +) +from inspiremusic.utils.mask import (subsequent_mask, make_pad_mask) + + +class TransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + src_attention: if false, encoder-decoder cross attention is not + applied, such as CIF model + key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. + tie_word_embedding: Tie or clone module weights depending of whether we are + using TorchScript or not + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + normalize_before: bool = True, + src_attention: bool = True, + key_bias: bool = True, + activation_type: str = "relu", + gradient_checkpointing: bool = False, + tie_word_embedding: bool = False, + ): + super().__init__() + attention_dim = encoder_output_size + activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]() + + self.embed = torch.nn.Sequential( + torch.nn.Identity() if input_layer == "no_pos" else + torch.nn.Embedding(vocab_size, attention_dim), + INSPIREMUSIC_EMB_CLASSES[input_layer](attention_dim, + positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5) + self.use_output_layer = use_output_layer + if use_output_layer: + self.output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.output_layer = torch.nn.Identity() + self.num_blocks = num_blocks + self.decoders = torch.nn.ModuleList([ + DecoderLayer( + attention_dim, + INSPIREMUSIC_ATTENTION_CLASSES["selfattn"]( + attention_heads, attention_dim, + self_attention_dropout_rate, key_bias), + INSPIREMUSIC_ATTENTION_CLASSES["selfattn"]( + attention_heads, attention_dim, src_attention_dropout_rate, + key_bias) if src_attention else None, + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate, activation), + dropout_rate, + normalize_before, + ) for _ in range(self.num_blocks) + ]) + + self.gradient_checkpointing = gradient_checkpointing + self.tie_word_embedding = tie_word_embedding + + def forward( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + r_ys_in_pad: torch.Tensor = torch.empty(0), + reverse_weight: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: not used in transformer decoder, in order to unify api + with bidirectional decoder + reverse_weight: not used in transformer decoder, in order to unify + api with bidirectional decode + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + torch.tensor(0.0), in order to unify api with bidirectional decoder + olens: (batch, ) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + tgt = ys_in_pad + maxlen = tgt.size(1) + # tgt_mask: (B, 1, L) + tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1) + tgt_mask = tgt_mask.to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), + device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + x, _ = self.embed(tgt) + if self.gradient_checkpointing and self.training: + x = self.forward_layers_checkpointed(x, tgt_mask, memory, + memory_mask) + else: + x = self.forward_layers(x, tgt_mask, memory, memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.use_output_layer: + x = self.output_layer(x) + olens = tgt_mask.sum(1) + return x, torch.tensor(0.0), olens + + def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, + memory_mask) + return x + + @torch.jit.unused + def forward_layers_checkpointed(self, x: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = ckpt.checkpoint( + layer.__call__, x, tgt_mask, memory, memory_mask) + return x + + def forward_one_step( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + x, _ = self.embed(tgt) + new_cache = [] + for i, decoder in enumerate(self.decoders): + if cache is None: + c = None + else: + c = cache[i] + x, tgt_mask, memory, memory_mask = decoder(x, + tgt_mask, + memory, + memory_mask, + cache=c) + new_cache.append(x) + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.use_output_layer: + y = torch.log_softmax(self.output_layer(y), dim=-1) + return y, new_cache + + def tie_or_clone_weights(self, jit_mode: bool = True): + """Tie or clone module weights (between word_emb and output_layer) + depending of whether we are using TorchScript or not""" + if not self.use_output_layer: + return + if jit_mode: + logging.info("clone emb.weight to output.weight") + self.output_layer.weight = torch.nn.Parameter( + self.embed[0].weight.clone()) + else: + logging.info("tie emb.weight with output.weight") + self.output_layer.weight = self.embed[0].weight + + if getattr(self.output_layer, "bias", None) is not None: + self.output_layer.bias.data = torch.nn.functional.pad( + self.output_layer.bias.data, + ( + 0, + self.output_layer.weight.shape[0] - + self.output_layer.bias.shape[0], + ), + "constant", + 0, + ) + + +class BiTransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + r_num_blocks: the number of right to left decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + r_num_blocks: int = 0, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + normalize_before: bool = True, + key_bias: bool = True, + gradient_checkpointing: bool = False, + tie_word_embedding: bool = False, + ): + + super().__init__() + self.tie_word_embedding = tie_word_embedding + self.left_decoder = TransformerDecoder( + vocab_size, + encoder_output_size, + attention_heads, + linear_units, + num_blocks, + dropout_rate, + positional_dropout_rate, + self_attention_dropout_rate, + src_attention_dropout_rate, + input_layer, + use_output_layer, + normalize_before, + key_bias=key_bias, + gradient_checkpointing=gradient_checkpointing, + tie_word_embedding=tie_word_embedding) + + self.right_decoder = TransformerDecoder( + vocab_size, + encoder_output_size, + attention_heads, + linear_units, + r_num_blocks, + dropout_rate, + positional_dropout_rate, + self_attention_dropout_rate, + src_attention_dropout_rate, + input_layer, + use_output_layer, + normalize_before, + key_bias=key_bias, + gradient_checkpointing=gradient_checkpointing, + tie_word_embedding=tie_word_embedding) + + def forward( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + r_ys_in_pad: torch.Tensor, + reverse_weight: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out), + used for right to left decoder + reverse_weight: used for right to left decoder + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + r_x: x: decoded token score (right to left decoder) + before softmax (batch, maxlen_out, vocab_size) + if use_output_layer is True, + olens: (batch, ) + """ + l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, + ys_in_lens) + r_x = torch.tensor(0.0) + if reverse_weight > 0.0: + r_x, _, olens = self.right_decoder(memory, memory_mask, + r_ys_in_pad, ys_in_lens) + return l_x, r_x, olens + + def forward_one_step( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + return self.left_decoder.forward_one_step(memory, memory_mask, tgt, + tgt_mask, cache) + + def tie_or_clone_weights(self, jit_mode: bool = True): + """Tie or clone module weights (between word_emb and output_layer) + depending of whether we are using TorchScript or not""" + self.left_decoder.tie_or_clone_weights(jit_mode) + self.right_decoder.tie_or_clone_weights(jit_mode) diff --git a/inspiremusic/transformer/decoder_layer.py b/inspiremusic/transformer/decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..91c7c5d7fb2a8e79cea7705646e5381016f73466 --- /dev/null +++ b/inspiremusic/transformer/decoder_layer.py @@ -0,0 +1,132 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decoder self-attention layer definition.""" +from typing import Optional, Tuple + +import torch +from torch import nn + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (torch.nn.Module): Inter-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + If `None` is passed, Inter-attention is not used, such as + CIF, GPT, and other decoder only model. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: nn.Module, + src_attn: Optional[nn.Module], + feed_forward: nn.Module, + dropout_rate: float, + normalize_before: bool = True, + ): + """Construct an DecoderLayer object.""" + super().__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-5) + self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.norm3 = nn.LayerNorm(size, eps=1e-5) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + + def forward( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor, + cache: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor + (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory + (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask + (#batch, maxlen_in). + cache (torch.Tensor): cached tensors. + (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = tgt_mask[:, -1:, :] + + x = residual + self.dropout( + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) + if not self.normalize_before: + x = self.norm1(x) + + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout( + self.src_attn(x, memory, memory, memory_mask)[0]) + if not self.normalize_before: + x = self.norm2(x) + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask diff --git a/inspiremusic/transformer/embedding.py b/inspiremusic/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..eae8c8ecabb15b4174cc3aa73c070ae702bb5f82 --- /dev/null +++ b/inspiremusic/transformer/embedding.py @@ -0,0 +1,294 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Positonal Encoding Module.""" + +import math +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import numpy as np + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.max_len = max_len + + self.pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + self.pe[:, 0::2] = torch.sin(position * div_term) + self.pe[:, 1::2] = torch.cos(position * div_term) + self.pe = self.pe.unsqueeze(0) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int, torch.tensor): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + + self.pe = self.pe.to(x.device) + pos_emb = self.position_encoding(offset, x.size(1), False) + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + else: # for batched streaming decoding on GPU + assert torch.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + \ + torch.arange(0, size).to(offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model + + if apply_dropout: + pos_emb = self.dropout(pos_emb) + return pos_emb + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.pe = self.pe.to(x.device) + x = x * self.xscale + pos_emb = self.position_encoding(offset, x.size(1), False) + return self.dropout(x), self.dropout(pos_emb) + + +class WhisperPositionalEncoding(PositionalEncoding): + """ Sinusoids position encoding used in openai-whisper.encoder + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): + super().__init__(d_model, dropout_rate, max_len) + self.xscale = 1.0 + log_timescale_increment = np.log(10000) / (d_model // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * + torch.arange(d_model // 2)) + scaled_time = torch.arange(max_len)[:, np.newaxis] * \ + inv_timescales[np.newaxis, :] + pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + delattr(self, "pe") + self.register_buffer("pe", pe.unsqueeze(0)) + + +class LearnablePositionalEncoding(PositionalEncoding): + """ Learnable position encoding used in openai-whisper.decoder + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): + super().__init__(d_model, dropout_rate, max_len) + # NOTE(xcsong): overwrite self.pe & self.xscale + self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) + self.xscale = 1.0 + + +class NoPositionalEncoding(torch.nn.Module): + """ No position encoding + """ + + def __init__(self, d_model: int, dropout_rate: float): + super().__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """ Just return zero vector for interface compatibility + """ + pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) + return self.dropout(x), pos_emb + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return torch.zeros(1, size, self.d_model) + + +class EspnetRelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Construct an PositionalEncoding object.""" + super(EspnetRelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: torch.Tensor): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.position_encoding(size=x.size(1), offset=offset) + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, + ] + return pos_emb diff --git a/inspiremusic/transformer/encoder.py b/inspiremusic/transformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a46b531778f3c012f0a7e4f5da3c1d6f707c358d --- /dev/null +++ b/inspiremusic/transformer/encoder.py @@ -0,0 +1,477 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder definition.""" +from typing import Tuple + +import torch +import torch.utils.checkpoint as ckpt + +from inspiremusic.transformer.convolution import ConvolutionModule +from inspiremusic.transformer.encoder_layer import TransformerEncoderLayer +from inspiremusic.transformer.encoder_layer import ConformerEncoderLayer +from inspiremusic.transformer.positionwise_feed_forward import PositionwiseFeedForward +from inspiremusic.utils.class_utils import ( + INSPIREMUSIC_EMB_CLASSES, + INSPIREMUSIC_SUBSAMPLE_CLASSES, + INSPIREMUSIC_ATTENTION_CLASSES, + INSPIREMUSIC_ACTIVATION_CLASSES, +) +from inspiremusic.utils.mask import make_pad_mask +from inspiremusic.utils.mask import add_optional_chunk_mask + + +class BaseEncoder(torch.nn.Module): + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + gradient_checkpointing: bool = False, + ): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + dropout_rate (float): dropout rate + attention_dropout_rate (float): dropout rate in attention + positional_dropout_rate (float): dropout rate after adding + positional encoding + input_layer (str): input layer type. + optional [linear, conv2d, conv2d6, conv2d8] + pos_enc_layer_type (str): Encoder positional encoding layer type. + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] + normalize_before (bool): + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. + """ + super().__init__() + self._output_size = output_size + + self.global_cmvn = global_cmvn + self.embed = INSPIREMUSIC_SUBSAMPLE_CLASSES[input_layer]( + input_size, + output_size, + dropout_rate, + INSPIREMUSIC_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.gradient_checkpointing = gradient_checkpointing + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + + if self.gradient_checkpointing and self.training: + xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, + mask_pad) + else: + xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs + + @torch.jit.unused + def forward_layers_checkpointed(self, xs: torch.Tensor, + chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs, + chunk_masks, pos_emb, + mask_pad) + return xs + + @torch.jit.export + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.size(0), att_cache.size(2) + chunk_size = xs.size(1) + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, + size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + + xs, _, new_att_cache, new_cnn_cache = layer( + xs, + att_mask, + pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) + if self.normalize_before: + xs = self.after_norm(xs) + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.cat(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=0) + + return (xs, r_att_cache, r_cnn_cache) + + @torch.jit.unused + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not preferred. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, att_cache, + cnn_cache) = self.forward_chunk(chunk_xs, offset, + required_cache_size, att_cache, + cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones((1, 1, ys.size(1)), + device=ys.device, + dtype=torch.bool) + return ys, masks + + +class TransformerEncoder(BaseEncoder): + """Transformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + key_bias: bool = True, + selfattention_layer_type: str = "selfattn", + activation_type: str = "relu", + gradient_checkpointing: bool = False, + ): + """ Construct TransformerEncoder + + See Encoder for the meaning of each parameter. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) + activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]() + self.encoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + output_size, + INSPIREMUSIC_ATTENTION_CLASSES[selfattention_layer_type](attention_heads, + output_size, + attention_dropout_rate, + key_bias), + PositionwiseFeedForward(output_size, linear_units, + dropout_rate, activation), + dropout_rate, normalize_before) for _ in range(num_blocks) + ]) + + +class ConformerEncoder(BaseEncoder): + """Conformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + key_bias: bool = True, + gradient_checkpointing: bool = False, + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) + activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]() + + # self-attention module definition + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + key_bias, + ) + # feed-forward module definition + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal) + + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + INSPIREMUSIC_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + PositionwiseFeedForward(*positionwise_layer_args), + PositionwiseFeedForward( + *positionwise_layer_args) if macaron_style else None, + ConvolutionModule( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + ) for _ in range(num_blocks) + ]) diff --git a/inspiremusic/transformer/encoder_layer.py b/inspiremusic/transformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb3a9d4e99d0f8f92ec1802a1a7620328e9353a --- /dev/null +++ b/inspiremusic/transformer/encoder_layer.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder self-attention layer definition.""" + +from typing import Optional, Tuple + +import torch +from torch import nn + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: torch.nn.Module, + dropout_rate: float, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-5) + self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): just for interface compatibility + to ConformerEncoderLayer + mask_pad (torch.Tensor): does not used in transformer layer, + just for unified api with conformer. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2), not used here, it's for interface + compatibility to ConformerEncoderLayer. + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). + + """ + residual = x + if self.normalize_before: + x = self.norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + return x, mask, new_att_cache, fake_cnn_cache + + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module + self.norm_final = nn.LayerNorm( + size, eps=1e-5) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/inspiremusic/transformer/label_smoothing_loss.py b/inspiremusic/transformer/label_smoothing_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..73ce09f35bfacb86730e39ef72a097f8a04e469b --- /dev/null +++ b/inspiremusic/transformer/label_smoothing_loss.py @@ -0,0 +1,97 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Label smoothing module.""" + +import torch +from torch import nn + + +class LabelSmoothingLoss(nn.Module): + """Label-smoothing loss. + + In a standard CE loss, the label's data distribution is: + [0,1,2] -> + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + + In the smoothing version CE Loss,some probabilities + are taken from the true label prob (1.0) and are divided + among other labels. + + e.g. + smoothing=0.1 + [0,1,2] -> + [ + [0.9, 0.05, 0.05], + [0.05, 0.9, 0.05], + [0.05, 0.05, 0.9], + ] + + Args: + size (int): the number of class + padding_idx (int): padding class id which will be ignored for loss + smoothing (float): smoothing rate (0.0 means the conventional CE) + normalize_length (bool): + normalize loss by sequence length if True + normalize loss by batch size if False + """ + + def __init__(self, + size: int, + padding_idx: int, + smoothing: float, + normalize_length: bool = False): + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = nn.KLDivLoss(reduction="none") + self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.normalize_length = normalize_length + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Compute loss between x and target. + + The model outputs and data labels tensors are flatten to + (batch*seqlen, class) shape and a mask is applied to the + padding part which should not be calculated for loss. + + Args: + x (torch.Tensor): prediction (batch, seqlen, class) + target (torch.Tensor): + target signal masked with self.padding_id (batch, seqlen) + Returns: + loss (torch.Tensor) : The KL loss, scalar float value + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + # use zeros_like instead of torch.no_grad() for true_dist, + # since no_grad() can not be exported by JIT + true_dist = torch.zeros_like(x) + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom diff --git a/inspiremusic/transformer/positionwise_feed_forward.py b/inspiremusic/transformer/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a2cf6e7315e3a5ed2794423daff0a59cc5b208 --- /dev/null +++ b/inspiremusic/transformer/positionwise_feed_forward.py @@ -0,0 +1,115 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) + + +class MoEFFNLayer(torch.nn.Module): + """ + Mixture of expert with Positionwise feed forward layer + See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf + The output dim is same with the input dim. + + Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 + https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 + Args: + n_expert: number of expert. + n_expert_per_token: The actual number of experts used for each frame + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + n_expert: int, + n_expert_per_token: int, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + super(MoEFFNLayer, self).__init__() + self.gate = torch.nn.Linear(idim, n_expert, bias=False) + self.experts = torch.nn.ModuleList( + PositionwiseFeedForward(idim, hidden_units, dropout_rate, + activation) for _ in range(n_expert)) + self.n_expert_per_token = n_expert_per_token + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Foward function. + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + + """ + B, L, D = xs.size( + ) # batch size, sequence length, embedding dimension (idim) + xs = xs.view(-1, D) # (B*L, D) + router = self.gate(xs) # (B*L, n_expert) + logits, indices = torch.topk( + router, self.n_expert_per_token + ) # probs:(B*L, n_expert), indices: (B*L, n_expert) + weights = torch.nn.functional.softmax( + logits, dim=1, + dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) + output = torch.zeros_like(xs) # (B*L, D) + for i, expert in enumerate(self.experts): + mask = indices == i + batch_idx, ith_expert = torch.where(mask) + output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( + xs[batch_idx]) + return output.view(B, L, D) diff --git a/inspiremusic/transformer/qwen_encoder.py b/inspiremusic/transformer/qwen_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ec3b6afbcf77e74201a50d08ff6042d31e569f --- /dev/null +++ b/inspiremusic/transformer/qwen_encoder.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from inspiremusic.utils.mask import make_pad_mask +from inspiremusic.utils.hinter import hint_once + +class QwenEncoder(nn.Module): + def __init__( + self, + input_size: int, + pretrain_path: str = "Qwen/Qwen2.0-0.5B", + trainable: bool = False, + do_fusion_emb: bool = False, + fusion_drop_rate: float = 0.0, + ): + super(QwenEncoder, self).__init__() + self.input_size = input_size + self.trainable = trainable + self.model = AutoModelForCausalLM.from_pretrained(pretrain_path, device_map="cpu") + self._output_size = self.model.config.hidden_size + self.do_fusion_emb = do_fusion_emb + self.hidden_norm = torch.nn.LayerNorm(self._output_size) + self.fusion_dropout = nn.Dropout(fusion_drop_rate) + if do_fusion_emb: + self.fusion_layer = torch.nn.Linear(self._output_size * 2, self._output_size) + self.emb_norm = torch.nn.LayerNorm(self._output_size) + self.fusion_norm = torch.nn.LayerNorm(self._output_size) + from inspiremusic.transformer.activation import Swish + self.fusion_act = Swish(self) + + if not self.trainable: + self.model.eval() + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + input_ids: torch.Tensor, + ilens: torch.Tensor, + ): + device = input_ids.device + input_ids = torch.clamp(input_ids, min=0, max=None) + input_masks = (~make_pad_mask(ilens)).to(device).long() + if not self.trainable: + with torch.no_grad(): + model_outputs = self.model( + input_ids=input_ids, + attention_mask=input_masks, + output_hidden_states=True + ) + else: + model_outputs = self.model( + input_ids=input_ids, + attention_mask=input_masks, + output_hidden_states=True + ) + outs = model_outputs.hidden_states[-1] + outs = self.hidden_norm(outs) + if self.do_fusion_emb: + hint_once("fuse embedding and LM outputs", "fuse_emb") + outs = self.fusion_dropout(self.fusion_act(outs)) + emb = model_outputs.hidden_states[0] + emb = self.fusion_dropout(self.fusion_act(self.emb_norm(emb))) + outs = self.fusion_layer( + torch.cat([outs, emb], dim=-1) + ) + outs = self.fusion_act(self.fusion_norm(outs)) + + return outs, ilens + + +class QwenEmbeddingEncoder(nn.Module): + def __init__( + self, + input_size: int, + pretrain_path: str = "Qwen/Qwen2.0-0.5B", + ): + super(QwenEmbeddingEncoder, self).__init__() + self.input_size = input_size + from transformers import Qwen2ForCausalLM + # self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2") + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, + device_map="cpu") + self._output_size = self.model.config.hidden_size + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + input_embeds: torch.Tensor, + ilens: torch.Tensor, + ): + input_masks = (~make_pad_mask(ilens)).to(input_embeds.device).long() + + outs = self.model( + inputs_embeds=input_embeds, + attention_mask=input_masks, + output_hidden_states=True, + return_dict=True, + ) + + return outs.hidden_states[-1], input_masks + + def forward_one_step(self, xs, masks, cache=None): + + outs = self.model( + inputs_embeds=xs, + attention_mask=masks, + output_hidden_states=True, + return_dict=True, + use_cache=True, + past_key_values=cache, + ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + + return xs, masks, new_cache + + +class QwenInputOnlyEncoder(nn.Module): + def __init__( + self, + input_size: int, + pretrain_path: str = "Qwen/Qwen2.0-0.5B", + ): + super(QwenInputOnlyEncoder, self).__init__() + self.input_size = input_size + from transformers import Qwen2ForCausalLM + # model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2") + model = Qwen2ForCausalLM.from_pretrained(pretrain_path, + device_map="cpu") + self.embed = model.model.embed_tokens + for p in self.embed.parameters(): + p.requires_grad = False + # set text embedding to non-trainable + + # self.post_embed = model.model.rotary_emb + self._output_size = model.config.hidden_size + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + input_ids: torch.Tensor, + ilens: torch.Tensor, + ): + input_masks = (~make_pad_mask(ilens)).to(input_ids.device).long() + + outs = self.embed(input_ids) + + return outs, input_masks + \ No newline at end of file diff --git a/inspiremusic/transformer/subsampling.py b/inspiremusic/transformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..08cf7d4224c0f4bf95853590f7e2f97b387f44f9 --- /dev/null +++ b/inspiremusic/transformer/subsampling.py @@ -0,0 +1,384 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Subsampling layer definition.""" + +from typing import Tuple, Union + +import torch + + +class BaseSubsampling(torch.nn.Module): + + def __init__(self): + super().__init__() + self.right_context = 0 + self.subsampling_rate = 1 + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return self.pos_enc.position_encoding(offset, size) + + +class EmbedinigNoSubsampling(BaseSubsampling): + """Embedding input without subsampling + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + super().__init__() + self.embed = torch.nn.Embedding(idim, odim) + self.pos_enc = pos_enc_class + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.embed(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class LinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class Conv1dSubsampling2(BaseSubsampling): + """Convolutional 1D subsampling (to 1/2 length). + It is designed for Whisper, ref: + https://github.com/openai/whisper/blob/main/whisper/model.py + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv1dSubsampling2 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1), + torch.nn.GELU(), + torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1), + torch.nn.GELU(), + ) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 2 + # 4 = (3 - 1) * 1 + (3 - 1) * 1 + self.right_context = 4 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. + torch.Tensor: positional encoding + + """ + time = x.size(1) + x = x.transpose(1, 2) # (b, f, t) + x = self.conv(x) + x = x.transpose(1, 2) # (b, t, f) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, (time + 1) % 2::2] + + +class Conv2dSubsampling4(BaseSubsampling): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling4 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + torch.Tensor: positional encoding + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2] + + +class Conv2dSubsampling6(BaseSubsampling): + """Convolutional 2D subsampling (to 1/6 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling6 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), + odim) + self.pos_enc = pos_enc_class + # 10 = (3 - 1) * 1 + (5 - 1) * 2 + self.subsampling_rate = 6 + self.right_context = 10 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3] + + +class Conv2dSubsampling8(BaseSubsampling): + """Convolutional 2D subsampling (to 1/8 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling8 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear( + odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) + self.pos_enc = pos_enc_class + self.subsampling_rate = 8 + # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 + self.right_context = 14 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] + + +class LegacyLinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask diff --git a/inspiremusic/utils/__init__.py b/inspiremusic/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/utils/audio_utils.py b/inspiremusic/utils/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b5d2afe3fd572b4f9fdfcece8b88d99870052a --- /dev/null +++ b/inspiremusic/utils/audio_utils.py @@ -0,0 +1,623 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import io +import logging +import re +import sys +import inspect +import random +import typing as tp +from functools import partial + +import omegaconf +import torch +import torchaudio +import numpy as np + +from typing_extensions import Literal +from typing import ( + Any, + Union, + Iterable, + List, + Dict, + Optional, + Tuple, +) + +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +_BoolLike_co = Union[bool, np.bool_] +_IntLike_co = Union[_BoolLike_co, int, "np.integer[Any]"] +_FloatLike_co = Union[_IntLike_co, float, "np.floating[Any]"] + +def process_audio(file_path, target_sample_rate=24000): + audio, sample_rate = torchaudio.load(file_path) + # Check if the audio needs to be resampled + if sample_rate != target_sample_rate: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(audio) + # Convert stereo to mono (if necessary) + audio = audio.mean(dim=0, keepdim=True) if audio.size(0) == 2 else audio + return audio, target_sample_rate + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + # global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned + mel_basis = {} + hann_window = {} + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def fade_out(audio: torch.Tensor, sample_rate: int, + fade_duration: float) -> torch.Tensor: + """ + Apply a linear fade-out effect to the given audio waveform. + + Parameters: + audio (torch.Tensor): The audio waveform tensor. + sample_rate (int): Sample rate of the audio. + fade_duration (float): Duration of the fade-out effect in seconds. + + Returns: + torch.Tensor: The audio with the fade-out effect applied. + """ + fade_samples = int(fade_duration * sample_rate) + + if fade_samples > audio.shape[1]: + fade_samples = audio.shape[ + 1] # use the whole length of audio if necessary + + fade_out_envelope = torch.linspace(1.0, 0.0, fade_samples, + dtype=audio.dtype, device=audio.device) + + fade_section = audio[:, -fade_samples:].clone() + + fade_section *= fade_out_envelope + + faded_audio = audio.clone() + faded_audio[:, -fade_samples:] = fade_section + + return faded_audio + +def split_wav_into_chunks(num_samples, wav, max_chunk_size, minimum_chunk_size=720): + num_chunks = (num_samples + max_chunk_size - 1) // max_chunk_size # Ceiling division + wav_chunks = [] + for i in range(num_chunks): + start_idx = i * max_chunk_size + end_idx = min(start_idx + max_chunk_size, num_samples) + if (end_idx - start_idx) >= minimum_chunk_size: + if len(wav.shape) == 2: + chunk = wav[:,start_idx:end_idx] + else: + chunk = wav[start_idx:end_idx] + wav_chunks.append(chunk) + else: + print(f"{num_samples}:{num_chunks}, chunk size={(end_idx - start_idx)} is lower then minimum_chunk_size!") + return wav_chunks + +def tiny(x: Union[float, np.ndarray]) -> _FloatLike_co: + """Compute the tiny-value corresponding to an input's data type. + """ + # Make sure we have an array view + x = np.asarray(x) + + # Only floating types generate a tiny + if np.issubdtype(x.dtype, np.floating) or np.issubdtype( + x.dtype, np.complexfloating + ): + dtype = x.dtype + else: + dtype = np.dtype(np.float32) + + return np.finfo(dtype).tiny + +def detect_silence(audio, sample_rate, threshold=0.05, min_silence_duration=1): + """ + Detects the first occurrence of silence in the audio. + + Parameters: + audio (Tensor): The audio waveform. + sample_rate (int): The sample rate of the audio. + threshold (float): The threshold below which the signal is considered silent. + min_silence_duration (float): The minimum duration of silence in seconds. + + Returns: + int: The timestamp (in samples) where the silence starts. + """ + # Convert the audio to a numpy array for easier manipulation + audio_np = audio.numpy().flatten() + # Calculate the energy of the signal + energy = np.abs(audio_np) + # Find the indices where the energy is below the threshold + silent_indices = np.where(energy < threshold)[0] + # Find the start and end of contiguous silent regions + silent_regions = np.split(silent_indices, np.where(np.diff(silent_indices) != 1)[0] + 1) + # Filter out regions that are too short + min_silence_samples = int(min_silence_duration * sample_rate) + for region in silent_regions: + if len(region) >= min_silence_samples: + return region[0] + + # If no silence is found, return the length of the audio + return len(audio_np) + +def trim_audio(waveform, sample_rate=24000, threshold=0.05, min_silence_duration=1, minimum_silence_start_sample=24000): + """ + Trims the audio from the beginning to the first occurrence of silence. + + Parameters: + waveform (Tensor): The waveform data to the input audio file. + sample_rate (int): Sample rate of the input audio file. + threshold (float): The threshold below which the signal is considered silent. + min_silence_duration (float): The minimum duration of silence in seconds. + """ + # Detect the first occurrence of silence + silence_start_sample = detect_silence(waveform, sample_rate, threshold, min_silence_duration) + if silence_start_sample > minimum_silence_start_sample : + trimmed_waveform = waveform[:silence_start_sample] + else: + trimmed_waveform = waveform[:minimum_silence_start_sample] + if isinstance(trimmed_waveform, torch.Tensor): + return trimmed_waveform + else: + return trimmed_waveform.unsqueeze() + +def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, energy_floor: float = 2e-3): + """Normalize an input signal to a user loudness in dB LKFS. + Audio loudness is defined according to the ITU-R BS.1770-4 recommendation. + + Args: + wav (torch.Tensor): Input multichannel audio data. + sample_rate (int): Sample rate. + loudness_headroom_db (float): Target loudness of the output in dB LUFS. + loudness_compressor (bool): Uses tanh for soft clipping. + energy_floor (float): anything below that RMS level will not be rescaled. + Returns: + torch.Tensor: Loudness normalized output data. + """ + energy = wav.pow(2).mean().sqrt().item() + if energy < energy_floor: + return wav + transform = torchaudio.transforms.Loudness(sample_rate) + input_loudness_db = transform(wav).item() + # calculate the gain needed to scale to the desired loudness level + delta_loudness = -loudness_headroom_db - input_loudness_db + gain = 10.0 ** (delta_loudness / 20.0) + output = gain * wav + if loudness_compressor: + output = torch.tanh(output) + assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt()) + return output + +def normalize( + S: np.ndarray, + *, + norm: Optional[float] = np.inf, + axis: Optional[int] = 0, + threshold: Optional[_FloatLike_co] = None, + fill: Optional[bool] = None, +) -> np.ndarray: + """Normalize an array along a chosen axis. + """ + # Avoid div-by-zero + if threshold is None: + threshold = tiny(S) + + elif threshold <= 0: + raise ParameterError(f"threshold={threshold} must be strictly positive") + + if fill not in [None, False, True]: + raise ParameterError(f"fill={fill} must be None or boolean") + + if not np.isfinite(S).all(): + raise ParameterError("Input must be finite") + + # All norms only depend on magnitude, let's do that first + S = S.numpy() + mag = np.abs(S).astype(float) + + # For max/min norms, filling with 1 works + fill_norm = 1 + + if norm is None: + return S + + elif norm == np.inf: + length = np.max(mag, axis=axis, keepdims=True) + + elif norm == -np.inf: + length = np.min(mag, axis=axis, keepdims=True) + + elif norm == 0: + if fill is True: + raise ParameterError("Cannot normalize with norm=0 and fill=True") + + length = np.sum(mag > 0, axis=axis, keepdims=True, dtype=mag.dtype) + + elif np.issubdtype(type(norm), np.number) and norm > 0: + length = np.sum(mag**norm, axis=axis, keepdims=True) ** (1.0 / norm) + + if axis is None: + fill_norm = mag.size ** (-1.0 / norm) + else: + fill_norm = mag.shape[axis] ** (-1.0 / norm) + + else: + raise ParameterError(f"Unsupported norm: {repr(norm)}") + + # indices where norm is below the threshold + small_idx = length < threshold + + Snorm = np.empty_like(S) + if fill is None: + # Leave small indices un-normalized + length[small_idx] = 1.0 + Snorm[:] = S / length + + elif fill: + # If we have a non-zero fill value, we locate those entries by + # doing a nan-divide. + # If S was finite, then length is finite (except for small positions) + length[small_idx] = np.nan + Snorm[:] = S / length + Snorm[np.isnan(Snorm)] = fill_norm + else: + # Set small values to zero by doing an inf-divide. + # This is safe (by IEEE-754) as long as S is finite. + length[small_idx] = np.inf + Snorm[:] = S / length + + return Snorm + +def normalize_audio(wav: torch.Tensor, normalize: bool = True, + strategy: str = 'peak', peak_clip_headroom_db: float = 1, + rms_headroom_db: float = 18, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, log_clipping: bool = False, + sample_rate: tp.Optional[int] = None, + stem_name: tp.Optional[str] = None) -> torch.Tensor: + """Normalize the audio according to the prescribed strategy (see after). + + Args: + wav (torch.Tensor): Audio data. + normalize (bool): if `True` (default), normalizes according to the prescribed + strategy (see after). If `False`, the strategy is only used in case clipping + would happen. + strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', + i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square + with extra headroom to avoid clipping. 'clip' just clips. + peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. + rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger + than the `peak_clip` one to avoid further clipping. + loudness_headroom_db (float): Target loudness for loudness normalization. + loudness_compressor (bool): If True, uses tanh based soft clipping. + log_clipping (bool): If True, basic logging on stderr when clipping still + occurs despite strategy (only for 'rms'). + sample_rate (int): Sample rate for the audio data (required for loudness). + stem_name (str, optional): Stem name for clipping logging. + Returns: + torch.Tensor: Normalized audio. + """ + scale_peak = 10 ** (-peak_clip_headroom_db / 20) + scale_rms = 10 ** (-rms_headroom_db / 20) + if strategy == 'peak': + rescaling = (scale_peak / wav.abs().max()) + if normalize or rescaling < 1: + wav = wav * rescaling + elif strategy == 'clip': + wav = wav.clamp(-scale_peak, scale_peak) + elif strategy == 'rms': + mono = wav.mean(dim=0) + rescaling = scale_rms / mono.pow(2).mean().sqrt() + if normalize or rescaling < 1: + wav = wav * rescaling + _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) + elif strategy == 'loudness': + assert sample_rate is not None, "Loudness normalization requires sample rate." + wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor) + _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) + else: + assert wav.abs().max() < 1 + assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'" + return wav + + +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """ + Convert audio to float 32 bits PCM format. + Args: + wav (torch.tensor): Input wav tensor + Returns: + same wav in float32 PCM format + """ + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / 2**15 + elif wav.dtype == torch.int32: + return wav.float() / 2**31 + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def i16_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to int 16 bits PCM format. + + ..Warning:: There exist many formula for doing this conversion. None are perfect + due to the asymmetry of the int16 range. One either have possible clipping, DC offset, + or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, + it is possible that `i16_pcm(f32_pcm)) != Identity`. + Args: + wav (torch.tensor): Input wav tensor + Returns: + same wav in float16 PCM format + """ + if wav.dtype.is_floating_point: + assert wav.abs().max() <= 1 + candidate = (wav * 2 ** 15).round() + if candidate.max() >= 2 ** 15: # clipping would occur + candidate = (wav * (2 ** 15 - 1)).round() + return candidate.short() + else: + assert wav.dtype == torch.int16 + return wav + + +def compress(wav: torch.Tensor, sr: int, + target_format: tp.Literal["mp3", "ogg", "flac"] = "mp3", + bitrate: str = "128k") -> tp.Tuple[torch.Tensor, int]: + """Convert audio wave form to a specified lossy format: mp3, ogg, flac + + Args: + wav (torch.Tensor): Input wav tensor. + sr (int): Sampling rate. + target_format (str): Compression format (e.g., 'mp3'). + bitrate (str): Bitrate for compression. + + Returns: + Tuple of compressed WAV tensor and sampling rate. + """ + + # Extract the bit rate from string (e.g., '128k') + match = re.search(r"\d+(\.\d+)?", str(bitrate)) + parsed_bitrate = float(match.group()) if match else None + assert parsed_bitrate, f"Invalid bitrate specified (got {parsed_bitrate})" + try: + # Create a virtual file instead of saving to disk + buffer = io.BytesIO() + + torchaudio.save( + buffer, wav, sr, format=target_format, bits_per_sample=parsed_bitrate, + ) + # Move to the beginning of the file + buffer.seek(0) + compressed_wav, sr = torchaudio.load(buffer) + return compressed_wav, sr + + except RuntimeError: + logger.warning( + f"compression failed skipping compression: {format} {parsed_bitrate}" + ) + return wav, sr + + +def get_mp3(wav_tensor: torch.Tensor, sr: int, bitrate: str = "128k") -> torch.Tensor: + """Convert a batch of audio files to MP3 format, maintaining the original shape. + + This function takes a batch of audio files represented as a PyTorch tensor, converts + them to MP3 format using the specified bitrate, and returns the batch in the same + shape as the input. + + Args: + wav_tensor (torch.Tensor): Batch of audio files represented as a tensor. + Shape should be (batch_size, channels, length). + sr (int): Sampling rate of the audio. + bitrate (str): Bitrate for MP3 conversion, default is '128k'. + + Returns: + torch.Tensor: Batch of audio files converted to MP3 format, with the same + shape as the input tensor. + """ + device = wav_tensor.device + batch_size, channels, original_length = wav_tensor.shape + + # Flatten tensor for conversion and move to CPU + wav_tensor_flat = wav_tensor.view(1, -1).cpu() + + # Convert to MP3 format with specified bitrate + wav_tensor_flat, _ = compress(wav_tensor_flat, sr, bitrate=bitrate) + + # Reshape back to original batch format and trim or pad if necessary + wav_tensor = wav_tensor_flat.view(batch_size, channels, -1) + compressed_length = wav_tensor.shape[-1] + if compressed_length > original_length: + wav_tensor = wav_tensor[:, :, :original_length] # Trim excess frames + elif compressed_length < original_length: + padding = torch.zeros( + batch_size, channels, original_length - compressed_length, device=device + ) + wav_tensor = torch.cat((wav_tensor, padding), dim=-1) # Pad with zeros + + # Move tensor back to the original device + return wav_tensor.to(device) + + +def get_aac( + wav_tensor: torch.Tensor, + sr: int, + bitrate: str = "128k", + lowpass_freq: tp.Optional[int] = None, +) -> torch.Tensor: + """Converts a batch of audio tensors to AAC format and then back to tensors. + + This function first saves the input tensor batch as WAV files, then uses FFmpeg to convert + these WAV files to AAC format. Finally, it loads the AAC files back into tensors. + + Args: + wav_tensor (torch.Tensor): A batch of audio files represented as a tensor. + Shape should be (batch_size, channels, length). + sr (int): Sampling rate of the audio. + bitrate (str): Bitrate for AAC conversion, default is '128k'. + lowpass_freq (Optional[int]): Frequency for a low-pass filter. If None, no filter is applied. + + Returns: + torch.Tensor: Batch of audio files converted to AAC and back, with the same + shape as the input tensor. + """ + import tempfile + import subprocess + + device = wav_tensor.device + batch_size, channels, original_length = wav_tensor.shape + + # Parse the bitrate value from the string + match = re.search(r"\d+(\.\d+)?", bitrate) + parsed_bitrate = ( + match.group() if match else "128" + ) # Default to 128 if parsing fails + + # Flatten tensor for conversion and move to CPU + wav_tensor_flat = wav_tensor.view(1, -1).cpu() + + with tempfile.NamedTemporaryFile( + suffix=".wav" + ) as f_in, tempfile.NamedTemporaryFile(suffix=".aac") as f_out: + input_path, output_path = f_in.name, f_out.name + + # Save the tensor as a WAV file + torchaudio.save(input_path, wav_tensor_flat, sr, backend="ffmpeg") + + # Prepare FFmpeg command for AAC conversion + command = [ + "ffmpeg", + "-y", + "-i", + input_path, + "-ar", + str(sr), + "-b:a", + f"{parsed_bitrate}k", + "-c:a", + "aac", + ] + if lowpass_freq is not None: + command += ["-cutoff", str(lowpass_freq)] + command.append(output_path) + + try: + # Run FFmpeg and suppress output + subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + # Load the AAC audio back into a tensor + aac_tensor, _ = torchaudio.load(output_path, backend="ffmpeg") + except Exception as exc: + raise RuntimeError( + "Failed to run command " ".join(command)} " + "(Often this means ffmpeg is not installed or the encoder is not supported, " + "make sure you installed an older version ffmpeg<5)" + ) from exc + + original_length_flat = batch_size * channels * original_length + compressed_length_flat = aac_tensor.shape[-1] + + # Trim excess frames + if compressed_length_flat > original_length_flat: + aac_tensor = aac_tensor[:, :original_length_flat] + + # Pad the shortedn frames + elif compressed_length_flat < original_length_flat: + padding = torch.zeros( + 1, original_length_flat - compressed_length_flat, device=device + ) + aac_tensor = torch.cat((aac_tensor, padding), dim=-1) + + # Reshape and adjust length to match original tensor + wav_tensor = aac_tensor.view(batch_size, channels, -1) + compressed_length = wav_tensor.shape[-1] + + assert compressed_length == original_length, ( + "AAC-compressed audio does not have the same frames as original one. " + "One reason can be ffmpeg is not installed and used as proper backed " + "for torchaudio, or the AAC encoder is not correct. Run " + "`torchaudio.utils.ffmpeg_utils.get_audio_encoders()` and make sure we see entry for" + "AAC in the output." + ) + return wav_tensor.to(device) \ No newline at end of file diff --git a/inspiremusic/utils/binary.py b/inspiremusic/utils/binary.py new file mode 100644 index 0000000000000000000000000000000000000000..862cb467850b9af8e8b035939c018984e590e79c --- /dev/null +++ b/inspiremusic/utils/binary.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" +import io +import json +import struct +import typing as tp + +# format is `ECDC` magic code, followed by the header size as uint32. +# Then an uint8 indicates the protocol version (0.) +# The header is then provided as json and should contain all required +# informations for decoding. A raw stream of bytes is then provided +# and should be interpretable using the json header. +_encodec_header_struct = struct.Struct('!4sBI') +_ENCODEC_MAGIC = b'ECDC' + + +def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any): + meta_dumped = json.dumps(metadata).encode('utf-8') + version = 0 + header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, + len(meta_dumped)) + fo.write(header) + fo.write(meta_dumped) + fo.flush() + + +def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes: + buf = b"" + while len(buf) < size: + new_buf = fo.read(size) + if not new_buf: + raise EOFError("Impossible to read enough data from the stream, " + f"{size} bytes remaining.") + buf += new_buf + size -= len(new_buf) + return buf + + +def read_ecdc_header(fo: tp.IO[bytes]): + header_bytes = _read_exactly(fo, _encodec_header_struct.size) + magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) + if magic != _ENCODEC_MAGIC: + raise ValueError("File is not in ECDC format.") + if version != 0: + raise ValueError("Version not supported.") + meta_bytes = _read_exactly(fo, meta_size) + return json.loads(meta_bytes.decode('utf-8')) + + +class BitPacker: + """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. + Note that for some bandwidth (1.5, 3), the codebook representation + will not cover an integer number of bytes. + + Args: + bits (int): number of bits per value that will be pushed. + fo (IO[bytes]): file-object to push the bytes to. + """ + + def __init__(self, bits: int, fo: tp.IO[bytes]): + self._current_value = 0 + self._current_bits = 0 + self.bits = bits + self.fo = fo + + def push(self, value: int): + """Push a new value to the stream. This will immediately + write as many uint8 as possible to the underlying file-object.""" + self._current_value += (value << self._current_bits) + self._current_bits += self.bits + while self._current_bits >= 8: + lower_8bits = self._current_value & 0xff + self._current_bits -= 8 + self._current_value >>= 8 + self.fo.write(bytes([lower_8bits])) + + def flush(self): + """Flushes the remaining partial uint8, call this at the end + of the stream to encode.""" + if self._current_bits: + self.fo.write(bytes([self._current_value])) + self._current_value = 0 + self._current_bits = 0 + self.fo.flush() + + +class BitUnpacker: + """BitUnpacker does the opposite of `BitPacker`. + + Args: + bits (int): number of bits of the values to decode. + fo (IO[bytes]): file-object to push the bytes to. + """ + + def __init__(self, bits: int, fo: tp.IO[bytes]): + self.bits = bits + self.fo = fo + self._mask = (1 << bits) - 1 + self._current_value = 0 + self._current_bits = 0 + + def pull(self) -> tp.Optional[int]: + """ + Pull a single value from the stream, potentially reading some + extra bytes from the underlying file-object. + Returns `None` when reaching the end of the stream. + """ + while self._current_bits < self.bits: + buf = self.fo.read(1) + if not buf: + return None + character = buf[0] + self._current_value += character << self._current_bits + self._current_bits += 8 + + out = self._current_value & self._mask + self._current_value >>= self.bits + self._current_bits -= self.bits + return out + + +def test(): + import torch + torch.manual_seed(1234) + for rep in range(4): + length: int = torch.randint(10, 2_000, (1, )).item() + bits: int = torch.randint(1, 16, (1, )).item() + tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist() + rebuilt: tp.List[int] = [] + buf = io.BytesIO() + packer = BitPacker(bits, buf) + for token in tokens: + packer.push(token) + packer.flush() + buf.seek(0) + unpacker = BitUnpacker(bits, buf) + while True: + value = unpacker.pull() + if value is None: + break + rebuilt.append(value) + assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) + # The flushing mechanism might lead to "ghost" values at the end of the stream. + assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), + len(tokens), bits) + for idx, (a, b) in enumerate(zip(tokens, rebuilt)): + assert a == b, (idx, a, b) + + +if __name__ == '__main__': + test() diff --git a/inspiremusic/utils/class_utils.py b/inspiremusic/utils/class_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6ddd08863b54afd24e92268dbd1faf15114b3e --- /dev/null +++ b/inspiremusic/utils/class_utils.py @@ -0,0 +1,71 @@ +# Copyright [2023-11-28] +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from inspiremusic.transformer.activation import Swish +from inspiremusic.transformer.subsampling import ( + LinearNoSubsampling, + EmbedinigNoSubsampling, + Conv1dSubsampling2, + Conv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, +) +from inspiremusic.transformer.embedding import (PositionalEncoding, + RelPositionalEncoding, + WhisperPositionalEncoding, + LearnablePositionalEncoding, + NoPositionalEncoding) +from inspiremusic.transformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention) +from inspiremusic.transformer.embedding import EspnetRelPositionalEncoding +from inspiremusic.transformer.subsampling import LegacyLinearNoSubsampling + + +INSPIREMUSIC_ACTIVATION_CLASSES = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": getattr(torch.nn, "SiLU", Swish), + "gelu": torch.nn.GELU, +} + +INSPIREMUSIC_SUBSAMPLE_CLASSES = { + "linear": LinearNoSubsampling, + "linear_legacy": LegacyLinearNoSubsampling, + "embed": EmbedinigNoSubsampling, + "conv1d2": Conv1dSubsampling2, + "conv2d": Conv2dSubsampling4, + "conv2d6": Conv2dSubsampling6, + "conv2d8": Conv2dSubsampling8, + 'paraformer_dummy': torch.nn.Identity +} + +INSPIREMUSIC_EMB_CLASSES = { + "embed": PositionalEncoding, + "abs_pos": PositionalEncoding, + "rel_pos": RelPositionalEncoding, + "rel_pos_espnet": EspnetRelPositionalEncoding, + "no_pos": NoPositionalEncoding, + "abs_pos_whisper": WhisperPositionalEncoding, + "embed_learnable_pe": LearnablePositionalEncoding, +} + +INSPIREMUSIC_ATTENTION_CLASSES = { + "selfattn": MultiHeadedAttention, + "rel_selfattn": RelPositionMultiHeadedAttention, +} + diff --git a/inspiremusic/utils/common.py b/inspiremusic/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d888fa4b71743142a968d70b7c0bcd09a32de70b --- /dev/null +++ b/inspiremusic/utils/common.py @@ -0,0 +1,173 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Unility functions for Transformer.""" + +from typing import List + +import torch +IGNORE_ID = -1 + +MUSIC_STRUCTURE_LABELS = ["intro", "verse1", "chorus", "verse2", "outro"] + +def pad_list(xs: List[torch.Tensor], pad_value: int): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + max_len = max([len(item) for item in xs]) + batchs = len(xs) + ndim = xs[0].ndim + if ndim == 1: + pad_res = torch.zeros(batchs, + max_len, + dtype=xs[0].dtype, + device=xs[0].device) + elif ndim == 2: + pad_res = torch.zeros(batchs, + max_len, + xs[0].shape[1], + dtype=xs[0].dtype, + device=xs[0].device) + elif ndim == 3: + pad_res = torch.zeros(batchs, + max_len, + xs[0].shape[1], + xs[0].shape[2], + dtype=xs[0].dtype, + device=xs[0].device) + else: + raise ValueError(f"Unsupported ndim: {ndim}") + pad_res.fill_(pad_value) + for i in range(batchs): + pad_res[i, :len(xs[i])] = xs[i] + return pad_res + + +def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, + ignore_label: int) -> torch.Tensor: + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + torch.Tensor: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), + pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = torch.sum(mask) + return (numerator / denominator).detach() + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + +def topk_sampling(weighted_scores, decoded_tokens, top_k=25): + zeros = weighted_scores.new_ones(weighted_scores.shape) * float('-inf') + values,indices = torch.topk(weighted_scores,top_k) + zeros.scatter_(-1, indices, values) + return random_sampling(zeros,decoded_tokens) + +# Repetition Aware Sampling in VALL-E 2 + +def ras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): + top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() + if rep_num >= win_size * tau_r: + top_ids = random_sampling(weighted_scores, decoded_tokens) + return top_ids + +def caras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): + weighted_scores, cfg_weighted_scores = weighted_scores + top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() + if rep_num >= win_size * tau_r: + top_ids = random_sampling(cfg_weighted_scores, decoded_tokens) + return top_ids + +def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): + prob, indices = [], [] + cum_prob = 0.0 + sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) + for i in range(len(sorted_idx)): + # sampling both top-p and numbers. + if cum_prob < top_p and len(prob) < top_k: + cum_prob += sorted_value[i] + prob.append(sorted_value[i]) + indices.append(sorted_idx[i]) + else: + break + prob = torch.tensor(prob).to(weighted_scores) + indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) + top_ids = indices[prob.multinomial(1, replacement=True)] + return top_ids + + +def random_sampling(weighted_scores, decoded_tokens): + top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) + return top_ids + + +def fade_in_out(fade_in_mel, fade_out_mel, window): + device = fade_in_mel.device + fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() + mel_overlap_len = int(window.shape[0] / 2) + fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \ + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:] + return fade_in_mel.to(device) + +def set_all_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + assert mask.dtype == torch.bool + assert dtype in [torch.float32, torch.bfloat16, torch.float16] + mask = mask.to(dtype) + # attention mask bias + # NOTE(Mddct): torch.finfo jit issues + # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min + mask = (1.0 - mask) * torch.finfo(dtype).min + return mask \ No newline at end of file diff --git a/inspiremusic/utils/data_utils.py b/inspiremusic/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..094a280f25ac50be81253854e62790ce50e5e7de --- /dev/null +++ b/inspiremusic/utils/data_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch.utils.data import DataLoader +from inspiremusic.dataset.dataset import Dataset +import numpy as np +import librosa + +def audio_process_dataset_and_dataloader(args, configs): + input_dataset = Dataset(args.input_data, data_pipeline=configs['data_pipeline'], mode='processing', shuffle=True, partition=True) + # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts + input_data_loader = DataLoader(input_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + return input_dataset, input_data_loader + +def is_silent(wav_path, threshold=0.01, frame_length=2048, hop_length=512): + y, sr = librosa.load(wav_path, sr=None) + rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] + silent_frames = np.sum(rms < threshold) / len(rms) + silence_fraction_threshold = 0.95 + return silent_frames >= silence_fraction_threshold + +def rich_captions(text=None, tags=None, lyrics=None, chorus="verse", start_time=0.0, end_time=30.0): + if text is None and tags is None and lyrics is None: + return None + else: + if start_time is None: + start_time = 0.0 + if end_time is None: + end_time = 30.0 + if chorus is None: + chorus = "verse" + captions = f"<|{start_time:.1f}|><|{chorus}|>" + if tags is not None: + captions += f"<|{tags}|>" + if text is not None: + captions += f"<|{text}|>" + if lyrics is not None: + captions += f"<|lyrics|><|{lyrics}|>" + captions += f"<|{end_time:.1f}|>" + return captions + +def process_tags(infile, outfile, timefile = None): + key_list = [] + with open(infile, "r") as f: + for line in f: + sec = line.strip() + key_list.append(sec) + f.close() + if timefile is None: + with open(outfile, 'w') as f: + for k in key_list: + parts = k.rsplit('_', 1) + text = parts[0].replace('_', ' ') + ', ' + parts[1] + caption = rich_captions(text, None, None) + if caption is not None: + f.write("%s\t%s\n" %(k, caption)) + f.close() + else: + times = {} + with open(timefile, "r") as f: + for line in f: + sec = line.strip().split("\t") + if len(sec) == 2 : + times[sec[0]] = sec[1] + f.close() + + with open(outfile, 'w') as f: + for k in key_list: + parts = k.rsplit('_', 1) + text = parts[0].replace('_', ' ') + ', ' + parts[1] + if k in times.keys(): + caption = rich_captions(text, None, None, "verse", 0.0, float(times[k])) + if caption is not None: + f.write("%s\t%s\n" %(k, caption)) + f.close() + +def process_trans(infile, outfile): + trans = {} + with open(infile, "r") as f: + for line in f: + sec = line.strip().split("\t") + if len(sec) == 2: + trans[sec[0]] = sec[1] + else: + print(line) + f.close() + with open(outfile, 'w') as f: + for k, v in trans.items(): + f.write("%s\t%s\n" %(k, rich_captions(v))) + f.close() \ No newline at end of file diff --git a/inspiremusic/utils/executor.py b/inspiremusic/utils/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..da933e4942c6095da87fb26e964c39c809925ad3 --- /dev/null +++ b/inspiremusic/utils/executor.py @@ -0,0 +1,121 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from contextlib import nullcontext +import os + +import torch +import torch.distributed as dist + +from inspiremusic.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, inspiremusic_join +from torch.cuda.amp import GradScaler, autocast + +class Executor: + + def __init__(self): + self.step = 0 + self.epoch = 0 + self.rank = int(os.environ.get('RANK', 0)) + self.device = torch.device('cuda:{}'.format(self.rank)) + + def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if inspiremusic_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + with autocast(enabled=scaler is not None): + info_dict = batch_forward(model, batch_dict, info_dict, scaler) + info_dict = batch_backward(model, info_dict, scaler) + + info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict, scaler) + log_per_step(writer, info_dict) + # NOTE specify save_per_step in inspiremusic.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, scaler=scaler) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, scaler=scaler) + + @torch.inference_mode() + def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, capped_at=5, scaler=None): + ''' Cross validation on + ''' + logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) + model.eval() + total_num_utts, total_loss_dict = 0, {} # avoid division by 0 + stop = capped_at + for batch_idx, batch_dict in enumerate(cv_data_loader): + info_dict["tag"] = "CV" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + + num_utts = len(batch_dict["utts"]) + total_num_utts += num_utts + + if capped_at>0: + if stop <= 0: + continue + else: + stop -= 1 + + with autocast(enabled=scaler is not None): + info_dict = batch_forward(model, batch_dict, info_dict, scaler) + + for k, v in info_dict['loss_dict'].items(): + if k not in total_loss_dict: + total_loss_dict[k] = [] + total_loss_dict[k].append(v.item() * num_utts) + log_per_step(None, info_dict) + + for k, v in total_loss_dict.items(): + total_loss_dict[k] = sum(v) / total_num_utts + info_dict['loss_dict'] = total_loss_dict + log_per_save(writer, info_dict) + model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) + save_model(model, model_name, info_dict) diff --git a/inspiremusic/utils/file_utils.py b/inspiremusic/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56019b4fc015681e2b4d7a79cc2e186c6007d079 --- /dev/null +++ b/inspiremusic/utils/file_utils.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import torchaudio +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + +def read_trans(list_file): + trans = {} + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + sec = line.strip().split("\t") + if len(sec) > 1: + if sec[0] not in trans.keys(): + trans[sec[0]] = sec[1] + return trans + +def read_scp(list_file): + scp = {} + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + sec = line.strip().split(" ") + if len(sec) > 1: + if sec[0] not in scp.keys(): + scp[sec[0]] = sec[1] + return scp + +def read_lists(list_file): + lists = [] + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + lists.append(line.strip()) + return lists + + +def read_json_lists(list_file): + lists = read_lists(list_file) + results = {} + for fn in lists: + with open(fn, 'r', encoding='utf8') as fin: + results.update(json.load(fin)) + return results + + +def load_wav(wav, target_sr): + audio, sample_rate = torchaudio.load(wav) + audio = audio.mean(dim=0, keepdim=True) + if sample_rate != target_sr: + assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) + return audio + + +def speed_change(waveform, sample_rate, speed_factor: str): + effects = [ + ["tempo", speed_factor], # speed_factor + ["rate", f"{sample_rate}"] + ] + augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor( + waveform, + sample_rate, + effects + ) + return augmented_waveform, new_sample_rate diff --git a/inspiremusic/utils/frontend_utils.py b/inspiremusic/utils/frontend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45239b854718afe0dc3ce194aa37a0ac6bb760eb --- /dev/null +++ b/inspiremusic/utils/frontend_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') + + +# whether contain chinese character +def contains_chinese(text): + return bool(chinese_char_pattern.search(text)) + + +# replace special symbol +def replace_corner_mark(text): + text = text.replace('²', '平方') + text = text.replace('³', '立方') + return text + + +# remove meaningless symbol +def remove_bracket(text): + text = text.replace('(', '').replace(')', '') + text = text.replace('【', '').replace('】', '') + text = text.replace('`', '').replace('`', '') + text = text.replace("——", " ") + return text + + +# spell Arabic numerals +def spell_out_number(text: str, inflect_parser): + new_text = [] + st = None + for i, c in enumerate(text): + if not c.isdigit(): + if st is not None: + num_str = inflect_parser.number_to_words(text[st: i]) + new_text.append(num_str) + st = None + new_text.append(c) + else: + if st is None: + st = i + if st is not None and st < len(text): + num_str = inflect_parser.number_to_words(text[st:]) + new_text.append(num_str) + return ''.join(new_text) + + +# split paragrah logic: +# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len +# 2. cal sentence len according to lang +# 3. split sentence according to puncatation +def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): + def calc_utt_length(_text: str): + if lang == "zh": + return len(_text) + else: + return len(tokenize(_text)) + + def should_merge(_text: str): + if lang == "zh": + return len(_text) < merge_len + else: + return len(tokenize(_text)) < merge_len + + if lang == "zh": + pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] + else: + pounc = ['.', '?', '!', ';', ':'] + if comma_split: + pounc.extend([',', ',']) + st = 0 + utts = [] + for i, c in enumerate(text): + if c in pounc: + if len(text[st: i]) > 0: + utts.append(text[st: i] + c) + if i + 1 < len(text) and text[i + 1] in ['"', '”']: + tmp = utts.pop(-1) + utts.append(tmp + text[i + 1]) + st = i + 2 + else: + st = i + 1 + if len(utts) == 0: + if lang == "zh": + utts.append(text + '。') + else: + utts.append(text + '.') + final_utts = [] + cur_utt = "" + for utt in utts: + if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: + final_utts.append(cur_utt) + cur_utt = "" + cur_utt = cur_utt + utt + if len(cur_utt) > 0: + if should_merge(cur_utt) and len(final_utts) != 0: + final_utts[-1] = final_utts[-1] + cur_utt + else: + final_utts.append(cur_utt) + + return final_utts + + +# remove blank between chinese character +def replace_blank(text: str): + out_str = [] + for i, c in enumerate(text): + if c == " ": + if ((text[i + 1].isascii() and text[i + 1] != " ") and + (text[i - 1].isascii() and text[i - 1] != " ")): + out_str.append(c) + else: + out_str.append(c) + return "".join(out_str) diff --git a/inspiremusic/utils/hinter.py b/inspiremusic/utils/hinter.py new file mode 100644 index 0000000000000000000000000000000000000000..6b32194336ed680a12772e2e3d3ed48d13cbcf59 --- /dev/null +++ b/inspiremusic/utils/hinter.py @@ -0,0 +1,12 @@ +import sys +import torch.distributed +import logging + +HINTED = set() + + +def hint_once(content, uid, rank=None): + if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank: + if uid not in HINTED: + logging.info(content, stacklevel=3) + HINTED.add(uid) \ No newline at end of file diff --git a/inspiremusic/utils/losses.py b/inspiremusic/utils/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..78efd3b72ff4c61971f4732626c43613f812761d --- /dev/null +++ b/inspiremusic/utils/losses.py @@ -0,0 +1,20 @@ +import torch +import torch.nn.functional as F + + +def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + m_DG = torch.median((dr - dg)) + L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) + loss += tau - F.relu(tau - L_rel) + return loss + + +def mel_loss(real_speech, generated_speech, mel_transforms): + loss = 0 + for transform in mel_transforms: + mel_r = transform(real_speech) + mel_g = transform(generated_speech) + loss += F.l1_loss(mel_g, mel_r) + return loss diff --git a/inspiremusic/utils/mask.py b/inspiremusic/utils/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d767dfe0b9a37295c49d6670e5600e60fe075e --- /dev/null +++ b/inspiremusic/utils/mask.py @@ -0,0 +1,227 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +''' +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=torch.bool) + return torch.tril(ret) +''' + + +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + arange = torch.arange(size, device=device) + mask = arange.expand(size, size) + arange = arange.unsqueeze(-1) + mask = mask <= arange + return mask + + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +def add_optional_chunk_mask(xs: torch.Tensor, + masks: torch.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True): + """ Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = torch.randint(1, max_len, (1, )).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = torch.randint(0, max_left_chunks, + (1, )).item() + chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + return chunk_masks + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask diff --git a/inspiremusic/utils/scheduler.py b/inspiremusic/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5f3a8b00ca9328a0fe917b34d73e21e9c25b2f --- /dev/null +++ b/inspiremusic/utils/scheduler.py @@ -0,0 +1,738 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Ximalaya Inc (Yuguang Yang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +# NeMo(https://github.com/NVIDIA/NeMo) + +from typing import Union + +import math +import warnings +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class WarmupLR(_LRScheduler): + """The WarmupLR scheduler + + This scheduler is almost same as NoamLR Scheduler except for following + difference: + + NoamLR: + lr = optimizer.lr * model_size ** -0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + WarmupLR: + lr = optimizer.lr * warmup_step ** 0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + + Note that the maximum lr equals to optimizer.lr in this scheduler. + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, + ): + self.warmup_steps = warmup_steps + + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer, last_epoch) + + def __repr__(self): + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + + def get_lr(self): + step_num = self.last_epoch + 1 + if self.warmup_steps == 0: + return [lr * step_num**-0.5 for lr in self.base_lrs] + else: + return [ + lr * self.warmup_steps**0.5 * + min(step_num**-0.5, step_num * self.warmup_steps**-1.5) + for lr in self.base_lrs + ] + + def set_step(self, step: int): + self.last_epoch = step + + +class WarmupPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + assert not (warmup_steps is not None and warmup_ratio is not None),\ + "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class SquareRootConstantPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, + optimizer, + *, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use particular number of step or ratio" + assert constant_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.constant_lr = 1 / (constant_steps**0.5) + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + if step <= self.constant_steps: + return [self.constant_lr for _ in self.base_lrs] + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class WarmupHoldPolicy(WarmupPolicy): + """Variant of WarmupPolicy which maintains high + learning rate for a defined number of steps. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to + hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + hold_steps=None, + hold_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (hold_steps is not None and hold_ratio is not None), \ + "Either use particular number of step or ratio" + assert hold_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + self.min_lr = min_lr + self._last_warmup_lr = 0.0 + + # Necessary to duplicate as class attributes are hidden in inner class + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if hold_steps is not None: + self.hold_steps = hold_steps + self.warmup_steps + elif hold_ratio is not None: + self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps + else: + self.hold_steps = 0 + + super().__init__( + optimizer, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + ) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler," + " " + "please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + # Warmup phase + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + # Hold phase + if (step >= self.warmup_steps) and (step < self.hold_steps): + return self.base_lrs + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + +class WarmupAnnealHoldPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + min_lr: Minimum lr to hold the learning rate after decay at. + constant_steps: Number of steps to keep lr constant at. + constant_ratio: Ratio of steps to keep lr constant. + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (warmup_steps is not None + and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use constant_steps or constant_ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.decay_steps = max_steps - (self.constant_steps + + self.warmup_steps) + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + # Warmup steps + if self.warmup_steps > 0 and step <= self.warmup_steps: + return self._get_warmup_lr(step) + + # Constant steps after warmup and decay + if self.constant_steps > 0 and ( + self.warmup_steps + self.decay_steps) < step <= self.max_steps: + return self._get_constant_lr(step) + + # Min lr after max steps of updates + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_constant_lr(self, step): + return [self.min_lr for _ in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +def _squareroot_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps)**0.5 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _square_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps)**2 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _cosine_annealing(initial_lr, step, max_steps, min_lr): + mult = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + out_lr = (initial_lr - min_lr) * mult + min_lr + return out_lr + + +def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, + decay_steps, min_lr): + assert max_lr > min_lr + # Use linear warmup for the initial part. + if warmup_steps > 0 and step <= warmup_steps: + return max_lr * float(step) / float(warmup_steps) + + # For any steps larger than `decay_steps`, use `min_lr`. + if step > warmup_steps + decay_steps: + return min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = step - warmup_steps + decay_steps_ = decay_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = max_lr - min_lr + + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + + return min_lr + coeff * delta_lr + + +def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): + if cycle: + multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps) + decay_steps *= multiplier + else: + step = min(step, decay_steps) + p = step / decay_steps + lr = (initial_lr - min_lr) * math.pow(1.0 - p, power) + lr += min_lr + return lr + + +def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, + decay_rate, min_lr): + # hold_steps = total number of steps + # to hold the LR, not the warmup + hold steps. + T_warmup_decay = max(1, warmup_steps**decay_rate) + T_hold_decay = max(1, (step - hold_steps)**decay_rate) + lr = (initial_lr * T_warmup_decay) / T_hold_decay + lr = max(lr, min_lr) + return lr + + +class SquareAnnealing(WarmupPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=1e-5, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _square_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + return new_lrs + + +class SquareRootAnnealing(WarmupPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=0, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _squareroot_annealing(initial_lr=initial_lr, + step=step, + max_steps=self.max_steps, + min_lr=self.min_lr) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class CosineAnnealing(WarmupAnnealHoldPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=0, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate.") + + if self.constant_steps is None or self.constant_steps == 0: + new_lrs = [ + _cosine_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + else: + new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step) + return new_lrs + + def _get_warmup_lr(self, step): + if self.constant_steps is None or self.constant_steps == 0: + return super()._get_warmup_lr(step) + else: + # Use linear warmup for the initial part. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_constant_lr(self, step): + # Only called when `constant_steps` > 0. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_linear_warmup_with_cosine_annealing_lr(self, step): + # Cosine Schedule for Megatron LM, + # slightly different warmup schedule + constant LR at the end. + new_lrs = [ + _linear_warmup_with_cosine_annealing( + max_lr=self.base_lrs[0], + warmup_steps=self.warmup_steps, + step=step, + decay_steps=self.decay_steps, + min_lr=self.min_lr, + ) for _ in self.base_lrs + ] + return new_lrs + + +class NoamAnnealing(_LRScheduler): + + def __init__(self, + optimizer, + *, + d_model, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + self._normalize = d_model**(-0.5) + assert not (warmup_steps is not None and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = max(1, self.last_epoch) + + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate.") + + new_lrs = [ + self._noam_annealing(initial_lr=initial_lr, step=step) + for initial_lr in self.base_lrs + ] + return new_lrs + + def _noam_annealing(self, initial_lr, step): + if self.warmup_steps > 0: + mult = self._normalize * min(step**(-0.5), + step * (self.warmup_steps**(-1.5))) + else: + mult = self._normalize * step**(-0.5) + + out_lr = initial_lr * mult + if step > self.warmup_steps: + out_lr = max(out_lr, self.min_lr) + return out_lr + + +class NoamHoldAnnealing(WarmupHoldPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + decay_rate=0.5, + min_lr=0.0, + last_epoch=-1, + **kwargs): + """ + From Nemo: + Implementation of the Noam Hold Annealing policy + from the SqueezeFormer paper. + + Unlike NoamAnnealing, the peak learning rate + can be explicitly set for this scheduler. + The schedule first performs linear warmup, + then holds the peak LR, then decays with some schedule for + the remainder of the steps. + Therefore the min-lr is still dependent + on the hyper parameters selected. + + It's schedule is determined by three factors- + + Warmup Steps: Initial stage, where linear warmup + occurs uptil the peak LR is reached. Unlike NoamAnnealing, + the peak LR is explicitly stated here instead of a scaling factor. + + Hold Steps: Intermediate stage, where the peak LR + is maintained for some number of steps. In this region, + the high peak LR allows the model to converge faster + if training is stable. However the high LR + may also cause instability during training. + Should usually be a significant fraction of training + steps (around 30-40% of the entire training steps). + + Decay Steps: Final stage, where the LR rapidly decays + with some scaling rate (set by decay rate). + To attain Noam decay, use 0.5, + for Squeezeformer recommended decay, use 1.0. + The fast decay after prolonged high LR during + hold phase allows for rapid convergence. + + References: + - [Squeezeformer: + An Efficient Transformer for Automatic Speech Recognition] + (https://arxiv.org/abs/2206.00888) + + Args: + optimizer: Pytorch compatible Optimizer object. + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to + hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + decay_rate: Float value describing the polynomial decay + after the hold period. Default value + of 0.5 corresponds to Noam decay. + min_lr: Minimum learning rate. + """ + self.decay_rate = decay_rate + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + if self.warmup_steps is None or self.warmup_steps == 0: + raise ValueError( + "Noam scheduler cannot be used without warmup steps") + + if self.hold_steps > 0: + hold_steps = self.hold_steps - self.warmup_steps + else: + hold_steps = 0 + + new_lrs = [ + _noam_hold_annealing( + initial_lr, + step=step, + warmup_steps=self.warmup_steps, + hold_steps=hold_steps, + decay_rate=self.decay_rate, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + return new_lrs + + def set_step(self, step: int): + self.last_epoch = step + + +class ConstantLR(_LRScheduler): + """The ConstantLR scheduler + + This scheduler keeps a constant lr + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + ): + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer) + + def get_lr(self): + return self.base_lrs + + def set_step(self, step: int): + self.last_epoch = step diff --git a/inspiremusic/utils/tokenizer_utils.py b/inspiremusic/utils/tokenizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..095757f5d0f8a6da8598f06ceed819e659f8bb61 --- /dev/null +++ b/inspiremusic/utils/tokenizer_utils.py @@ -0,0 +1,221 @@ +import glob +import json +import os +import random +import sys +import time +import warnings + +import matplotlib +import numpy as np +import torch +import yaml +from torch import distributed as dist +from torch.nn.utils import weight_norm +matplotlib.use("Agg") +import matplotlib.pylab as plt +import re +import pathlib + + +def seed_everything(seed, cudnn_deterministic=False): + """ + Function that sets seed for pseudo-random number generators in: + pytorch, numpy, python.random + + Args: + seed: the integer value seed for global random state + """ + if seed is not None: + # print(f"Global seed set to {seed}") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # if cudnn_deterministic: + # torch.backends.cudnn.deterministic = True + # warnings.warn('You have chosen to seed training. ' + # 'This will turn on the CUDNN deterministic setting, ' + # 'which can slow down your training considerably! ' + # 'You may see unexpected behavior when restarting ' + # 'from checkpoints.') + + +def is_primary(): + return get_rank() == 0 + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + + return dist.get_rank() + + +def load_yaml_config(path): + with open(path) as f: + config = yaml.full_load(f) + return config + + +def save_config_to_yaml(config, path): + assert path.endswith('.yaml') + with open(path, 'w') as f: + f.write(yaml.dump(config)) + f.close() + + +def save_dict_to_json(d, path, indent=None): + json.dump(d, open(path, 'w'), indent=indent) + + +def load_dict_from_json(path): + return json.load(open(path, 'r')) + + +def write_args(args, path): + args_dict = dict((name, getattr(args, name)) for name in dir(args) + if not name.startswith('_')) + with open(path, 'a') as args_file: + args_file.write('==> torch version: {}\n'.format(torch.__version__)) + args_file.write( + '==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) + args_file.write('==> Cmd:\n') + args_file.write(str(sys.argv)) + args_file.write('\n==> args:\n') + for k, v in sorted(args_dict.items()): + args_file.write(' %s: %s\n' % (str(k), str(v))) + args_file.close() + + +class Logger(object): + def __init__(self, args): + self.args = args + self.save_dir = args.save_dir + self.is_primary = is_primary() + + if self.is_primary: + os.makedirs(self.save_dir, exist_ok=True) + + # save the args and config + self.config_dir = os.path.join(self.save_dir, 'configs') + os.makedirs(self.config_dir, exist_ok=True) + file_name = os.path.join(self.config_dir, 'args.txt') + write_args(args, file_name) + + log_dir = os.path.join(self.save_dir, 'logs') + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + self.text_writer = open(os.path.join(log_dir, 'log.txt'), + 'a') # 'w') + if args.tensorboard: + self.log_info('using tensorboard') + self.tb_writer = torch.utils.tensorboard.SummaryWriter( + log_dir=log_dir + ) # tensorboard.SummaryWriter(log_dir=log_dir) + else: + self.tb_writer = None + + def save_config(self, config): + if self.is_primary: + save_config_to_yaml(config, + os.path.join(self.config_dir, 'config.yaml')) + + def log_info(self, info, check_primary=True): + if self.is_primary or (not check_primary): + print(info) + if self.is_primary: + info = str(info) + time_str = time.strftime('%Y-%m-%d-%H-%M') + info = '{}: {}'.format(time_str, info) + if not info.endswith('\n'): + info += '\n' + self.text_writer.write(info) + self.text_writer.flush() + + def add_scalar(self, **kargs): + """Log a scalar variable.""" + if self.is_primary: + if self.tb_writer is not None: + self.tb_writer.add_scalar(**kargs) + + def add_scalars(self, **kargs): + """Log a scalar variable.""" + if self.is_primary: + if self.tb_writer is not None: + self.tb_writer.add_scalars(**kargs) + + def add_image(self, **kargs): + """Log a scalar variable.""" + if self.is_primary: + if self.tb_writer is not None: + self.tb_writer.add_image(**kargs) + + def add_images(self, **kargs): + """Log a scalar variable.""" + if self.is_primary: + if self.tb_writer is not None: + self.tb_writer.add_images(**kargs) + + def close(self): + if self.is_primary: + self.text_writer.close() + self.tb_writer.close() + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, aspect="auto", origin="lower", interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj, num_ckpt_keep=5): + name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1) + ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*')) + if len(ckpts) > num_ckpt_keep: + [os.remove(c) for c in ckpts[:-num_ckpt_keep]] + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/inspiremusic/utils/train_utils.py b/inspiremusic/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9312eb5a83f37388cd4de547beef1bc0f95b6d28 --- /dev/null +++ b/inspiremusic/utils/train_utils.py @@ -0,0 +1,300 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2023 Horizon Inc. (authors: Xingchen Song) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +import logging +import os +import torch +import json +import re +import datetime +import yaml + +import deepspeed +import torch.optim as optim +import torch.distributed as dist + +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from torch.nn.utils import clip_grad_norm_ + +from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live + +from inspiremusic.dataset.dataset import Dataset +from inspiremusic.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR + + +def init_distributed(args): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + + ', rank {}, world_size {}'.format(rank, world_size)) + if args.train_engine == 'torch_ddp': + torch.cuda.set_device(local_rank) + dist.init_process_group(args.dist_backend) + else: + deepspeed.init_distributed(dist_backend=args.dist_backend) + return world_size, local_rank, rank + + +def init_dataset_and_dataloader(args, configs): + gan = False + data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline'] + train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', shuffle=True, partition=True) + cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', shuffle=False, partition=False) + + # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts + train_data_loader = DataLoader(train_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + timeout=60) + cv_data_loader = DataLoader(cv_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + timeout=60) + return train_dataset, cv_dataset, train_data_loader, cv_data_loader + + +def check_modify_and_save_config(args, configs): + if args.train_engine == "torch_ddp": + configs['train_conf']["dtype"] = 'fp32' + else: + with open(args.deepspeed_config, 'r') as fin: + ds_configs = json.load(fin) + if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: + configs['train_conf']["dtype"] = "fp16" + elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: + configs['train_conf']["dtype"] = "bf16" + else: + configs['train_conf']["dtype"] = "fp32" + assert ds_configs["train_micro_batch_size_per_gpu"] == 1 + # if use deepspeed, override ddp config + configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * + configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"]) + configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"] + configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"] + configs['train_conf']['log_interval'] = ds_configs["steps_per_print"] + return configs + + +def wrap_cuda_model(args, model): + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + if args.train_engine == "torch_ddp": # native pytorch ddp + assert (torch.cuda.is_available()) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) + else: + if int(os.environ.get('RANK', 0)) == 0: + logging.info("Estimating model states memory needs (zero2)...") + estimate_zero2_model_states_mem_needs_all_live( + model, + num_gpus_per_node=local_world_size, + num_nodes=world_size // local_world_size) + return model + +def init_optimizer_and_scheduler(args, configs, model): + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + # use deepspeed optimizer for speedup + if args.train_engine == "deepspeed": + def scheduler(opt): + return scheduler_type(opt, **configs['train_conf']['scheduler_conf']) + model, optimizer, _, scheduler = deepspeed.initialize( + args=args, + model=model, + optimizer=None, + lr_scheduler=scheduler, + model_parameters=model.parameters()) + + return model, optimizer, scheduler + + +def init_summarywriter(args): + writer = None + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.model_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + return writer + + +def save_model(model, model_name, info_dict): + rank = int(os.environ.get('RANK', 0)) + model_dir = info_dict["model_dir"] + save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name)) + + if info_dict["train_engine"] == "torch_ddp": + if rank == 0: + torch.save(model.module.state_dict(), save_model_path) + else: + with torch.no_grad(): + model.save_checkpoint(save_dir=model_dir, + tag=model_name, + client_state=info_dict) + if rank == 0: + info_path = re.sub('.pt$', '.yaml', save_model_path) + info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + with open(info_path, 'w') as fout: + data = yaml.dump(info_dict) + fout.write(data) + logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path)) + + +def inspiremusic_join(group_join, info_dict): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + + if info_dict["batch_idx"] != 0: + # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr + try: + dist.monitored_barrier(group=group_join, + timeout=group_join.options._timeout) + return False + except RuntimeError as e: + logging.info("Detected uneven workload distribution: {}\n".format(e) + + "Break current worker to manually join all workers, " + + "world_size {}, current rank {}, current local_rank {}\n". + format(world_size, rank, local_rank)) + return True + else: + return False + + +def batch_forward(model, batch, info_dict, scaler): + device = int(os.environ.get('LOCAL_RANK', 0)) + + dtype = info_dict["dtype"] + if dtype == "fp16": + dtype = torch.float16 + elif dtype == "bf16": + dtype = torch.bfloat16 + else: # fp32 + dtype = torch.float32 + + if info_dict['train_engine'] == 'torch_ddp': + autocast = torch.cuda.amp.autocast(enabled=scaler is not None) + else: + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False) + + with autocast: + info_dict['loss_dict'] = model(batch, device) + return info_dict + + +def batch_backward(model, info_dict, scaler): + if info_dict["train_engine"] == "deepspeed": + scaled_loss = model.backward(info_dict['loss_dict']['loss']) + else: + scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad'] + if scaler is not None: + scaler.scale(scaled_loss).backward() + else: + scaled_loss.backward() + + info_dict['loss_dict']['loss'] = scaled_loss + return info_dict + +def update_parameter_and_lr(model, optimizer, scheduler, info_dict, scaler=None): + grad_norm = 0.0 + if info_dict['train_engine'] == "deepspeed": + info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary() + model.step() + grad_norm = model.get_global_grad_norm() + elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0: + if scaler is not None: + scaler.unscale_(optimizer) # Unscale gradients before clipping + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + scaler.step(optimizer) + scaler.update() + else: + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + if torch.isfinite(grad_norm): + optimizer.step() + optimizer.zero_grad() + scheduler.step() + info_dict["lr"] = optimizer.param_groups[0]['lr'] + info_dict["grad_norm"] = grad_norm + return info_dict + + +def log_per_step(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict.get('epoch', 0) + step = info_dict["step"] + batch_idx = info_dict["batch_idx"] + loss_dict = info_dict['loss_dict'] + rank = int(os.environ.get('RANK', 0)) + + # only rank 0 write to tensorboard to avoid multi-process write + if writer is not None: + if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \ + (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0): + for k in ['epoch', 'lr', 'grad_norm']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) + + # TRAIN & CV, Shell log (stdout) + if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0: + log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1) + for name, value in loss_dict.items(): + log_str += '{} {:.6f} '.format(name, value.item()) + if tag == "TRAIN": + log_str += 'lr {:.8f} grad_norm {:.6f}'.format( + info_dict["lr"], info_dict['grad_norm']) + log_str += ' rank {}'.format(rank) + logging.debug(log_str) + + +def log_per_save(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict["epoch"] + step = info_dict["step"] + loss_dict = info_dict["loss_dict"] + lr = info_dict['lr'] + rank = int(os.environ.get('RANK', 0)) + logging.info( + 'Epoch {} Step {} CV info lr {} {} rank {}'.format( + epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()]))) + + if writer is not None: + for k in ['epoch', 'lr']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) diff --git a/inspiremusic/utils/utils.py b/inspiremusic/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db1d8b41c94be6c41b8424d3d036e17a6471234d --- /dev/null +++ b/inspiremusic/utils/utils.py @@ -0,0 +1,22 @@ +import os +import sys + +def align_trans_scp_file(trans, scp): + trans_dict = {} + with open(trans, 'r') as f: + for line in f: + sec = line.strip().split("\t") + trans_dict[sec[0]] = sec[1] + scp_dict = {} + with open(scp, 'r') as f: + for line in f: + sec = line.strip().split(" ") + scp_dict[sec[0]] = sec[1] + with open("text", "w") as f: + for k, v in scp_dict.items(): + f.write("%s\t%s\n"%(k,trans_dict[k])) + +if __name__ == '__main__': + trans = sys.argv[1] + scp = sys.argv[2] + align_trans_scp_file(trans, scp) \ No newline at end of file diff --git a/inspiremusic/version.txt b/inspiremusic/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa3386889155607fc99e0c091986218549648a89 --- /dev/null +++ b/inspiremusic/version.txt @@ -0,0 +1 @@ +v0.1 \ No newline at end of file diff --git a/inspiremusic/wavtokenizer/.DS_Store b/inspiremusic/wavtokenizer/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..edf661a8d23d97390558a7d72a84297bf0f2897f Binary files /dev/null and b/inspiremusic/wavtokenizer/.DS_Store differ diff --git a/inspiremusic/wavtokenizer/__init__.py b/inspiremusic/wavtokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/wavtokenizer/decoder/__init__.py b/inspiremusic/wavtokenizer/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/wavtokenizer/decoder/dataset.py b/inspiremusic/wavtokenizer/decoder/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..345a1617291648fe5f8671e7ea897c539fdcb2f5 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/dataset.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass + +import numpy as np +import torch +import torchaudio +from pytorch_lightning import LightningDataModule +from torch.utils.data import Dataset, DataLoader + +import soundfile +# import librosa +import random + +torch.set_num_threads(1) + + +@dataclass +class DataConfig: + filelist_path: str + sampling_rate: int + num_samples: int + batch_size: int + num_workers: int + +def collate_fn(batch): + batch = [item for item in batch if item is not None] + return torch.stack(batch, dim=0) + +class VocosDataModule(LightningDataModule): + def __init__(self, train_params: DataConfig, val_params: DataConfig): + super().__init__() + self.train_config = train_params + self.val_config = val_params + + def _get_dataloder(self, cfg: DataConfig, train: bool): + dataset = VocosDataset(cfg, train=train) + dataloader = DataLoader( + dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, collate_fn=collate_fn + ) + return dataloader + + def train_dataloader(self) -> DataLoader: + return self._get_dataloder(self.train_config, train=True) + + def val_dataloader(self) -> DataLoader: + return self._get_dataloder(self.val_config, train=False) + + +class VocosDataset(Dataset): + def __init__(self, cfg: DataConfig, train: bool): + with open(cfg.filelist_path) as f: + self.filelist = f.read().splitlines() + self.sampling_rate = cfg.sampling_rate + self.num_samples = cfg.num_samples + self.train = train + + def __len__(self) -> int: + return len(self.filelist) + + def __getitem__(self, index: int) -> torch.Tensor: + audio_path = self.filelist[index] + # y, sr = torchaudio.load(audio_path) + # print(audio_path,"111") + try: + y1, sr = soundfile.read(audio_path) + # y1, sr = librosa.load(audio_path,sr=None) + y = torch.tensor(y1).float().unsqueeze(0) + # if y.size(0) > 1: + # # mix to mono + # y = y.mean(dim=0, keepdim=True) + if y.ndim > 2: + # mix to mono + # print("有问题哈,数据处理部分") + # y = y.mean(dim=-1, keepdim=False) + random_channel = random.randint(0, y.size(-1) - 1) + y = y[:, :, random_channel] + + gain = np.random.uniform(-1, -6) if self.train else -3 + y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) + if sr != self.sampling_rate: + y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) + if y.size(-1) < self.num_samples: + pad_length = self.num_samples - y.size(-1) + padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) + y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) + elif self.train: + start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) + y = y[:, start : start + self.num_samples] + else: + # During validation, take always the first segment for determinism + y = y[:, : self.num_samples] + + return y[0] + except Exception as e: + print(f"Error processing file {audio_path} at index {index}: {e}") + # 这里可以继续选择抛出异常,或者返回一个 None 表示无效数据 + return None + + # def __getitem__(self, index: int) -> torch.Tensor: + # audio_path = self.filelist[index] + # try: + # y, sr = torchaudio.load(audio_path) + # if y.size(0) > 1: + # # 随机选择一个通道 + # random_channel = random.randint(0, y.size(0) - 1) + # y = y[random_channel, :].unsqueeze(0) # 保持返回值为 (1, T) 的形式 + # # gain = np.random.uniform(-1, -6) if self.train else -3 + # # y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) + # if sr != self.sampling_rate: + # y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) + # if y.size(-1) < self.num_samples: + # pad_length = self.num_samples - y.size(-1) + # padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) + # y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) + # elif self.train: + # start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) + # y = y[:, start: start + self.num_samples] + # else: + # # During validation, take always the first segment for determinism + # y = y[:, :self.num_samples] + # return y[0] + # except Exception as e: + # print(f"Error processing file {audio_path} at index {index}: {e}") + # # 这里可以继续选择抛出异常,或者返回一个 None 表示无效数据 + # return None \ No newline at end of file diff --git a/inspiremusic/wavtokenizer/decoder/discriminator_dac.py b/inspiremusic/wavtokenizer/decoder/discriminator_dac.py new file mode 100644 index 0000000000000000000000000000000000000000..33ef3258a3a4a3ab11f17ca9c3ad381de95b6477 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/discriminator_dac.py @@ -0,0 +1,249 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# from audiotools import AudioSignal +# from audiotools import ml +# from audiotools import STFTParams +from einops import rearrange +from torch.nn.utils import weight_norm + +from collections import namedtuple + +STFTParams = namedtuple( + "STFTParams", + ["window_length", "hop_length", "window_type", "match_stride", "padding_type"], +) + +STFTParams.__new__.__defaults__ = (None, None, None, None, None) + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 48000): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + # x = AudioSignal(x, self.sample_rate) + # x.resample(self.sample_rate // self.rate) + # x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 24000, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 24000 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + self.n_fft = window_length + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + # x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + # x = torch.view_as_real(x.stft()) + + # x.squeeze(0).stft(n_fft=1024,win_length=1024,return_complex=True).size() + # breakpoint() + if x.size(0)==1: + # x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.window_length,return_complex=True).unsqueeze(0)) + x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(0)) + else: + # x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.window_length,return_complex=True).unsqueeze(1)) + x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(1)) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +# class DACDiscriminator(ml.BaseModel): +class DACDiscriminator(nn.Module): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 24000, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 24000 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = DACDiscriminator() + x = torch.zeros(1, 1, 24000) + results = disc(x) + breakpoint() + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print("00") diff --git a/inspiremusic/wavtokenizer/decoder/discriminators.py b/inspiremusic/wavtokenizer/decoder/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6dece570b3d181f3fd2206a4dae2549b9e0fa3 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/discriminators.py @@ -0,0 +1,202 @@ +from typing import Tuple, List + +import torch +from torch import nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm + + +class MultiPeriodDiscriminator(nn.Module): + """ + Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + periods (tuple[int]): Tuple of periods for each discriminator. + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None): + super().__init__() + self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(nn.Module): + def __init__( + self, + period: int, + in_channels: int = 1, + kernel_size: int = 5, + stride: int = 3, + lrelu_slope: float = 0.1, + num_embeddings: int = None, + ): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + self.lrelu_slope = lrelu_slope + + def forward( + self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x = x.unsqueeze(1) + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = torch.nn.functional.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for i, l in enumerate(self.convs): + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + if i > 0: + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), + num_embeddings: int = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator. + Each resolution should be a tuple of (n_fft, hop_length, win_length). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + resolution: Tuple[int, int, int], + channels: int = 64, + in_channels: int = 1, + num_embeddings: int = None, + lrelu_slope: float = 0.1, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.lrelu_slope = lrelu_slope + self.convs = nn.ModuleList( + [ + weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)), + weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) + torch.nn.init.zeros_(self.emb.weight) + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) + + def forward( + self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x: torch.Tensor) -> torch.Tensor: + n_fft, hop_length, win_length = self.resolution + magnitude_spectrogram = torch.stft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=None, # interestingly rectangular window kind of works here + center=True, + return_complex=True, + ).abs() + + return magnitude_spectrogram diff --git a/inspiremusic/wavtokenizer/decoder/experiment.py b/inspiremusic/wavtokenizer/decoder/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..f5557a218aedf77cc43ba3e774d6accb6003e45f --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/experiment.py @@ -0,0 +1,474 @@ +import math + +import numpy as np +import pytorch_lightning as pl +import torch +import torchaudio +import transformers +import yaml + +from decoder.discriminator_dac import DACDiscriminator + +from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator +from decoder.feature_extractors import FeatureExtractor +from decoder.heads import FourierHead +from decoder.helpers import plot_spectrogram_to_numpy +from decoder.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss, DACGANLoss +from decoder.models import Backbone +from decoder.modules import safe_log +from decoder.pretrained_model import instantiate_class + + +class VocosExp(pl.LightningModule): + # noinspection PyUnusedLocal + def __init__( + self, + feature_extractor: FeatureExtractor, + backbone: Backbone, + head: FourierHead, + resume_config: str, + resume_model: str, + sample_rate: int = 24000, + initial_learning_rate: float = 2e-4, + num_warmup_steps: int = 0, + mel_loss_coeff: float = 45, + mrd_loss_coeff: float = 1.0, + pretrain_mel_steps: int = 0, + decay_mel_coeff: bool = False, + evaluate_utmos: bool = False, + evaluate_pesq: bool = False, + evaluate_periodicty: bool = False, + resume: bool = False, + ): + """ + Args: + feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals. + backbone (Backbone): An instance of Backbone model. + head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform. + sample_rate (int): Sampling rate of the audio signals. + initial_learning_rate (float): Initial learning rate for the optimizer. + num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0. + mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45. + mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0. + pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0. + decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False. + evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run. + evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run. + evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run. + """ + super().__init__() + self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"]) + + self.feature_extractor = feature_extractor + self.backbone = backbone + self.head = head + + self.resume_config = resume_config + self.resume_model = resume_model + self.resume = resume + + self.multiperioddisc = MultiPeriodDiscriminator() + self.multiresddisc = MultiResolutionDiscriminator() + + + self.dac = DACDiscriminator() + + self.dacdiscriminator = DACGANLoss(self.dac) + + self.disc_loss = DiscriminatorLoss() + self.gen_loss = GeneratorLoss() + self.feat_matching_loss = FeatureMatchingLoss() + self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate) + + self.train_discriminator = False + self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff + + def configure_optimizers(self): + disc_params = [ + {"params": self.multiperioddisc.parameters()}, + {"params": self.multiresddisc.parameters()}, + {"params": self.dac.parameters()}, + ] + gen_params = [ + {"params": self.feature_extractor.parameters()}, + {"params": self.backbone.parameters()}, + {"params": self.head.parameters()}, + ] + + opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate) + opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate) + + max_steps = self.trainer.max_steps // 2 # Max steps per optimizer + scheduler_disc = transformers.get_cosine_schedule_with_warmup( + opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, + ) + scheduler_gen = transformers.get_cosine_schedule_with_warmup( + opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, + ) + + return ( + [opt_disc, opt_gen], + [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}], + ) + + def forward(self, audio_input, **kwargs): + features, _, commit_loss = self.feature_extractor(audio_input, **kwargs) + # print('1111', self.feature_extractor.state_dict()['encodec.decoder.model.3.convtr.convtr.weight_g']) + x = self.backbone(features, **kwargs) + audio_output = self.head(x) + return audio_output, commit_loss + + def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): + audio_input = batch + + # train discriminator + if optimizer_idx == 0 and self.train_discriminator: + with torch.no_grad(): + audio_hat, _ = self(audio_input, **kwargs) + + + loss_dac=self.dacdiscriminator.discriminator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1)) + + real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,) + real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,) + loss_mp, loss_mp_real, _ = self.disc_loss( + disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp + ) + loss_mrd, loss_mrd_real, _ = self.disc_loss( + disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd + ) + loss_mp /= len(loss_mp_real) + loss_mrd /= len(loss_mrd_real) + loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + loss_dac + + self.log("discriminator/total", loss, prog_bar=True) + self.log("discriminator/multi_period_loss", loss_mp) + self.log("discriminator/multi_res_loss", loss_mrd) + self.log("discriminator/dac", loss_dac) + return loss + + # train generator + if optimizer_idx == 1: + audio_hat, commit_loss = self(audio_input, **kwargs) + if self.train_discriminator: + + loss_dac_1,loss_dac_2 = self.dacdiscriminator.generator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1)) + _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc( + y=audio_input, y_hat=audio_hat, **kwargs, + ) + _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc( + y=audio_input, y_hat=audio_hat, **kwargs, + ) + loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp) + loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd) + loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) + loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) + loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp) + loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd) + + self.log("generator/multi_period_loss", loss_gen_mp) + self.log("generator/multi_res_loss", loss_gen_mrd) + self.log("generator/feature_matching_mp", loss_fm_mp) + self.log("generator/feature_matching_mrd", loss_fm_mrd) + self.log("generator/loss_dac_1", loss_dac_1) + self.log("generator/loss_dac_2", loss_dac_2) + else: + loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0 + + mel_loss = self.melspec_loss(audio_hat, audio_input) + loss = ( + loss_gen_mp + + self.hparams.mrd_loss_coeff * loss_gen_mrd + + loss_fm_mp + + self.hparams.mrd_loss_coeff * loss_fm_mrd + + self.mel_loss_coeff * mel_loss + + 1000 * commit_loss + + loss_dac_1 + + loss_dac_2 + ) + + self.log("generator/total_loss", loss, prog_bar=True) + self.log("mel_loss_coeff", self.mel_loss_coeff) + self.log("generator/mel_loss", mel_loss) + self.log("commit_loss", commit_loss) + + if self.global_step % 1000 == 0 and self.global_rank == 0: + self.logger.experiment.add_audio( + "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate + ) + self.logger.experiment.add_audio( + "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate + ) + with torch.no_grad(): + mel = safe_log(self.melspec_loss.mel_spec(audio_input[0])) + mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0])) + self.logger.experiment.add_image( + "train/mel_target", + plot_spectrogram_to_numpy(mel.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + self.logger.experiment.add_image( + "train/mel_pred", + plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + + return loss + + def on_validation_epoch_start(self): + if self.hparams.evaluate_utmos: + from metrics.UTMOS import UTMOSScore + + if not hasattr(self, "utmos_model"): + self.utmos_model = UTMOSScore(device=self.device) + + def validation_step(self, batch, batch_idx, **kwargs): + audio_input = batch + audio_hat, commit_loss = self(audio_input, **kwargs) + + audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000) + audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000) + + if self.hparams.evaluate_periodicty: + from metrics.periodicity import calculate_periodicity_metrics + + periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz) + else: + periodicity_loss = pitch_loss = f1_score = 0 + + if self.hparams.evaluate_utmos: + utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean() + else: + utmos_score = torch.zeros(1, device=self.device) + + if self.hparams.evaluate_pesq: + from pesq import pesq + + pesq_score = 0 + for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()): + pesq_score += pesq(16000, ref, deg, "wb", on_error=1) + pesq_score /= len(audio_16_khz) + pesq_score = torch.tensor(pesq_score) + else: + pesq_score = torch.zeros(1, device=self.device) + + mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1)) + total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) + 1000 * commit_loss + + return { + "val_loss": total_loss, + "mel_loss": mel_loss, + "utmos_score": utmos_score, + "pesq_score": pesq_score, + "periodicity_loss": periodicity_loss, + "pitch_loss": pitch_loss, + "f1_score": f1_score, + "audio_input": audio_input[0], + "audio_pred": audio_hat[0], + } + + def validation_epoch_end(self, outputs): + if self.global_rank == 0: + *_, audio_in, audio_pred = outputs[0].values() + self.logger.experiment.add_audio( + "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate + ) + self.logger.experiment.add_audio( + "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate + ) + mel_target = safe_log(self.melspec_loss.mel_spec(audio_in)) + mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred)) + self.logger.experiment.add_image( + "val_mel_target", + plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + self.logger.experiment.add_image( + "val_mel_hat", + plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() + mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean() + utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean() + pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean() + periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean() + pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean() + f1_score = np.array([x["f1_score"] for x in outputs]).mean() + + self.log("val_loss", avg_loss, sync_dist=True) + self.log("val/mel_loss", mel_loss, sync_dist=True) + self.log("val/utmos_score", utmos_score, sync_dist=True) + self.log("val/pesq_score", pesq_score, sync_dist=True) + self.log("val/periodicity_loss", periodicity_loss, sync_dist=True) + self.log("val/pitch_loss", pitch_loss, sync_dist=True) + self.log("val/f1_score", f1_score, sync_dist=True) + + @property + def global_step(self): + """ + Override global_step so that it returns the total number of batches processed + """ + return self.trainer.fit_loop.epoch_loop.total_batch_idx + + def on_train_batch_start(self, *args): + if self.global_step >= self.hparams.pretrain_mel_steps: + self.train_discriminator = True + else: + self.train_discriminator = False + + def on_train_batch_end(self, *args): + def mel_loss_coeff_decay(current_step, num_cycles=0.5): + max_steps = self.trainer.max_steps // 2 + if current_step < self.hparams.num_warmup_steps: + return 1.0 + progress = float(current_step - self.hparams.num_warmup_steps) / float( + max(1, max_steps - self.hparams.num_warmup_steps) + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + if self.hparams.decay_mel_coeff: + self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1) + + +class WavTokenizer(VocosExp): + """ + WavTokenizer is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN. + It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to + a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step, + while during validation, a fixed bandwidth_id is used. + """ + + def __init__( + self, + feature_extractor: FeatureExtractor, + backbone: Backbone, + head: FourierHead, + resume_config: str, + resume_model: str, + sample_rate: int = 24000, + initial_learning_rate: float = 2e-4, + num_warmup_steps: int = 0, + mel_loss_coeff: float = 45, + mrd_loss_coeff: float = 1.0, + pretrain_mel_steps: int = 0, + decay_mel_coeff: bool = False, + evaluate_utmos: bool = False, + evaluate_pesq: bool = False, + evaluate_periodicty: bool = False, + resume: bool = False, + ): + super().__init__( + feature_extractor, + backbone, + head, + resume_config, + resume_model, + sample_rate, + initial_learning_rate, + num_warmup_steps, + mel_loss_coeff, + mrd_loss_coeff, + pretrain_mel_steps, + decay_mel_coeff, + evaluate_utmos, + evaluate_pesq, + evaluate_periodicty, + resume + ) + # Override with conditional discriminators + # VocosExp.__init__(self, feature_extractor, backbone, head, resume_config, resume_model) + # if self.resume: + # VocosExp.load_from_checkpoint(self.resume_model) + self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) + self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) + self.dac = DACDiscriminator() + if self.resume: + print('加载预训练模型:', self.resume_model) + # with open(self.resume_config, "r") as f: + # config = yaml.safe_load(f) + # feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"]) + # backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"]) + # head = instantiate_class(args=(), init=config['model']['init_args']["head"]) + + # 不加载量化器部分权重 + state_dict_raw = torch.load(self.resume_model, map_location=self.device)['state_dict'] + state_dict_fa_qa = dict() + state_dict_fa_en = dict() + state_dict_fa_de = dict() + state_dict_bb = dict() + state_dict_hd = dict() + state_dict_mp = dict() + state_dict_mr = dict() + state_dict_dac = dict() + for k, v in state_dict_raw.items(): + # breakpoint() + if k.startswith('feature_extractor.encodec.quantizer'): + # breakpoint() + # print("*****",k) + ss = k[46:48] + if ss[-1] == '.': + num = int(ss[0]) + # print("num,k",num,k[36:]) + if num <= 7: + state_dict_fa_qa[k[36:]] = v + if k.startswith('feature_extractor.encodec.encoder'): + state_dict_fa_en[k[34:]] = v + if k.startswith('feature_extractor.encodec.decoder'): + state_dict_fa_de[k[34:]] = v + if k.startswith('backbone.'): + state_dict_bb[k[9:]] = v + if k.startswith('head.'): + state_dict_hd[k[5:]] = v + if k.startswith('multiperioddisc.'): + state_dict_mp[k[16:]] = v + if k.startswith('multiresddisc.'): + state_dict_mr[k[14:]] = v + if k.startswith('dac.'): + state_dict_dac[k[4:]] = v + # breakpoint() + # feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True) + feature_extractor.encodec.encoder.load_state_dict(state_dict_fa_en, strict=True) + feature_extractor.encodec.decoder.load_state_dict(state_dict_fa_de, strict=True) + feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True) + backbone.load_state_dict(state_dict_bb, strict=True) + head.load_state_dict(state_dict_hd, strict=True) + self.feature_extractor = feature_extractor.to(self.device) + self.backbone = backbone.to(self.device) + self.head = head.to(self.device) + self.multiperioddisc.load_state_dict(state_dict_mp, strict=True) + self.multiresddisc.load_state_dict(state_dict_mr, strict=True) + self.dac.load_state_dict(state_dict_dac, strict=True) + + def training_step(self, *args): + # print('-------------------train--------------------') + # if self.global_rank == 0 and self.resume: + # config_path = self.resume_config + # model_path = self.resume_model + # self.pretrained_load(config_path, model_path) + # print('加载预训练模型:', model_path) + bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,) + output = super().training_step(*args, bandwidth_id=bandwidth_id) + return output + + def validation_step(self, *args): + # print('-------------------valid--------------------') + bandwidth_id = torch.tensor([0], device=self.device) + output = super().validation_step(*args, bandwidth_id=bandwidth_id) + return output + + def validation_epoch_end(self, outputs): + if self.global_rank == 0: + *_, audio_in, _ = outputs[0].values() + # Resynthesis with encodec for reference + self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0]) + encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :]) + self.logger.experiment.add_audio( + "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate, + ) + + super().validation_epoch_end(outputs) diff --git a/inspiremusic/wavtokenizer/decoder/feature_extractors.py b/inspiremusic/wavtokenizer/decoder/feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..d4672d141ba89b88ecdec4d48464252cb524fb9f --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/feature_extractors.py @@ -0,0 +1,176 @@ +from typing import List + +import torch +import torchaudio +from torch import nn +import math +# from inspiremusic.wavtokenizer.decoder.modules import safe_log +from inspiremusic.wavtokenizer.encoder.modules import SEANetEncoder, SEANetDecoder +from inspiremusic.wavtokenizer.encoder import EncodecModel +from inspiremusic.wavtokenizer.encoder.quantization import ResidualVectorQuantizer + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + + +class FeatureExtractor(nn.Module): + """Base class for feature extractors.""" + + def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Extract features from the given audio. + + Args: + audio (Tensor): Input audio waveform. + + Returns: + Tensor: Extracted features of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class MelSpectrogramFeatures(FeatureExtractor): + def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=padding == "center", + power=1, + ) + + def forward(self, audio, **kwargs): + if self.padding == "same": + pad = self.mel_spec.win_length - self.mel_spec.hop_length + audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") + mel = self.mel_spec(audio) + features = safe_log(mel) + return features + + +class EncodecFeatures(FeatureExtractor): + def __init__( + self, + encodec_model: str = "encodec_24khz", + bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0], + train_codebooks: bool = False, + num_quantizers: int = 1, + dowmsamples: List[int] = [6, 5, 5, 4], + vq_bins: int = 16384, + vq_kmeans: int = 800, + ): + super().__init__() + + # breakpoint() + self.frame_rate = 25 # not use + # n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate)) + n_q = num_quantizers # important + encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2, + dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU', + kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2, + true_skip=False, compress=2) + decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2, + dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU', + kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2, + true_skip=False, compress=2) + quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans, + decay=0.99, kmeans_init=True) + + # breakpoint() + if encodec_model == "encodec_24khz": + self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer, + target_bandwidths=bandwidths, sample_rate=24000, channels=1) + else: + raise ValueError( + f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'." + ) + for param in self.encodec.parameters(): + param.requires_grad = True + # self.num_q = n_q + # codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0) + # self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks) + self.bandwidths = bandwidths + + # @torch.no_grad() + # def get_encodec_codes(self, audio): + # audio = audio.unsqueeze(1) + # emb = self.encodec.encoder(audio) + # codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth) + # return codes + + def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)): + if self.training: + self.encodec.train() + + audio = audio.unsqueeze(1) # audio(16,24000) + + # breakpoint() + + emb = self.encodec.encoder(audio) + q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id]) + quantized = q_res.quantized + codes = q_res.codes + commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75) + + return quantized, codes, commit_loss + + # codes = self.get_encodec_codes(audio) + # # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights` + # # with offsets given by the number of bins, and finally summed in a vectorized operation. + # offsets = torch.arange( + # 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device + # ) + # embeddings_idxs = codes + offsets.view(-1, 1, 1) + # features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0) + # return features.transpose(1, 2) + + def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor): + if self.training: + self.encodec.train() + + audio = audio.unsqueeze(1) # audio(16,24000) + emb = self.encodec.encoder(audio) + q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id]) + quantized = q_res.quantized + codes = q_res.codes + commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75) + + return quantized, codes, commit_loss + + def _infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)): + if self.training: + self.encodec.train() + + audio = audio.unsqueeze(1) # audio(16,24000) + emb = self.encodec.encoder(audio) + q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id]) + quantized = q_res.quantized + codes = q_res.codes + commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75) + + return quantized, codes, commit_loss \ No newline at end of file diff --git a/inspiremusic/wavtokenizer/decoder/heads.py b/inspiremusic/wavtokenizer/decoder/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..eb3b9f85bda23ae73c09e462cb584bc9878faca9 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/heads.py @@ -0,0 +1,159 @@ +import torch +from torch import nn +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz + +from inspiremusic.wavtokenizer.decoder.spectral_ops import IMDCT, ISTFT + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + + S = mag * (x + 1j * y) + + audio = self.istft(S) + return audio + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False, + ): + super().__init__() + out_dim = mdct_frame_len // 2 + self.out = nn.Linear(dim, out_dim) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + self.clip_audio = clip_audio + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.out.weight.mul_(scale.view(-1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + x = symexp(x) + x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + + return audio + + +class IMDCTCosHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): + super().__init__() + self.clip_audio = clip_audio + self.out = nn.Linear(dim, mdct_frame_len) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTCosHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + m, p = x.chunk(2, dim=2) + m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(m * torch.cos(p)) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + return audio diff --git a/inspiremusic/wavtokenizer/decoder/helpers.py b/inspiremusic/wavtokenizer/decoder/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3d303010352ad59dde2996605f124128ee17db36 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/helpers.py @@ -0,0 +1,71 @@ +import matplotlib +import numpy as np +import torch +from matplotlib import pyplot as plt +from pytorch_lightning import Callback + +matplotlib.use("Agg") + + +def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: + """ + Save a matplotlib figure to a numpy array. + + Args: + fig (Figure): Matplotlib figure object. + + Returns: + ndarray: Numpy array representing the figure. + """ + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray: + """ + Plot a spectrogram and convert it to a numpy array. + + Args: + spectrogram (ndarray): Spectrogram data. + + Returns: + ndarray: Numpy array representing the plotted spectrogram. + """ + spectrogram = spectrogram.astype(np.float32) + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +class GradNormCallback(Callback): + """ + Callback to log the gradient norm. + """ + + def on_after_backward(self, trainer, model): + model.log("grad_norm", gradient_norm(model)) + + +def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor: + """ + Compute the gradient norm. + + Args: + model (Module): PyTorch model. + norm_type (float, optional): Type of the norm. Defaults to 2.0. + + Returns: + Tensor: Gradient norm. + """ + grads = [p.grad for p in model.parameters() if p.grad is not None] + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type) + return total_norm diff --git a/inspiremusic/wavtokenizer/decoder/loss.py b/inspiremusic/wavtokenizer/decoder/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..30f32ccf9a3f5373335ddb8da1334f508c16f752 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/loss.py @@ -0,0 +1,159 @@ +from typing import List, Tuple + +import torch +import torchaudio +from torch import nn + +from decoder.modules import safe_log + +import torch.nn.functional as F + + +class MelSpecReconstructionLoss(nn.Module): + """ + L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample + """ + + def __init__( + self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100, + ): + super().__init__() + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1, + ) + + def forward(self, y_hat, y) -> torch.Tensor: + """ + Args: + y_hat (Tensor): Predicted audio waveform. + y (Tensor): Ground truth audio waveform. + + Returns: + Tensor: L1 loss between the mel-scaled magnitude spectrograms. + """ + mel_hat = safe_log(self.mel_spec(y_hat)) + mel = safe_log(self.mel_spec(y)) + + loss = torch.nn.functional.l1_loss(mel, mel_hat) + + return loss + + +class GeneratorLoss(nn.Module): + """ + Generator Loss module. Calculates the loss for the generator based on discriminator outputs. + """ + + def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + disc_outputs (List[Tensor]): List of discriminator outputs. + + Returns: + Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from + the sub-discriminators + """ + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean(torch.clamp(1 - dg, min=0)) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class DiscriminatorLoss(nn.Module): + """ + Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. + """ + + def forward( + self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. + disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. + + Returns: + Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from + the sub-discriminators for real outputs, and a list of + loss values for generated outputs. + """ + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(torch.clamp(1 - dr, min=0)) + g_loss = torch.mean(torch.clamp(1 + dg, min=0)) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +class FeatureMatchingLoss(nn.Module): + """ + Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. + """ + + def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: + """ + Args: + fmap_r (List[List[Tensor]]): List of feature maps from real samples. + fmap_g (List[List[Tensor]]): List of feature maps from generated samples. + + Returns: + Tensor: The calculated feature matching loss. + """ + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss + +class DACGANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + # d_fake = self.discriminator(fake.audio_data) + # d_real = self.discriminator(real.audio_data) + d_fake = self.discriminator(fake) + d_real = self.discriminator(real) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature + diff --git a/inspiremusic/wavtokenizer/decoder/models.py b/inspiremusic/wavtokenizer/decoder/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ce3d99c57feb48946039a7501c638874afdf62 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/models.py @@ -0,0 +1,266 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from inspiremusic.wavtokenizer.decoder.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv1d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb=None): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h = q.shape + q = q.permute(0, 2, 1) # b,hw,c + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + h_ = self.proj_out(h_) + + return x + h_ + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + self.temb_ch = 0 + block_in = dim + dropout = 0.1 + attn_type="vanilla" + + pos_net : tp.List[nn.Module] = [ + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + make_attn(block_in, attn_type=attn_type), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + Normalize(block_in) + ] + + self.pos_net = nn.Sequential(*pos_net) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.embed(x) + x = self.pos_net(x) + if self.adanorm: + # assert bandwidth_id is not None + if bandwidth_id is None: + bandwidth_id = torch.tensor(0, device='cuda') + x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, cond_embedding_id=bandwidth_id) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, input_channels, dim, num_blocks, layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x diff --git a/inspiremusic/wavtokenizer/decoder/modules.py b/inspiremusic/wavtokenizer/decoder/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..799a61fb94a2adc26c6e7a39e4ff3285f6556975 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/modules.py @@ -0,0 +1,214 @@ +from typing import Optional +from typing import Tuple + +import torch +from torch import nn +from torch.nn.utils import weight_norm, remove_weight_norm + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, ...] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: float = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) diff --git a/inspiremusic/wavtokenizer/decoder/pretrained.py b/inspiremusic/wavtokenizer/decoder/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..1231a05d4945e2f2debe620b1347cad3e6c7ca76 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/pretrained.py @@ -0,0 +1,253 @@ +import os +from typing import Tuple, Any, Union, Dict + +import torch +import yaml +from huggingface_hub import hf_hub_download +from torch import nn +from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures +from inspiremusic.wavtokenizer.decoder.heads import FourierHead +from inspiremusic.wavtokenizer.decoder.models import Backbone + + +def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: + """Instantiates a class with the given args and init. + + Args: + args: Positional arguments required for instantiation. + init: Dict of the form {"class_path":...,"init_args":...}. + + Returns: + The instantiated class object. + """ + kwargs = init.get("init_args", {}) + if not isinstance(args, tuple): + args = (args,) + class_module, class_name = init["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(*args, **kwargs) + + +class WavTokenizer(nn.Module): + """ + The Vocos class represents a Fourier-based neural vocoder for audio synthesis. + This class is primarily designed for inference, with support for loading from pretrained + model checkpoints. It consists of three main components: a feature extractor, + a backbone, and a head. + """ + + def __init__( + self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead, + ): + super().__init__() + self.feature_extractor = feature_extractor + self.backbone = backbone + self.head = head + + @classmethod + def from_hparams(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config["feature_extractor"]) + backbone = instantiate_class(args=(), init=config["backbone"]) + head = instantiate_class(args=(), init=config["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) + return model + + @classmethod + def from_pretrained(self, repo_id: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml") + model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") + model = self.from_hparams(config_path) + state_dict = torch.load(model_path, map_location="cpu") + if isinstance(model.feature_extractor, EncodecFeatures): + encodec_parameters = { + "feature_extractor.encodec." + key: value + for key, value in model.feature_extractor.encodec.state_dict().items() + } + state_dict.update(encodec_parameters) + model.load_state_dict(state_dict) + model.eval() + return model + + + @classmethod + def from_hparams_feat(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"]) + backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"]) + head = instantiate_class(args=(), init=config['model']['init_args']["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) + return model + + + @classmethod + def from_pretrained_feat(self, config_path, model_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams_feat(config_path) + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'): + state_dict[k] = v + + model.load_state_dict(state_dict) + model.eval() + return model + + @classmethod + def estimator(self, config_path, model_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams_feat(config_path) + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'): + state_dict[k] = v + model.load_state_dict(state_dict) + model.eval() + return model + + @classmethod + def from_pretrained0911(self, config_path, model_folder_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams0802(config_path) + + models = os.listdir(model_folder_path) + val_loss = [] + for item in models: + if not item.startswith('vocos_'): + continue + val_loss.append(item[-11:-5]) + val_loss.sort() + val_loss = val_loss[:3] # 取前3性能较好的模型平均 + state_dict = dict() + state_dicts = [] + for item in models: + if not item.startswith('vocos_'): + continue + ll = item[-11:-5] + if ll not in val_loss: + continue + model_path = model_folder_path + '/' + item + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict_single = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'): + state_dict_single[k] = v + state_dicts.append(state_dict_single) + for kk in state_dicts[0].keys(): + vv = state_dicts[0][kk] + for i in range(1, len(state_dicts)): + ss = state_dicts[i] + vv += ss[kk] + vm = vv/len(state_dicts) + state_dict[kk] = vm + model.load_state_dict(state_dict) + model.eval() + return model + + + @torch.inference_mode() + def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, + which is then passed through the backbone and the head to reconstruct the audio output. + + Args: + audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), + where B is the batch size and L is the waveform length. + + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818 + audio_output = self.decode(features, **kwargs) + return audio_output + + + # 0818 + @torch.inference_mode() + def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + features, discrete_codes, _ = self.feature_extractor(audio_input, **kwargs) + return features,discrete_codes + + + # 0818 + @torch.inference_mode() + def encode_infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + features, discrete_codes, _ = self.feature_extractor.infer(audio_input, **kwargs) + return features,discrete_codes + + @torch.inference_mode() + def infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + _, discrete_codes, _ = self.feature_extractor._infer(audio_input, **kwargs) + discrete_codes = discrete_codes.clamp(min=0, max=16383) + return discrete_codes + + @torch.inference_mode() + def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to decode audio waveform from already calculated features. The features input is passed through + the backbone and the head to reconstruct the audio output. + + Args: + features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, + C denotes the feature dimension, and L is the sequence length. + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + x = self.backbone(features_input, **kwargs) + audio_output = self.head(x) + return audio_output + + @torch.inference_mode() + def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor: + """ + Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's + codebook weights. + + Args: + codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L), + where K is the number of codebooks, B is the batch size and L is the sequence length. + + Returns: + Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension, + and L is the sequence length. + """ + assert isinstance( + self.feature_extractor, EncodecFeatures + ), "Feature extractor should be an instance of EncodecFeatures" + + if codes.dim() == 2: + codes = codes.unsqueeze(1) + + n_bins = self.feature_extractor.encodec.quantizer.bins + offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device) + embeddings_idxs = codes + offsets.view(-1, 1, 1) + + tmp=torch.cat([vq.codebook for vq in self.feature_extractor.encodec.quantizer.vq.layers],dim=0) + # features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0) + features = torch.nn.functional.embedding(embeddings_idxs, tmp).sum(dim=0) + features = features.transpose(1, 2) + + return features diff --git a/inspiremusic/wavtokenizer/decoder/pretrained_model.py b/inspiremusic/wavtokenizer/decoder/pretrained_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c919bb25685d78522c6b638cd46310c7ae5edc0d --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/pretrained_model.py @@ -0,0 +1,192 @@ +from typing import Tuple, Any, Union, Dict + +import torch +import yaml +from huggingface_hub import hf_hub_download +from torch import nn +from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures +from inspiremusic.wavtokenizer.decoder.heads import FourierHead +from inspiremusic.wavtokenizer.decoder.models import Backbone +from inspiremusic.wavtokenizer.decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator + + +def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: + """Instantiates a class with the given args and init. + + Args: + args: Positional arguments required for instantiation. + init: Dict of the form {"class_path":...,"init_args":...}. + + Returns: + The instantiated class object. + """ + kwargs = init.get("init_args", {}) + if not isinstance(args, tuple): + args = (args,) + class_module, class_name = init["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(*args, **kwargs) + + +class WavTokenizer(nn.Module): + """ + The Vocos class represents a Fourier-based neural vocoder for audio synthesis. + This class is primarily designed for inference, with support for loading from pretrained + model checkpoints. It consists of three main components: a feature extractor, + a backbone, and a head. + """ + + def __init__( + self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead, + multiperioddisc: MultiPeriodDiscriminator, multiresddisc: MultiResolutionDiscriminator, + ): + super().__init__() + self.feature_extractor = feature_extractor + self.backbone = backbone + self.head = head + + self.multiperioddisc = multiperioddisc + self.multiresddisc = multiresddisc + + @classmethod + def from_hparams0828(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"]) + backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"]) + head = instantiate_class(args=(), init=config['model']['init_args']["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head, + multiperioddisc=MultiPeriodDiscriminator(num_embeddings=4), + multiresddisc=MultiResolutionDiscriminator(num_embeddings=4)) + return model + + @classmethod + def from_pretrained0828(self, config_path, model_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams0828(config_path) + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.') \ + or k.startswith('multiperioddisc.') or k.startswith('multiresddisc.'): + state_dict[k] = v + # if isinstance(model.feature_extractor, EncodecFeatures): + # encodec_parameters = { + # "feature_extractor.encodec." + key: value + # for key, value in model.feature_extractor.encodec.state_dict().items() + # } + # state_dict.update(encodec_parameters) + model.load_state_dict(state_dict) + return model + + @classmethod + def from_hparams0802(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"]) + backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"]) + head = instantiate_class(args=(), init=config['model']['init_args']["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) + return model + + @classmethod + def from_pretrained0802(self, config_path, model_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams0802(config_path) + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'): + state_dict[k] = v + # if isinstance(model.feature_extractor, EncodecFeatures): + # encodec_parameters = { + # "feature_extractor.encodec." + key: value + # for key, value in model.feature_extractor.encodec.state_dict().items() + # } + # state_dict.update(encodec_parameters) + model.load_state_dict(state_dict) + model.eval() + return model + + @torch.inference_mode() + def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, + which is then passed through the backbone and the head to reconstruct the audio output. + + Args: + audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), + where B is the batch size and L is the waveform length. + + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818 + audio_output = self.decode(features, **kwargs) + return audio_output + + + # 0818 + @torch.inference_mode() + def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + features, _, _ = self.feature_extractor(audio_input, **kwargs) + return features + + + @torch.inference_mode() + def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to decode audio waveform from already calculated features. The features input is passed through + the backbone and the head to reconstruct the audio output. + + Args: + features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, + C denotes the feature dimension, and L is the sequence length. + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + x = self.backbone(features_input, **kwargs) + audio_output = self.head(x) + return audio_output + + @torch.inference_mode() + def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor: + """ + Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's + codebook weights. + + Args: + codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L), + where K is the number of codebooks, B is the batch size and L is the sequence length. + + Returns: + Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension, + and L is the sequence length. + """ + assert isinstance( + self.feature_extractor, EncodecFeatures + ), "Feature extractor should be an instance of EncodecFeatures" + + if codes.dim() == 2: + codes = codes.unsqueeze(1) + + n_bins = self.feature_extractor.encodec.quantizer.bins + offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device) + embeddings_idxs = codes + offsets.view(-1, 1, 1) + features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0) + features = features.transpose(1, 2) + + return features diff --git a/inspiremusic/wavtokenizer/decoder/spectral_ops.py b/inspiremusic/wavtokenizer/decoder/spectral_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8b062d5ff4c1d82124afa2752cea62132434790d --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/spectral_ops.py @@ -0,0 +1,242 @@ +import numpy as np +import scipy +import torch +from torch import nn, view_as_real, view_as_complex +import pdb + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + # assert (window_envelope > 1e-11).all() + if not torch.all(window_envelope > 1e-11): + window_envelope = torch.clamp(window_envelope, min=1e-11) + + y = y / window_envelope + + return y + + def onnx_forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + pdb.set_trace() + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + # assert (window_envelope > 1e-11).all() + if not torch.all(window_envelope > 1e-11): + window_envelope = torch.clamp(window_envelope, min=1e-11) + + y = y / window_envelope + + return y + + +class MDCT(nn.Module): + """ + Modified Discrete Cosine Transform (MDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) + post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) + # view_as_real: NCCL Backend does not support ComplexFloat data type + # https://github.com/pytorch/pytorch/issues/71613 + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. + + Args: + audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size + and T is the length of the audio. + + Returns: + Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames + and N is the number of frequency bins. + """ + if self.padding == "center": + audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2)) + elif self.padding == "same": + # hop_length is 1/2 frame_len + audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4)) + else: + raise ValueError("Padding must be 'center' or 'same'.") + + x = audio.unfold(-1, self.frame_len, self.frame_len // 2) + N = self.frame_len // 2 + x = x * self.window.expand(x.shape) + X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N] + res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) + return torch.real(res) * np.sqrt(2) + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1) + y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio diff --git a/inspiremusic/wavtokenizer/encoder/__init__.py b/inspiremusic/wavtokenizer/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8fd2ada59e0e15d4df2854052edf150e5238e3 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# flake8: noqa + +"""EnCodec neural audio codec.""" + +__version__ = "0.1.2a3" + +from .model import EncodecModel diff --git a/inspiremusic/wavtokenizer/encoder/distrib.py b/inspiremusic/wavtokenizer/encoder/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..b1662d8085cf2878c4cd058537d0f097de91d158 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/distrib.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one.") + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce( + buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast( + buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce( + p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.): + """Average a dictionary of metrics across all workers, using the optional + `count` as unnormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/inspiremusic/wavtokenizer/encoder/model.py b/inspiremusic/wavtokenizer/encoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..33be28de408112b0f54f062df43ac13953e170ea --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/model.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""EnCodec model implementation.""" + +import math +from pathlib import Path +import typing as tp + +import numpy as np +import torch +from torch import nn + +from . import quantization as qt +from . import modules as m +from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url + + +ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/' + +EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]] + + +class LMModel(nn.Module): + """Language Model to estimate probabilities of each codebook entry. + We predict all codebooks in parallel for a given time step. + + Args: + n_q (int): number of codebooks. + card (int): codebook cardinality. + dim (int): transformer dimension. + **kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`. + """ + def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs): + super().__init__() + self.card = card + self.n_q = n_q + self.dim = dim + self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs) + self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)]) + self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)]) + + def forward(self, indices: torch.Tensor, + states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0): + """ + Args: + indices (torch.Tensor): indices from the previous time step. Indices + should be 1 + actual index in the codebook. The value 0 is reserved for + when the index is missing (i.e. first time step). Shape should be + `[B, n_q, T]`. + states: state for the streaming decoding. + offset: offset of the current time step. + + Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities + with a shape `[B, card, n_q, T]`. + + """ + B, K, T = indices.shape + input_ = sum([self.emb[k](indices[:, k]) for k in range(K)]) + out, states, offset = self.transformer(input_, states, offset) + logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2) + return torch.softmax(logits, dim=1), states, offset + + +class EncodecModel(nn.Module): + """EnCodec model operating on the raw waveform. + Args: + target_bandwidths (list of float): Target bandwidths. + encoder (nn.Module): Encoder network. + decoder (nn.Module): Decoder network. + sample_rate (int): Audio sample rate. + channels (int): Number of audio channels. + normalize (bool): Whether to apply audio normalization. + segment (float or None): segment duration in sec. when doing overlap-add. + overlap (float): overlap between segment, given as a fraction of the segment duration. + name (str): name of the model, used as metadata when compressing audio. + """ + def __init__(self, + encoder: m.SEANetEncoder, + decoder: m.SEANetDecoder, + quantizer: qt.ResidualVectorQuantizer, + target_bandwidths: tp.List[float], + sample_rate: int, + channels: int, + normalize: bool = False, + segment: tp.Optional[float] = None, + overlap: float = 0.01, + name: str = 'unset'): + super().__init__() + self.bandwidth: tp.Optional[float] = None + self.target_bandwidths = target_bandwidths + self.encoder = encoder + self.quantizer = quantizer + self.decoder = decoder + self.sample_rate = sample_rate + self.channels = channels + self.normalize = normalize + self.segment = segment + self.overlap = overlap + self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios)) + self.name = name + self.bits_per_codebook = int(math.log2(self.quantizer.bins)) + assert 2 ** self.bits_per_codebook == self.quantizer.bins, \ + "quantizer bins must be a power of 2." + + @property + def segment_length(self) -> tp.Optional[int]: + if self.segment is None: + return None + return int(self.segment * self.sample_rate) + + @property + def segment_stride(self) -> tp.Optional[int]: + segment_length = self.segment_length + if segment_length is None: + return None + return max(1, int((1 - self.overlap) * segment_length)) + + def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]: + """Given a tensor `x`, returns a list of frames containing + the discrete encoded codes for `x`, along with rescaling factors + for each segment, when `self.normalize` is True. + + Each frames is a tuple `(codebook, scale)`, with `codebook` of + shape `[B, K, T]`, with `K` the number of codebooks. + """ + assert x.dim() == 3 + _, channels, length = x.shape + assert channels > 0 and channels <= 2 + segment_length = self.segment_length + if segment_length is None: + segment_length = length + stride = length + else: + stride = self.segment_stride # type: ignore + assert stride is not None + + encoded_frames: tp.List[EncodedFrame] = [] + for offset in range(0, length, stride): + frame = x[:, :, offset: offset + segment_length] + encoded_frames.append(self._encode_frame(frame)) + return encoded_frames + + def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: + length = x.shape[-1] + duration = length / self.sample_rate + assert self.segment is None or duration <= 1e-5 + self.segment + + if self.normalize: + mono = x.mean(dim=1, keepdim=True) + volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() + scale = 1e-8 + volume + x = x / scale + scale = scale.view(-1, 1) + else: + scale = None + + emb = self.encoder(x) + codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + return codes, scale + + def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor: + """Decode the given frames into a waveform. + Note that the output might be a bit bigger than the input. In that case, + any extra steps at the end can be trimmed. + """ + segment_length = self.segment_length + if segment_length is None: + assert len(encoded_frames) == 1 + return self._decode_frame(encoded_frames[0]) + + frames = [self._decode_frame(frame) for frame in encoded_frames] + return _linear_overlap_add(frames, self.segment_stride or 1) + + def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor: + codes, scale = encoded_frame + codes = codes.transpose(0, 1) + emb = self.quantizer.decode(codes) + out = self.decoder(emb) + if scale is not None: + out = out * scale.view(-1, 1, 1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + frames = self.encode(x) + return self.decode(frames)[:, :, :x.shape[-1]] + + def set_target_bandwidth(self, bandwidth: float): + if bandwidth not in self.target_bandwidths: + raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. " + f"Select one of {self.target_bandwidths}.") + self.bandwidth = bandwidth + + def get_lm_model(self) -> LMModel: + """Return the associated LM model to improve the compression rate. + """ + device = next(self.parameters()).device + lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200, + past_context=int(3.5 * self.frame_rate)).to(device) + checkpoints = { + 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th', + 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th', + } + try: + checkpoint_name = checkpoints[self.name] + except KeyError: + raise RuntimeError("No LM pre-trained for the current Encodec model.") + url = _get_checkpoint_url(ROOT_URL, checkpoint_name) + state = torch.hub.load_state_dict_from_url( + url, map_location='cpu', check_hash=True) # type: ignore + lm.load_state_dict(state) + lm.eval() + return lm + + @staticmethod + def _get_model(target_bandwidths: tp.List[float], + sample_rate: int = 24_000, + channels: int = 1, + causal: bool = True, + model_norm: str = 'weight_norm', + audio_normalize: bool = False, + segment: tp.Optional[float] = None, + name: str = 'unset'): + encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal) + decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal) + n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10)) + quantizer = qt.ResidualVectorQuantizer( + dimension=encoder.dimension, + n_q=n_q, + bins=1024, + ) + model = EncodecModel( + encoder, + decoder, + quantizer, + target_bandwidths, + sample_rate, + channels, + normalize=audio_normalize, + segment=segment, + name=name, + ) + return model + + @staticmethod + def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None): + if repository is not None: + if not repository.is_dir(): + raise ValueError(f"{repository} must exist and be a directory.") + file = repository / checkpoint_name + checksum = file.stem.split('-')[1] + _check_checksum(file, checksum) + return torch.load(file) + else: + url = _get_checkpoint_url(ROOT_URL, checkpoint_name) + return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore + + @staticmethod + def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None): + """Return the pretrained causal 24khz model. + """ + if repository: + assert pretrained + target_bandwidths = [1.5, 3., 6, 12., 24.] + checkpoint_name = 'encodec_24khz-d7cc33bc.th' + sample_rate = 24_000 + channels = 1 + model = EncodecModel._get_model( + target_bandwidths, sample_rate, channels, + causal=True, model_norm='weight_norm', audio_normalize=False, + name='encodec_24khz' if pretrained else 'unset') + if pretrained: + state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) + model.load_state_dict(state_dict) + model.eval() + return model + + @staticmethod + def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None): + """Return the pretrained 48khz model. + """ + if repository: + assert pretrained + target_bandwidths = [3., 6., 12., 24.] + checkpoint_name = 'encodec_48khz-7e698e3e.th' + sample_rate = 48_000 + channels = 2 + model = EncodecModel._get_model( + target_bandwidths, sample_rate, channels, + causal=False, model_norm='time_group_norm', audio_normalize=True, + segment=1., name='encodec_48khz' if pretrained else 'unset') + if pretrained: + state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) + model.load_state_dict(state_dict) + model.eval() + return model + + +def test(): + from itertools import product + import torchaudio + bandwidths = [3, 6, 12, 24] + models = { + 'encodec_24khz': EncodecModel.encodec_model_24khz, + 'encodec_48khz': EncodecModel.encodec_model_48khz + } + for model_name, bw in product(models.keys(), bandwidths): + model = models[model_name]() + model.set_target_bandwidth(bw) + audio_suffix = model_name.split('_')[1][:3] + wav, sr = torchaudio.load(f"test_{audio_suffix}.wav") + wav = wav[:, :model.sample_rate * 2] + wav_in = wav.unsqueeze(0) + wav_dec = model(wav_in)[0] + assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape) + + +if __name__ == '__main__': + test() diff --git a/inspiremusic/wavtokenizer/encoder/modules/__init__.py b/inspiremusic/wavtokenizer/encoder/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e2f987aafa3abf9b882fe15ca5a3b6e150ea32 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch modules.""" + +# flake8: noqa +from .conv import ( + pad1d, + unpad1d, + NormConv1d, + NormConvTranspose1d, + NormConv2d, + NormConvTranspose2d, + SConv1d, + SConvTranspose1d, +) +from .lstm import SLSTM +from .seanet import SEANetEncoder, SEANetDecoder +from .transformer import StreamingTransformerEncoder diff --git a/inspiremusic/wavtokenizer/encoder/modules/conv.py b/inspiremusic/wavtokenizer/encoder/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e83ae84d20ad2082c6e83bb7fc73bb22ac58cf13 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/conv.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .norm import ConvLayerNorm + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`. + """ + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect'): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/inspiremusic/wavtokenizer/encoder/modules/lstm.py b/inspiremusic/wavtokenizer/encoder/modules/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..49908198953deed173bed6eed5199eb74b99e5f8 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/lstm.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""LSTM layers module.""" + +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + # def forward(self, x): + # x = x.permute(2, 0, 1) + # y, _ = self.lstm(x) + # if self.skip: + # y = y + x + # y = y.permute(1, 2, 0) + # return y + + # 修改transpose顺序 + def forward(self, x): + # # 插入reshape + # x = x.reshape(x.shape) + x1 = x.permute(2, 0, 1) + y, _ = self.lstm(x1) + y = y.permute(1, 2, 0) + if self.skip: + y = y + x + return y diff --git a/inspiremusic/wavtokenizer/encoder/modules/norm.py b/inspiremusic/wavtokenizer/encoder/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..19970e0a21ea1c10461cb56d776619dd5f64ff36 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/norm.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Normalization modules.""" + +import typing as tp + +import einops +import torch +from torch import nn + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return diff --git a/inspiremusic/wavtokenizer/encoder/modules/seanet.py b/inspiremusic/wavtokenizer/encoder/modules/seanet.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1c02d508cbffce0613a637d4c7943d936b09db --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/seanet.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Encodec SEANet-based encoder and decoder implementation.""" + +import typing as tp + +import numpy as np +import torch.nn as nn + +from . import ( + SConv1d, + SConvTranspose1d, + SLSTM +) + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + Args: + dim (int): Dimension of the input/output + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3) + true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. + """ + def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], + activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, + pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + super().__init__() + assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = 1 + model: tp.List[nn.Module] = [ + SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + norm=norm, norm_params=norm_params, + activation=activation, activation_params=activation_params, + causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + # Add downsampling layers + model += [ + act(**activation_params), + SConv1d(mult * n_filters, mult * n_filters * 2, + kernel_size=ratio * 2, stride=ratio, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + mult *= 2 + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2, + trim_right_ratio: float = 1.0): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add upsampling layers + model += [ + act(**activation_params), + SConvTranspose1d(mult * n_filters, mult * n_filters // 2, + kernel_size=ratio * 2, stride=ratio, + norm=norm, norm_kwargs=norm_params, + causal=causal, trim_right_ratio=trim_right_ratio), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + activation=activation, activation_params=activation_params, + norm=norm, norm_params=norm_params, causal=causal, + pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [ + final_act(**final_activation_params) + ] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def test(): + import torch + encoder = SEANetEncoder() + decoder = SEANetDecoder() + x = torch.randn(1, 1, 24000) + z = encoder(x) + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + +if __name__ == '__main__': + test() diff --git a/inspiremusic/wavtokenizer/encoder/modules/transformer.py b/inspiremusic/wavtokenizer/encoder/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..44b47918f84aa47021c0d6f5bd58364641088541 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/transformer.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""A streamable transformer.""" + +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000): + """Create time embedding for the given positions, target dimension `dim`. + """ + # We aim for BTC format + assert dim % 2 == 0 + half_dim = dim // 2 + adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) + phase = positions / (max_period ** (adim / (half_dim - 1))) + return torch.cat([ + torch.cos(phase), + torch.sin(phase), + ], dim=-1) + + +class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): + def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore + if self.norm_first: + sa_input = self.norm1(x) + x = x + self._sa_block(sa_input, x_past, past_context) + x = x + self._ff_block(self.norm2(x)) + else: + sa_input = x + x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) + x = self.norm2(x + self._ff_block(x)) + + return x, sa_input + + # self-attention block + def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore + _, T, _ = x.shape + _, H, _ = x_past.shape + + queries = x + keys = torch.cat([x_past, x], dim=1) + values = keys + + queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) + keys_pos = torch.arange(T + H, device=x.device).view(1, -1) + delta = queries_pos - keys_pos + valid_access = (delta >= 0) & (delta <= past_context) + x = self.self_attn(queries, keys, values, + attn_mask=~valid_access, + need_weights=False)[0] + return self.dropout1(x) + + +class StreamingTransformerEncoder(nn.Module): + """TransformerEncoder with streaming support. + + Args: + dim (int): dimension of the data. + hidden_scale (int): intermediate dimension of FF module is this times the dimension. + num_heads (int): number of heads. + num_layers (int): number of layers. + max_period (float): maxium period of cosines in the positional embedding. + past_context (int or None): receptive field for the causal mask, infinite if None. + gelu (bool): if true uses GeLUs, otherwise use ReLUs. + norm_in (bool): normalize the input. + dropout (float): dropout probability. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5, + max_period: float = 10000, past_context: int = 1000, gelu: bool = True, + norm_in: bool = True, dropout: float = 0., **kwargs): + super().__init__() + assert dim % num_heads == 0 + hidden_dim = int(dim * hidden_scale) + + self.max_period = max_period + self.past_context = past_context + activation: tp.Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + else: + self.norm_in = nn.Identity() + + self.layers = nn.ModuleList() + for idx in range(num_layers): + self.layers.append( + StreamingTransformerEncoderLayer( + dim, num_heads, hidden_dim, + activation=activation, batch_first=True, dropout=dropout, **kwargs)) + + def forward(self, x: torch.Tensor, + states: tp.Optional[tp.List[torch.Tensor]] = None, + offset: tp.Union[int, torch.Tensor] = 0): + B, T, C = x.shape + if states is None: + states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))] + + positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset + pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) + + new_state: tp.List[torch.Tensor] = [] + x = self.norm_in(x) + x = x + pos_emb + + for layer_state, layer in zip(states, self.layers): + x, new_layer_state = layer(x, layer_state, self.past_context) + new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) + new_state.append(new_layer_state[:, -self.past_context:, :]) + return x, new_state, offset + T diff --git a/inspiremusic/wavtokenizer/encoder/msstftd.py b/inspiremusic/wavtokenizer/encoder/msstftd.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d3242a57e1e20e99bc2fa86e363cc5ec92cbf7 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/msstftd.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""MS-STFT discriminator, provided here for reference.""" + +import typing as tp + +import torchaudio +import torch +from torch import nn +from einops import rearrange + +from .modules import NormConv2d + + +FeatureMapType = tp.List[torch.Tensor] +LogitsType = torch.Tensor +DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] + + +def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): + return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) + + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_fft (int): Size of FFT for each scale. Default: 1024 + hop_length (int): Length of hop between STFT windows for each scale. Default: 256 + kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` + stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` + dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` + win_length (int): Window size for each scale. Default: 1024 + normalized (bool): Whether to normalize by magnitude after stft. Default: True + norm (str): Normalization method. Default: `'weight_norm'` + activation (str): Activation function. Default: `'LeakyReLU'` + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. Default: 1 + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, + filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], + stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', + activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, + normalized=self.normalized, center=False, pad_mode=None, power=None) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, + dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm)) + in_chs = out_chs + out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm)) + self.conv_post = NormConv2d(out_chs, self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm) + + def forward(self, x: torch.Tensor): + fmap = [] + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + z = torch.cat([z.real, z.imag], dim=1) + z = rearrange(z, 'b c w t -> b c t w') + for i, layer in enumerate(self.convs): + z = layer(z) + z = self.activation(z) + fmap.append(z) + z = self.conv_post(z) + return z, fmap + + +class MultiScaleSTFTDiscriminator(nn.Module): + """Multi-Scale STFT (MS-STFT) discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_ffts (Sequence[int]): Size of FFT for each scale + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale + win_lengths (Sequence[int]): Window size for each scale + **kwargs: additional args for STFTDiscriminator + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], + win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.discriminators = nn.ModuleList([ + DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, + n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) + for i in range(len(n_ffts)) + ]) + self.num_discriminators = len(self.discriminators) + + def forward(self, x: torch.Tensor) -> DiscriminatorOutput: + logits = [] + fmaps = [] + for disc in self.discriminators: + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps + + +def test(): + disc = MultiScaleSTFTDiscriminator(filters=32) + y = torch.randn(1, 1, 24000) + y_hat = torch.randn(1, 1, 24000) + + y_disc_r, fmap_r = disc(y) + y_disc_gen, fmap_gen = disc(y_hat) + assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators + + assert all([len(fm) == 5 for fm in fmap_r + fmap_gen]) + assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm]) + assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen]) + + +if __name__ == '__main__': + test() diff --git a/inspiremusic/wavtokenizer/encoder/quantization/__init__.py b/inspiremusic/wavtokenizer/encoder/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/inspiremusic/wavtokenizer/encoder/quantization/ac.py b/inspiremusic/wavtokenizer/encoder/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f3e5dcd385cd273a145effa3f53ce7ccfdc74c --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/quantization/ac.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, + roundoff: float = 1e-8, min_range: int = 2, + check: bool = True) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2 ** total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] + if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2 ** self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream. + """ + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + self.current -= (b1 << self.max_bit) + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2 ** self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py b/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..774781c2947622e6c0c7a55c6eded26a2813b7c7 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py @@ -0,0 +1,421 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp +import warnings + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from .. import distrib + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) #data不变 + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + distrib.broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1., + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, + kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, + decay=decay, epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + + # breakpoint() + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + quantize, embed_ind = self._codebook(x) + if self.training: + quantize = x + (quantize - x).detach() + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + # warnings.warn('When using RVQ in training model, first check ' + # 'https://github.com/facebookresearch/encodec/issues/25 . ' + # 'The bug wasn\'t fixed here for reproducibility.') + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + all_indices.append(indices) + quantized = layer.decode(indices) + residual = residual - quantized.detach() + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out + + +class LanguageVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + # print("core_vq.py:self.layers",self.layers) + + def forward(self, x, n_q: tp.Optional[int] = None): + # breakpoint() x[b,t,c] #[64,75,128] + quantized_out = 0.0 + residual = x + + + all_losses = [] + all_indices = [] + + # breakpoint() + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized_out, indices, loss = layer(residual) #得到该层的表征,该层的indices,该层的loss [64,75] + # residual = residual - quantized.detach() + # quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + # breakpoint() + # breakpoint() + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + all_indices.append(indices) + quantized = layer.decode(indices) + residual = residual - quantized.detach() + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out \ No newline at end of file diff --git a/inspiremusic/wavtokenizer/encoder/quantization/vq.py b/inspiremusic/wavtokenizer/encoder/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e316b4bf912c2a743cd27fe038a17e85bceb13 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/quantization/vq.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +from .core_vq import ResidualVectorQuantization,LanguageVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + + # print(self.bins) + + # breakpoint() + + self.vq = LanguageVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + # self.vq = ResidualVectorQuantization( + # dim=self.dimension, + # codebook_size=self.bins, + # num_quantizers=self.n_q, + # decay=self.decay, + # kmeans_init=self.kmeans_init, + # kmeans_iters=self.kmeans_iters, + # threshold_ema_dead_code=self.threshold_ema_dead_code, + # ) + + + def forward(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + frame_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + # breakpoint() + + + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + # assert n_q==4 + # breakpoint() + # nq_choice=[3,4,8] + nq_choice=[4,6,8] + if self.training: + # choice = int(torch.randint(0, 3, (1,)).item()) + choice = int(torch.randint(0, 3, (1,)).item()) + # breakpoint() + n_q=nq_choice[choice] + # breakpoint() + # n_q=8 + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def infer(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + frame_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + # n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + # # assert n_q==4 + # # breakpoint() + # # nq_choice=[3,4,8] + # nq_choice=[3,4,5,6,7,8] + # if self.training: + # # choice = int(torch.randint(0, 3, (1,)).item()) + # choice = int(torch.randint(0, 6, (1,)).item()) + # # breakpoint() + # n_q=nq_choice[choice] + n_q=1 + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth(self, frame_rate: int, bandwidth: tp.Optional[float] = None) -> int: + """Return n_q based on specified target bandwidth. + """ + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.: + # bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as + # bandwidth == 6.0 + n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, frame_rate: int): + """Return bandwidth per quantizer for a given input frame rate. + Each quantizer encodes a frame with lg(bins) bits. + """ + return math.log2(self.bins) * frame_rate + + def encode(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: + """Encode a given input tensor with the specified frame rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizers to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + codes = self.vq.encode(x, n_q=n_q) + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + """ + quantized = self.vq.decode(codes) + return quantized diff --git a/inspiremusic/wavtokenizer/encoder/utils.py b/inspiremusic/wavtokenizer/encoder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f0f9e9bcb37f2267b2f8adefabfc3672453dc5 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Various utilities.""" + +from hashlib import sha256 +from pathlib import Path +import typing as tp + +import torch +import torchaudio + + +def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int): + # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario + # e.g., more than 2 frames per position. + # The core idea is to use a weight function that is a triangle, + # with a maximum value at the middle of the segment. + # We use this weighting when summing the frames, and divide by the sum of weights + # for each positions at the end. Thus: + # - if a frame is the only one to cover a position, the weighting is a no-op. + # - if 2 frames cover a position: + # ... ... + # / \/ \ + # / /\ \ + # S T , i.e. S offset of second frame starts, T end of first frame. + # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset. + # After the final normalization, the weight of the second frame at position `t` is + # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want. + # + # - if more than 2 frames overlap at a given point, we hope that by induction + # something sensible happens. + assert len(frames) + device = frames[0].device + dtype = frames[0].dtype + shape = frames[0].shape[:-1] + total_size = stride * (len(frames) - 1) + frames[-1].shape[-1] + + frame_length = frames[0].shape[-1] + t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1] + weight = 0.5 - (t - 0.5).abs() + + sum_weight = torch.zeros(total_size, device=device, dtype=dtype) + out = torch.zeros(*shape, total_size, device=device, dtype=dtype) + offset: int = 0 + + for frame in frames: + frame_length = frame.shape[-1] + out[..., offset:offset + frame_length] += weight[:frame_length] * frame + sum_weight[offset:offset + frame_length] += weight[:frame_length] + offset += stride + assert sum_weight.min() > 0 + return out / sum_weight + + +def _get_checkpoint_url(root_url: str, checkpoint: str): + if not root_url.endswith('/'): + root_url += '/' + return root_url + checkpoint + + +def _check_checksum(path: Path, checksum: str): + sha = sha256() + with open(path, 'rb') as file: + while True: + buf = file.read(2**20) + if not buf: + break + sha.update(buf) + actual_checksum = sha.hexdigest()[:len(checksum)] + if actual_checksum != checksum: + raise RuntimeError(f'Invalid checksum for file {path}, ' + f'expected {checksum} but got {actual_checksum}') + + +def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): + assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions" + assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo." + *shape, channels, length = wav.shape + if target_channels == 1: + wav = wav.mean(-2, keepdim=True) + elif target_channels == 2: + wav = wav.expand(*shape, target_channels, length) + elif channels == 1: + wav = wav.expand(target_channels, -1) + else: + raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}") + wav = torchaudio.transforms.Resample(sr, target_sr)(wav) + return wav + + +def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], + sample_rate: int, rescale: bool = False): + limit = 0.99 + mx = wav.abs().max() + if rescale: + wav = wav * min(limit / mx, 1) + else: + wav = wav.clamp(-limit, limit) + torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16) diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..8cb74c576d15e436e7fef942073a514fcf606cd1 --- /dev/null +++ b/install.sh @@ -0,0 +1,16 @@ +#pip install flash-attn --no-build-isolation +#git submodule update --init --recursive +mkdir pretrained_models +cd pretrained_models +git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base.git & +git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long.git & +git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B.git & +git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz.git & +git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base-24kHz.git & +wait + +for i in InspireMusic-Base InspireMusic-Base-24kHz InspireMusic-1.5B InspireMusic-1.5B-24kHz InspireMusic-1.5B-Long +do + sed -i -e "s/\.\.\/\.\.\///g" ${i}/inspiremusic.yaml +done +cd .. diff --git a/path.sh b/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..1e0c63627f7691b4cd411586e3795239b9d0e163 --- /dev/null +++ b/path.sh @@ -0,0 +1,6 @@ +#!/bin/bash +export PYTHONIOENCODING=UTF-8 +export MAIN_ROOT=`realpath ${PWD}/` + +export PYTHONPATH=${MAIN_ROOT}:${MAIN_ROOT}/third_party/Matcha-TTS:${PYTHONPATH} +export BIN_DIR=${MAIN_ROOT}/inspiremusic diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a7692ffbabe9dd207326320e858b71918e5682c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,31 @@ +--extra-index-url https://download.pytorch.org/whl/cu117 +--extra-index-url https://pypi.nvidia.com +--trusted-host pypi.nvidia.com +conformer==0.3.2 +gdown==5.1.0 +grpcio==1.57.0 +grpcio-tools==1.57.0 +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +inflect==7.3.1 +librosa==0.10.2 +lightning==2.2.4 +matplotlib==3.7.5 +modelscope==1.17.1 +networkx==3.1 +omegaconf==2.3.0 +protobuf==4.25 +pydantic==2.7.0 +rich==13.7.1 +soundfile==0.12.1 +tensorboard==2.14.0 +torch +torchaudio +transformers==4.40.1 +uvicorn==0.30.0 +wget==3.2 +WeTextProcessing==1.0.3 +accelerate +huggingface-hub +soundfile==0.12.1 +diffusers diff --git a/third_party/.DS_Store b/third_party/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b5ef7dc97d9221f109f3cb33bfb3ea80996336a8 Binary files /dev/null and b/third_party/.DS_Store differ diff --git a/third_party/Matcha-TTS/.DS_Store b/third_party/Matcha-TTS/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a834b8fafd6d9b3777afdeddd678b1c7eb7dde65 Binary files /dev/null and b/third_party/Matcha-TTS/.DS_Store differ diff --git a/third_party/Matcha-TTS/.env.example b/third_party/Matcha-TTS/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..a790e320464ebc778ca07f5bcd826a9c8412ed0e --- /dev/null +++ b/third_party/Matcha-TTS/.env.example @@ -0,0 +1,6 @@ +# example of file for storing private and user specific environment variables, like keys or system paths +# rename it to ".env" (excluded from version control by default) +# .env is loaded by train.py automatically +# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} + +MY_VAR="/home/user/my/system/path" diff --git a/third_party/Matcha-TTS/.github/PULL_REQUEST_TEMPLATE.md b/third_party/Matcha-TTS/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..410bcd87a45297ab8f0d369574a032858b6b1811 --- /dev/null +++ b/third_party/Matcha-TTS/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,22 @@ +## What does this PR do? + + + +Fixes #\ + +## Before submitting + +- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? +- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? +- [ ] Did you list all the **breaking changes** introduced by this pull request? +- [ ] Did you **test your PR locally** with `pytest` command? +- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? + +## Did you have fun? + +Make sure you had fun coding 🙃 diff --git a/third_party/Matcha-TTS/.github/codecov.yml b/third_party/Matcha-TTS/.github/codecov.yml new file mode 100644 index 0000000000000000000000000000000000000000..c66853c4bd9991f730da5dda7dc9881986779558 --- /dev/null +++ b/third_party/Matcha-TTS/.github/codecov.yml @@ -0,0 +1,15 @@ +coverage: + status: + # measures overall project coverage + project: + default: + threshold: 100% # how much decrease in coverage is needed to not consider success + + # measures PR or single commit coverage + patch: + default: + threshold: 100% # how much decrease in coverage is needed to not consider success + + + # project: off + # patch: off diff --git a/third_party/Matcha-TTS/.github/dependabot.yml b/third_party/Matcha-TTS/.github/dependabot.yml new file mode 100644 index 0000000000000000000000000000000000000000..b19ccab12a3c573025ce6ba6d9068b062b29cc1b --- /dev/null +++ b/third_party/Matcha-TTS/.github/dependabot.yml @@ -0,0 +1,17 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + target-branch: "dev" + schedule: + interval: "daily" + ignore: + - dependency-name: "pytorch-lightning" + update-types: ["version-update:semver-patch"] + - dependency-name: "torchmetrics" + update-types: ["version-update:semver-patch"] diff --git a/third_party/Matcha-TTS/.github/release-drafter.yml b/third_party/Matcha-TTS/.github/release-drafter.yml new file mode 100644 index 0000000000000000000000000000000000000000..59af159f671abe75311eb626c8ec92ca6ea09d3c --- /dev/null +++ b/third_party/Matcha-TTS/.github/release-drafter.yml @@ -0,0 +1,44 @@ +name-template: "v$RESOLVED_VERSION" +tag-template: "v$RESOLVED_VERSION" + +categories: + - title: "🚀 Features" + labels: + - "feature" + - "enhancement" + - title: "🐛 Bug Fixes" + labels: + - "fix" + - "bugfix" + - "bug" + - title: "🧹 Maintenance" + labels: + - "maintenance" + - "dependencies" + - "refactoring" + - "cosmetic" + - "chore" + - title: "📝️ Documentation" + labels: + - "documentation" + - "docs" + +change-template: "- $TITLE @$AUTHOR (#$NUMBER)" +change-title-escapes: '\<*_&' # You can add # and @ to disable mentions + +version-resolver: + major: + labels: + - "major" + minor: + labels: + - "minor" + patch: + labels: + - "patch" + default: patch + +template: | + ## Changes + + $CHANGES diff --git a/third_party/Matcha-TTS/.gitignore b/third_party/Matcha-TTS/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..cbec8b43a0414bbbf4cc9feae49b9dc091a60c92 --- /dev/null +++ b/third_party/Matcha-TTS/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +### VisualStudioCode +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace +**/.vscode + +# JetBrains +.idea/ + +# Data & Models +*.h5 +*.tar +*.tar.gz + +# Lightning-Hydra-Template +configs/local/default.yaml +/data/ +/logs/ +.env + +# Aim logging +.aim + +# Cython complied files +matcha/utils/monotonic_align/core.c + +# Ignoring hifigan checkpoint +generator_v1 +g_02500000 +gradio_cached_examples/ +synth_output/ diff --git a/third_party/Matcha-TTS/.pre-commit-config.yaml b/third_party/Matcha-TTS/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..761867141e9f3a59316ab9f0b6eec6191d29900e --- /dev/null +++ b/third_party/Matcha-TTS/.pre-commit-config.yaml @@ -0,0 +1,59 @@ +default_language_version: + python: python3.11 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: trailing-whitespace + - id: end-of-file-fixer + # - id: check-docstring-first + - id: check-yaml + - id: debug-statements + - id: detect-private-key + - id: check-toml + - id: check-case-conflict + - id: check-added-large-files + + # python code formatting + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + args: [--line-length, "120"] + + # python import sorting + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + # python upgrading syntax to newer version + - repo: https://github.com/asottile/pyupgrade + rev: v3.15.0 + hooks: + - id: pyupgrade + args: [--py38-plus] + + # python check (PEP8), programming errors and code complexity + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: + [ + "--max-line-length", "120", + "--extend-ignore", + "E203,E402,E501,F401,F841,RST2,RST301", + "--exclude", + "logs/*,data/*,matcha/hifigan/*", + ] + additional_dependencies: [flake8-rst-docstrings==0.3.0] + + # pylint + - repo: https://github.com/pycqa/pylint + rev: v3.0.3 + hooks: + - id: pylint diff --git a/third_party/Matcha-TTS/.project-root b/third_party/Matcha-TTS/.project-root new file mode 100644 index 0000000000000000000000000000000000000000..63eab774b9e36aa1a46cbd31b59cbd373bc5477f --- /dev/null +++ b/third_party/Matcha-TTS/.project-root @@ -0,0 +1,2 @@ +# this file is required for inferring the project root directory +# do not delete diff --git a/third_party/Matcha-TTS/.pylintrc b/third_party/Matcha-TTS/.pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..962864189eab99a66b315b80f5a9976e7a423d4a --- /dev/null +++ b/third_party/Matcha-TTS/.pylintrc @@ -0,0 +1,525 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Specify a configuration file. +#rcfile= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=missing-docstring, + too-many-public-methods, + too-many-lines, + bare-except, + ## for avoiding weird p3.6 CI linter error + ## TODO: see later if we can remove this + assigning-non-slot, + unsupported-assignment-operation, + ## end + line-too-long, + fixme, + wrong-import-order, + ungrouped-imports, + wrong-import-position, + import-error, + invalid-name, + too-many-instance-attributes, + arguments-differ, + arguments-renamed, + no-name-in-module, + no-member, + unsubscriptable-object, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + useless-object-inheritance, + too-few-public-methods, + too-many-branches, + too-many-arguments, + too-many-locals, + too-many-statements, + duplicate-code, + not-callable, + import-outside-toplevel, + logging-fstring-interpolation, + logging-not-lazy, + unused-argument, + no-else-return, + chained-comparison, + redefined-outer-name + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[LOGGING] + +# Format style used to check logging format string. `old` means using % +# formatting, while `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package.. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members=numpy.*,torch.* + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +argument-rgx=[a-z_][a-z0-9_]{0,30}$ + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + x, + ex, + Run, + _ + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +variable-rgx=[a-z_][a-z0-9_]{0,30}$ + + +[STRING] + +# This flag controls whether the implicit-str-concat-in-sequence should +# generate a warning on implicit string concatenation in sequences defined over +# several lines. +check-str-concat-over-line-jumps=no + + +[IMPORTS] + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement. +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=15 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=builtins.BaseException, + builtins.Exception diff --git a/third_party/Matcha-TTS/LICENSE b/third_party/Matcha-TTS/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..858018e750da7be7b271bb7307e68d159ed67ef6 --- /dev/null +++ b/third_party/Matcha-TTS/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Shivam Mehta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/Matcha-TTS/MANIFEST.in b/third_party/Matcha-TTS/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..c013140cdfb9de19c4d4e73c73a44e33f33fa871 --- /dev/null +++ b/third_party/Matcha-TTS/MANIFEST.in @@ -0,0 +1,14 @@ +include README.md +include LICENSE.txt +include requirements.*.txt +include *.cff +include requirements.txt +include matcha/VERSION +recursive-include matcha *.json +recursive-include matcha *.html +recursive-include matcha *.png +recursive-include matcha *.md +recursive-include matcha *.py +recursive-include matcha *.pyx +recursive-exclude tests * +prune tests* diff --git a/third_party/Matcha-TTS/Makefile b/third_party/Matcha-TTS/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..4b523dd17b13a19617c9cc9d9dad7f7d8d4c24a0 --- /dev/null +++ b/third_party/Matcha-TTS/Makefile @@ -0,0 +1,42 @@ + +help: ## Show help + @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +clean: ## Clean autogenerated files + rm -rf dist + find . -type f -name "*.DS_Store" -ls -delete + find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf + find . | grep -E ".pytest_cache" | xargs rm -rf + find . | grep -E ".ipynb_checkpoints" | xargs rm -rf + rm -f .coverage + +clean-logs: ## Clean logs + rm -rf logs/** + +create-package: ## Create wheel and tar gz + rm -rf dist/ + python setup.py bdist_wheel --plat-name=manylinux1_x86_64 + python setup.py sdist + python -m twine upload dist/* --verbose --skip-existing + +format: ## Run pre-commit hooks + pre-commit run -a + +sync: ## Merge changes from main branch to your current branch + git pull + git pull origin main + +test: ## Run not slow tests + pytest -k "not slow" + +test-full: ## Run all tests + pytest + +train-ljspeech: ## Train the model + python matcha/train.py experiment=ljspeech + +train-ljspeech-min: ## Train the model with minimum memory + python matcha/train.py experiment=ljspeech_min_memory + +start_app: ## Start the app + python matcha/app.py diff --git a/third_party/Matcha-TTS/README.md b/third_party/Matcha-TTS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a5867d0a46568e545524073b435ad84d929b8d73 --- /dev/null +++ b/third_party/Matcha-TTS/README.md @@ -0,0 +1,315 @@ +
+ +# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching + +### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/) + +[![python](https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white)](https://www.python.org/downloads/release/python-3100/) +[![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) +[![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/) +[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) +[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) +[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) + +

+ +

+ +
+ +> This is the official code implementation of 🍵 Matcha-TTS [ICASSP 2024]. + +We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method: + +- Is probabilistic +- Has compact memory footprint +- Sounds highly natural +- Is very fast to synthesise from + +Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [our ICASSP 2024 paper](https://arxiv.org/abs/2309.03199) for more details. + +[Pre-trained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be automatically downloaded with the CLI or gradio interface. + +You can also [try 🍵 Matcha-TTS in your browser on HuggingFace 🤗 spaces](https://huggingface.co/spaces/shivammehta25/Matcha-TTS). + +## Teaser video + +[![Watch the video](https://img.youtube.com/vi/xmvJkz3bqw0/hqdefault.jpg)](https://youtu.be/xmvJkz3bqw0) + +## Installation + +1. Create an environment (suggested but optional) + +``` +conda create -n matcha-tts python=3.10 -y +conda activate matcha-tts +``` + +2. Install Matcha TTS using pip or from source + +```bash +pip install matcha-tts +``` + +from source + +```bash +pip install git+https://github.com/shivammehta25/Matcha-TTS.git +cd Matcha-TTS +pip install -e . +``` + +3. Run CLI / gradio app / jupyter notebook + +```bash +# This will download the required models +matcha-tts --text "" +``` + +or + +```bash +matcha-tts-app +``` + +or open `synthesis.ipynb` on jupyter notebook + +### CLI Arguments + +- To synthesise from given text, run: + +```bash +matcha-tts --text "" +``` + +- To synthesise from a file, run: + +```bash +matcha-tts --file +``` + +- To batch synthesise from a file, run: + +```bash +matcha-tts --file --batched +``` + +Additional arguments + +- Speaking rate + +```bash +matcha-tts --text "" --speaking_rate 1.0 +``` + +- Sampling temperature + +```bash +matcha-tts --text "" --temperature 0.667 +``` + +- Euler ODE solver steps + +```bash +matcha-tts --text "" --steps 10 +``` + +## Train with your own dataset + +Let's assume we are training with LJ Speech + +1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the file lists to point to the extracted data like for [item 5 in the setup of the NVIDIA Tacotron 2 repo](https://github.com/NVIDIA/tacotron2#setup). + +2. Clone and enter the Matcha-TTS repository + +```bash +git clone https://github.com/shivammehta25/Matcha-TTS.git +cd Matcha-TTS +``` + +3. Install the package from source + +```bash +pip install -e . +``` + +4. Go to `configs/data/ljspeech.yaml` and change + +```yaml +train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt +valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt +``` + +5. Generate normalisation statistics with the yaml file of dataset configuration + +```bash +matcha-data-stats -i ljspeech.yaml +# Output: +#{'mel_mean': -5.53662231756592, 'mel_std': 2.1161014277038574} +``` + +Update these values in `configs/data/ljspeech.yaml` under `data_statistics` key. + +```bash +data_statistics: # Computed for ljspeech dataset + mel_mean: -5.536622 + mel_std: 2.116101 +``` + +to the paths of your train and validation filelists. + +6. Run the training script + +```bash +make train-ljspeech +``` + +or + +```bash +python matcha/train.py experiment=ljspeech +``` + +- for a minimum memory run + +```bash +python matcha/train.py experiment=ljspeech_min_memory +``` + +- for multi-gpu training, run + +```bash +python matcha/train.py experiment=ljspeech trainer.devices=[0,1] +``` + +7. Synthesise from the custom trained model + +```bash +matcha-tts --text "" --checkpoint_path +``` + +## ONNX support + +> Special thanks to [@mush42](https://github.com/mush42) for implementing ONNX export and inference support. + +It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph. + +### ONNX export + +To export a checkpoint to ONNX, first install ONNX with + +```bash +pip install onnx +``` + +then run the following: + +```bash +python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5 +``` + +Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems). + +**Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**. + +**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release. + +### ONNX Inference + +To run inference on the exported model, first install `onnxruntime` using + +```bash +pip install onnxruntime +pip install onnxruntime-gpu # for GPU inference +``` + +then use the following: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs +``` + +You can also control synthesis parameters: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0 +``` + +To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu +``` + +If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory. +If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory. + +If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx +``` + +This will write `.wav` audio files to the output directory. + +## Extract phoneme alignments from Matcha-TTS + +If the dataset is structured as + +```bash +data/ +└── LJSpeech-1.1 + ├── metadata.csv + ├── README + ├── test.txt + ├── train.txt + ├── val.txt + └── wavs +``` +Then you can extract the phoneme level alignments from a Trained Matcha-TTS model using: +```bash +python matcha/utils/get_durations_from_trained_model.py -i dataset_yaml -c +``` +Example: +```bash +python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c matcha_ljspeech.ckpt +``` +or simply: +```bash +matcha-tts-get-durations -i ljspeech.yaml -c matcha_ljspeech.ckpt +``` +--- +## Train using extracted alignments + +In the datasetconfig turn on load duration. +Example: `ljspeech.yaml` +``` +load_durations: True +``` +or see an examples in configs/experiment/ljspeech_from_durations.yaml + + +## Citation information + +If you use our code or otherwise find this work useful, please cite our paper: + +```text +@inproceedings{mehta2024matcha, + title={Matcha-{TTS}: A fast {TTS} architecture with conditional flow matching}, + author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje}, + booktitle={Proc. ICASSP}, + year={2024} +} +``` + +## Acknowledgements + +Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it. + +Other source code we would like to acknowledge: + +- [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev): For helping me figure out how to make cython binaries pip installable and encouragement +- [Hugging Face Diffusers](https://huggingface.co/): For their awesome diffusers library and its components +- [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): For the monotonic alignment search source code +- [torchdyn](https://github.com/DiffEqML/torchdyn): Useful for trying other ODE solvers during research and development +- [labml.ai](https://nn.labml.ai/transformers/rope/index.html): For the RoPE implementation diff --git a/third_party/Matcha-TTS/configs/__init__.py b/third_party/Matcha-TTS/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56bf7f4aa4906bc0f997132708cc0826c198e4aa --- /dev/null +++ b/third_party/Matcha-TTS/configs/__init__.py @@ -0,0 +1 @@ +# this file is needed here to include configs when building project as a package diff --git a/third_party/Matcha-TTS/configs/callbacks/default.yaml b/third_party/Matcha-TTS/configs/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebaa3ed31a7f626bc62f90184dc4b25b631e52a9 --- /dev/null +++ b/third_party/Matcha-TTS/configs/callbacks/default.yaml @@ -0,0 +1,5 @@ +defaults: + - model_checkpoint.yaml + - model_summary.yaml + - rich_progress_bar.yaml + - _self_ diff --git a/third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml b/third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d085c711a8521b6b98ad6401b686bb601ceacd6 --- /dev/null +++ b/third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.output_dir}/checkpoints # directory to save the model file + filename: checkpoint_{epoch:03d} # checkpoint filename + monitor: epoch # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 10 # save k best models (determined by above metric) + mode: "max" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: 100 # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/third_party/Matcha-TTS/configs/callbacks/model_summary.yaml b/third_party/Matcha-TTS/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6e5368d0e94298cce6d5421365b4583bd763ba92 --- /dev/null +++ b/third_party/Matcha-TTS/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 3 # the maximum depth of layer nesting that the summary will include diff --git a/third_party/Matcha-TTS/configs/callbacks/none.yaml b/third_party/Matcha-TTS/configs/callbacks/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml b/third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de6f1ccb11205a4db93645fb6f297e50205de172 --- /dev/null +++ b/third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,4 @@ +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html + +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/third_party/Matcha-TTS/configs/data/hi-fi_en-US_female.yaml b/third_party/Matcha-TTS/configs/data/hi-fi_en-US_female.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1269f9b3b421d27a204bb0697e2f27a0fa0864a3 --- /dev/null +++ b/third_party/Matcha-TTS/configs/data/hi-fi_en-US_female.yaml @@ -0,0 +1,14 @@ +defaults: + - ljspeech + - _self_ + +# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: hi-fi_en-US_female +train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt +valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt +batch_size: 32 +cleaners: [english_cleaners_piper] +data_statistics: # Computed for this dataset + mel_mean: -6.38385 + mel_std: 2.541796 diff --git a/third_party/Matcha-TTS/configs/data/ljspeech.yaml b/third_party/Matcha-TTS/configs/data/ljspeech.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee87a6a76a2c344b3a90196f87bb48f205e2e48d --- /dev/null +++ b/third_party/Matcha-TTS/configs/data/ljspeech.yaml @@ -0,0 +1,22 @@ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: ljspeech +train_filelist_path: data/LJSpeech-1.1/train.txt +valid_filelist_path: data/LJSpeech-1.1/val.txt +batch_size: 32 +num_workers: 20 +pin_memory: True +cleaners: [english_cleaners2] +add_blank: True +n_spks: 1 +n_fft: 1024 +n_feats: 80 +sample_rate: 22050 +hop_length: 256 +win_length: 1024 +f_min: 0 +f_max: 8000 +data_statistics: # Computed for ljspeech dataset + mel_mean: -5.536622 + mel_std: 2.116101 +seed: ${seed} +load_durations: false diff --git a/third_party/Matcha-TTS/configs/data/vctk.yaml b/third_party/Matcha-TTS/configs/data/vctk.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ba11cc63371ad6308d6711513268de7efe50eed9 --- /dev/null +++ b/third_party/Matcha-TTS/configs/data/vctk.yaml @@ -0,0 +1,14 @@ +defaults: + - ljspeech + - _self_ + +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: vctk +train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt +valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt +batch_size: 32 +add_blank: True +n_spks: 109 +data_statistics: # Computed for vctk dataset + mel_mean: -6.630575 + mel_std: 2.482914 diff --git a/third_party/Matcha-TTS/configs/debug/default.yaml b/third_party/Matcha-TTS/configs/debug/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3932c82585fbe44047c1569a5cfe9ee9895c71a --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +# disable callbacks and loggers during debugging +# callbacks: null +# logger: null + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + + # use this to also set hydra loggers to 'DEBUG' + # verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + +data: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin diff --git a/third_party/Matcha-TTS/configs/debug/fdr.yaml b/third_party/Matcha-TTS/configs/debug/fdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f2d34fa37c31017e749d5a4fc5ae6763e688b46 --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/fdr.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default + +trainer: + fast_dev_run: true diff --git a/third_party/Matcha-TTS/configs/debug/limit.yaml b/third_party/Matcha-TTS/configs/debug/limit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..514d77fbd1475b03fff0372e3da3c2fa7ea7d190 --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/third_party/Matcha-TTS/configs/debug/overfit.yaml b/third_party/Matcha-TTS/configs/debug/overfit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9906586a67a12aa81ff69138f589a366dbe2222f --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: null diff --git a/third_party/Matcha-TTS/configs/debug/profiler.yaml b/third_party/Matcha-TTS/configs/debug/profiler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..266295f15e0166e1d1b58b88caa7673f4b6493b5 --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/profiler.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default + +trainer: + max_epochs: 1 + # profiler: "simple" + profiler: "advanced" + # profiler: "pytorch" + accelerator: gpu + + limit_train_batches: 0.02 diff --git a/third_party/Matcha-TTS/configs/eval.yaml b/third_party/Matcha-TTS/configs/eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be312992b2a486b04d83a54dbd8f670d94979709 --- /dev/null +++ b/third_party/Matcha-TTS/configs/eval.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - _self_ + - data: mnist # choose datamodule with `test_dataloader()` for evaluation + - model: mnist + - logger: null + - trainer: default + - paths: default + - extras: default + - hydra: default + +task_name: "eval" + +tags: ["dev"] + +# passing checkpoint path is necessary for evaluation +ckpt_path: ??? diff --git a/third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml b/third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e6c57a0d0a399f7463f4ff2d96e1928c435779b --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: hi-fi_en-US_female.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"] + +run_name: hi-fi_en-US_female_piper_phonemizer diff --git a/third_party/Matcha-TTS/configs/experiment/ljspeech.yaml b/third_party/Matcha-TTS/configs/experiment/ljspeech.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5723f42cf3552226c42bd91202cc18818b685f0 --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/ljspeech.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech diff --git a/third_party/Matcha-TTS/configs/experiment/ljspeech_from_durations.yaml b/third_party/Matcha-TTS/configs/experiment/ljspeech_from_durations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63f7d298280245b8ae4d3403f8540d0d2e8ada4f --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/ljspeech_from_durations.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech + + +data: + load_durations: True + batch_size: 64 diff --git a/third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml b/third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef554dc633c392b1592d90d9d7734f2329264fdd --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech_min + + +model: + out_size: 172 diff --git a/third_party/Matcha-TTS/configs/experiment/multispeaker.yaml b/third_party/Matcha-TTS/configs/experiment/multispeaker.yaml new file mode 100644 index 0000000000000000000000000000000000000000..553842f4e2168db0fee4e44db11b5d086295b044 --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/multispeaker.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: vctk.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["multispeaker"] + +run_name: multispeaker diff --git a/third_party/Matcha-TTS/configs/extras/default.yaml b/third_party/Matcha-TTS/configs/extras/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9c6b622283a647fbc513166fc14f016cc3ed8a0 --- /dev/null +++ b/third_party/Matcha-TTS/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml b/third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1391183ebcdec3d8f5eb61374e0719d13c7545da --- /dev/null +++ b/third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml @@ -0,0 +1,52 @@ +# @package _global_ + +# example hyperparameter optimization of some experiment with Optuna: +# python train.py -m hparams_search=mnist_optuna experiment=example + +defaults: + - override /hydra/sweeper: optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module! +optimized_metric: "val/acc_best" + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + mode: "MULTIRUN" # set hydra to multirun by default if this config is attached + + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper + + # storage URL to persist optimization results + # for example, you can use SQLite if you set 'sqlite:///example.db' + storage: null + + # name of the study to persist optimization results + study_name: null + + # number of parallel workers + n_jobs: 1 + + # 'minimize' or 'maximize' the objective + direction: maximize + + # total number of runs that will be executed + n_trials: 20 + + # choose Optuna hyperparameter sampler + # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others + # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html + sampler: + _target_: optuna.samplers.TPESampler + seed: 1234 + n_startup_trials: 10 # number of random sampling runs before optimization starts + + # define hyperparameter search space + params: + model.optimizer.lr: interval(0.0001, 0.1) + data.batch_size: choice(32, 64, 128, 256) + model.net.lin1_size: choice(64, 128, 256) + model.net.lin2_size: choice(64, 128, 256) + model.net.lin3_size: choice(32, 64, 128, 256) diff --git a/third_party/Matcha-TTS/configs/hydra/default.yaml b/third_party/Matcha-TTS/configs/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1533136b22802a4f81e5387b74e407289edce94d --- /dev/null +++ b/third_party/Matcha-TTS/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log diff --git a/third_party/Matcha-TTS/configs/local/.gitkeep b/third_party/Matcha-TTS/configs/local/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/configs/logger/aim.yaml b/third_party/Matcha-TTS/configs/logger/aim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f9f6adad7feb2780c2efd5ddb0ed053621e05f8 --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/third_party/Matcha-TTS/configs/logger/comet.yaml b/third_party/Matcha-TTS/configs/logger/comet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0789274e2137ee6c97ca37a5d56c2b8abaf0aaa --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "lightning-hydra-template" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/third_party/Matcha-TTS/configs/logger/csv.yaml b/third_party/Matcha-TTS/configs/logger/csv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa028e9c146430c319101ffdfce466514338591c --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/third_party/Matcha-TTS/configs/logger/many_loggers.yaml b/third_party/Matcha-TTS/configs/logger/many_loggers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd586800bdccb4e8f4b0236a181b7ddd756ba9ab --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet + - csv + # - mlflow + # - neptune + - tensorboard + - wandb diff --git a/third_party/Matcha-TTS/configs/logger/mlflow.yaml b/third_party/Matcha-TTS/configs/logger/mlflow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8fb7e685fa27fc8141387a421b90a0b9b492d9e --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/third_party/Matcha-TTS/configs/logger/neptune.yaml b/third_party/Matcha-TTS/configs/logger/neptune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8233c140018ecce6ab62971beed269991d31c89b --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/lightning-hydra-template + # name: "" + log_model_checkpoints: True + prefix: "" diff --git a/third_party/Matcha-TTS/configs/logger/tensorboard.yaml b/third_party/Matcha-TTS/configs/logger/tensorboard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2bd31f6d8ba68d1f5c36a804885d5b9f9c1a9302 --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + # version: "" diff --git a/third_party/Matcha-TTS/configs/logger/wandb.yaml b/third_party/Matcha-TTS/configs/logger/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ece165889b3d0d9dc750a8f3c7454188cfdf12b7 --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "lightning-hydra-template" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/third_party/Matcha-TTS/configs/model/cfm/default.yaml b/third_party/Matcha-TTS/configs/model/cfm/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d1d9609e2d05c7b0a12a26115520340ac18e584 --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/cfm/default.yaml @@ -0,0 +1,3 @@ +name: CFM +solver: euler +sigma_min: 1e-4 diff --git a/third_party/Matcha-TTS/configs/model/decoder/default.yaml b/third_party/Matcha-TTS/configs/model/decoder/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aaa00e63402ade5c76247a2f1d6b294ec3c61e63 --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/decoder/default.yaml @@ -0,0 +1,7 @@ +channels: [256, 256] +dropout: 0.05 +attention_head_dim: 64 +n_blocks: 1 +num_mid_blocks: 2 +num_heads: 2 +act_fn: snakebeta diff --git a/third_party/Matcha-TTS/configs/model/encoder/default.yaml b/third_party/Matcha-TTS/configs/model/encoder/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4d5e5adee8f707bd384b682a3ad9a116c40c6ed --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/encoder/default.yaml @@ -0,0 +1,18 @@ +encoder_type: RoPE Encoder +encoder_params: + n_feats: ${model.n_feats} + n_channels: 192 + filter_channels: 768 + filter_channels_dp: 256 + n_heads: 2 + n_layers: 6 + kernel_size: 3 + p_dropout: 0.1 + spk_emb_dim: 64 + n_spks: 1 + prenet: true + +duration_predictor_params: + filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} + kernel_size: 3 + p_dropout: ${model.encoder.encoder_params.p_dropout} diff --git a/third_party/Matcha-TTS/configs/model/matcha.yaml b/third_party/Matcha-TTS/configs/model/matcha.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2b5c78ddeb98fcca85093deba1cea3b1d1074e1 --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/matcha.yaml @@ -0,0 +1,16 @@ +defaults: + - _self_ + - encoder: default.yaml + - decoder: default.yaml + - cfm: default.yaml + - optimizer: adam.yaml + +_target_: matcha.models.matcha_tts.MatchaTTS +n_vocab: 178 +n_spks: ${data.n_spks} +spk_emb_dim: 64 +n_feats: 80 +data_statistics: ${data.data_statistics} +out_size: null # Must be divisible by 4 +prior_loss: true +use_precomputed_durations: ${data.load_durations} diff --git a/third_party/Matcha-TTS/configs/model/optimizer/adam.yaml b/third_party/Matcha-TTS/configs/model/optimizer/adam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42795577474eaee5b0b96845a95e1a11c9152385 --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/optimizer/adam.yaml @@ -0,0 +1,4 @@ +_target_: torch.optim.Adam +_partial_: true +lr: 1e-4 +weight_decay: 0.0 diff --git a/third_party/Matcha-TTS/configs/paths/default.yaml b/third_party/Matcha-TTS/configs/paths/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec81db2d34712909a79be3e42e65efe08c35ecee --- /dev/null +++ b/third_party/Matcha-TTS/configs/paths/default.yaml @@ -0,0 +1,18 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${paths.root_dir}/data/ + +# path to logging directory +log_dir: ${paths.root_dir}/logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/third_party/Matcha-TTS/configs/train.yaml b/third_party/Matcha-TTS/configs/train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6f5c2e7b9781758c8d25f941f004ca383c3f494 --- /dev/null +++ b/third_party/Matcha-TTS/configs/train.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: ljspeech + - model: matcha + - callbacks: default + - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - trainer: default + - paths: default + - extras: default + - hydra: default + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path +task_name: "train" + +run_name: ??? + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: True + +# simply provide checkpoint path to resume training +ckpt_path: null + +# seed for random number generators in pytorch, numpy and python.random +seed: 1234 diff --git a/third_party/Matcha-TTS/configs/trainer/cpu.yaml b/third_party/Matcha-TTS/configs/trainer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7d6767e60c956567555980654f15e7bb673a41f --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/third_party/Matcha-TTS/configs/trainer/ddp.yaml b/third_party/Matcha-TTS/configs/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..94b43e20ca7bf1f2ea92627fd46906e4f0a273a1 --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/ddp.yaml @@ -0,0 +1,9 @@ +defaults: + - default + +strategy: ddp + +accelerator: gpu +devices: [0,1] +num_nodes: 1 +sync_batchnorm: True diff --git a/third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml b/third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8404419e5c295654967d0dfb73a7366e75be2f1f --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/third_party/Matcha-TTS/configs/trainer/default.yaml b/third_party/Matcha-TTS/configs/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee3d370d8ca6b08d7ee7a86d34184c2104f0e1ef --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/default.yaml @@ -0,0 +1,20 @@ +_target_: lightning.pytorch.trainer.Trainer + +default_root_dir: ${paths.output_dir} + +max_epochs: -1 + +accelerator: gpu +devices: [0] + +# mixed precision for extra speed-up +precision: 16-mixed + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False + +gradient_clip_val: 5.0 diff --git a/third_party/Matcha-TTS/configs/trainer/gpu.yaml b/third_party/Matcha-TTS/configs/trainer/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2389510a90f5f0161cff6ccfcb4a96097ddf9a1 --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: gpu +devices: 1 diff --git a/third_party/Matcha-TTS/configs/trainer/mps.yaml b/third_party/Matcha-TTS/configs/trainer/mps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ecf6d5cc3a34ca127c5510f4a18e989561e38e4 --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/third_party/Matcha-TTS/matcha/VERSION b/third_party/Matcha-TTS/matcha/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..ea5abc8f95c042c48eff77805a033599f816a545 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/VERSION @@ -0,0 +1 @@ +0.0.7.0 diff --git a/third_party/Matcha-TTS/matcha/__init__.py b/third_party/Matcha-TTS/matcha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/app.py b/third_party/Matcha-TTS/matcha/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d68fbaa2d10d1faab606d89906af5e8b6baa5aa4 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/app.py @@ -0,0 +1,357 @@ +import tempfile +from argparse import Namespace +from pathlib import Path + +import gradio as gr +import soundfile as sf +import torch + +from matcha.cli import ( + MATCHA_URLS, + VOCODER_URLS, + assert_model_downloaded, + get_device, + load_matcha, + load_vocoder, + process_text, + to_waveform, +) +from matcha.utils.utils import get_user_data_dir, plot_tensor + +LOCATION = Path(get_user_data_dir()) + +args = Namespace( + cpu=False, + model="matcha_vctk", + vocoder="hifigan_univ_v1", + spk=0, +) + +CURRENTLY_LOADED_MODEL = args.model + + +def MATCHA_TTS_LOC(x): + return LOCATION / f"{x}.ckpt" + + +def VOCODER_LOC(x): + return LOCATION / f"{x}" + + +LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" +RADIO_OPTIONS = { + "Multi Speaker (VCTK)": { + "model": "matcha_vctk", + "vocoder": "hifigan_univ_v1", + }, + "Single Speaker (LJ Speech)": { + "model": "matcha_ljspeech", + "vocoder": "hifigan_T2_v1", + }, +} + +# Ensure all the required models are downloaded +assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"]) +assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"]) +assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"]) +assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"]) + +device = get_device(args) + +# Load default model +model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device) +vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device) + + +def load_model(model_name, vocoder_name): + model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device) + vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device) + return model, vocoder, denoiser + + +def load_model_ui(model_type, textbox): + model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"] + + global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != model_name: + model, vocoder, denoiser = load_model(model_name, vocoder_name) + CURRENTLY_LOADED_MODEL = model_name + + if model_name == "matcha_ljspeech": + spk_slider = gr.update(visible=False, value=-1) + single_speaker_examples = gr.update(visible=True) + multi_speaker_examples = gr.update(visible=False) + length_scale = gr.update(value=0.95) + else: + spk_slider = gr.update(visible=True, value=0) + single_speaker_examples = gr.update(visible=False) + multi_speaker_examples = gr.update(visible=True) + length_scale = gr.update(value=0.85) + + return ( + textbox, + gr.update(interactive=True), + spk_slider, + single_speaker_examples, + multi_speaker_examples, + length_scale, + ) + + +@torch.inference_mode() +def process_text_gradio(text): + output = process_text(1, text, device) + return output["x_phones"][1::2], output["x"], output["x_lengths"] + + +@torch.inference_mode() +def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk): + spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None + output = model.synthesise( + text, + text_length, + n_timesteps=n_timesteps, + temperature=temperature, + spks=spk, + length_scale=length_scale, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + sf.write(fp.name, output["waveform"], 22050, "PCM_24") + + return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy()) + + +def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk): + global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != "matcha_vctk": + global model, vocoder, denoiser # pylint: disable=global-statement + model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1") + CURRENTLY_LOADED_MODEL = "matcha_vctk" + + phones, text, text_lengths = process_text_gradio(text) + audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) + return phones, audio, mel_spectrogram + + +def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1): + global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != "matcha_ljspeech": + global model, vocoder, denoiser # pylint: disable=global-statement + model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1") + CURRENTLY_LOADED_MODEL = "matcha_ljspeech" + + phones, text, text_lengths = process_text_gradio(text) + audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) + return phones, audio, mel_spectrogram + + +def main(): + description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching + ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/) + We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method: + + + * Is probabilistic + * Has compact memory footprint + * Sounds highly natural + * Is very fast to synthesise from + + + Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199). + Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models. + + Cached examples are available at the bottom of the page. + """ + + with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo: + processed_text = gr.State(value=None) + processed_text_len = gr.State(value=None) + + with gr.Box(): + with gr.Row(): + gr.Markdown(description, scale=3) + with gr.Column(): + gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False) + html = '
' + gr.HTML(html) + + with gr.Box(): + radio_options = list(RADIO_OPTIONS.keys()) + model_type = gr.Radio( + radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False + ) + + with gr.Row(): + gr.Markdown("# Text Input") + with gr.Row(): + text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3) + spk_slider = gr.Slider( + minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1 + ) + + with gr.Row(): + gr.Markdown("### Hyper parameters") + with gr.Row(): + n_timesteps = gr.Slider( + label="Number of ODE steps", + minimum=1, + maximum=100, + step=1, + value=10, + interactive=True, + ) + length_scale = gr.Slider( + label="Length scale (Speaking rate)", + minimum=0.5, + maximum=1.5, + step=0.05, + value=1.0, + interactive=True, + ) + mel_temp = gr.Slider( + label="Sampling temperature", + minimum=0.00, + maximum=2.001, + step=0.16675, + value=0.667, + interactive=True, + ) + + synth_btn = gr.Button("Synthesise") + + with gr.Box(): + with gr.Row(): + gr.Markdown("### Phonetised text") + phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text") + + with gr.Box(): + with gr.Row(): + mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram") + + # with gr.Row(): + audio = gr.Audio(interactive=False, label="Audio") + + with gr.Row(visible=False) as example_row_lj_speech: + examples = gr.Examples( # pylint: disable=unused-variable + examples=[ + [ + "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.", + 50, + 0.677, + 0.95, + ], + [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + 2, + 0.677, + 0.95, + ], + [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + 4, + 0.677, + 0.95, + ], + [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + 10, + 0.677, + 0.95, + ], + [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + 50, + 0.677, + 0.95, + ], + [ + "The narrative of these events is based largely on the recollections of the participants.", + 10, + 0.677, + 0.95, + ], + [ + "The jury did not believe him, and the verdict was for the defendants.", + 10, + 0.677, + 0.95, + ], + ], + fn=ljspeech_example_cacher, + inputs=[text, n_timesteps, mel_temp, length_scale], + outputs=[phonetised_text, audio, mel_spectrogram], + cache_examples=True, + ) + + with gr.Row() as example_row_multispeaker: + multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable + examples=[ + [ + "Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!", + 10, + 0.677, + 0.85, + 0, + ], + [ + "Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!", + 10, + 0.677, + 0.85, + 16, + ], + [ + "Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!", + 50, + 0.677, + 0.85, + 44, + ], + [ + "Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!", + 50, + 0.677, + 0.85, + 45, + ], + [ + "Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!", + 4, + 0.677, + 0.85, + 58, + ], + ], + fn=multispeaker_example_cacher, + inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider], + outputs=[phonetised_text, audio, mel_spectrogram], + cache_examples=True, + label="Multi Speaker Examples", + ) + + model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then( + load_model_ui, + inputs=[model_type, text], + outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale], + ) + + synth_btn.click( + fn=process_text_gradio, + inputs=[ + text, + ], + outputs=[phonetised_text, processed_text, processed_text_len], + api_name="matcha_tts", + queue=True, + ).then( + fn=synthesise_mel, + inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider], + outputs=[audio, mel_spectrogram], + ) + + demo.queue().launch(share=True) + + +if __name__ == "__main__": + main() diff --git a/third_party/Matcha-TTS/matcha/cli.py b/third_party/Matcha-TTS/matcha/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..7daf13073a01326cc8150a0f29453e635f31d719 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/cli.py @@ -0,0 +1,419 @@ +import argparse +import datetime as dt +import os +import warnings +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import soundfile as sf +import torch + +from matcha.hifigan.config import v1 +from matcha.hifigan.denoiser import Denoiser +from matcha.hifigan.env import AttrDict +from matcha.hifigan.models import Generator as HiFiGAN +from matcha.models.matcha_tts import MatchaTTS +from matcha.text import sequence_to_text, text_to_sequence +from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse + +MATCHA_URLS = { + "matcha_ljspeech": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt", + "matcha_vctk": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_vctk.ckpt", +} + +VOCODER_URLS = { + "hifigan_T2_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1", # Old url: https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link + "hifigan_univ_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/g_02500000", # Old url: https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link +} + +MULTISPEAKER_MODEL = { + "matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)} +} + +SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}} + + +def plot_spectrogram_to_numpy(spectrogram, filename): + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.title("Synthesised Mel-Spectrogram") + fig.canvas.draw() + plt.savefig(filename) + + +def process_text(i: int, text: str, device: torch.device): + print(f"[{i}] - Input text: {text}") + x = torch.tensor( + intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), + dtype=torch.long, + device=device, + )[None] + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + x_phones = sequence_to_text(x.squeeze(0).tolist()) + print(f"[{i}] - Phonetised text: {x_phones[1::2]}") + + return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} + + +def get_texts(args): + if args.text: + texts = [args.text] + else: + with open(args.file, encoding="utf-8") as f: + texts = f.readlines() + return texts + + +def assert_required_models_available(args): + save_dir = get_user_data_dir() + if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None: + model_path = args.checkpoint_path + else: + model_path = save_dir / f"{args.model}.ckpt" + assert_model_downloaded(model_path, MATCHA_URLS[args.model]) + + vocoder_path = save_dir / f"{args.vocoder}" + assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder]) + return {"matcha": model_path, "vocoder": vocoder_path} + + +def load_hifigan(checkpoint_path, device): + h = AttrDict(v1) + hifigan = HiFiGAN(h).to(device) + hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"]) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def load_vocoder(vocoder_name, checkpoint_path, device): + print(f"[!] Loading {vocoder_name}!") + vocoder = None + if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"): + vocoder = load_hifigan(checkpoint_path, device) + else: + raise NotImplementedError( + f"Vocoder {vocoder_name} not implemented! define a load_<> method for it" + ) + + denoiser = Denoiser(vocoder, mode="zeros") + print(f"[+] {vocoder_name} loaded!") + return vocoder, denoiser + + +def load_matcha(model_name, checkpoint_path, device): + print(f"[!] Loading {model_name}!") + model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device) + _ = model.eval() + + print(f"[+] {model_name} loaded!") + return model + + +def to_waveform(mel, vocoder, denoiser=None): + audio = vocoder(mel).clamp(-1, 1) + if denoiser is not None: + audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze() + + return audio.cpu().squeeze() + + +def save_to_folder(filename: str, output: dict, folder: str): + folder = Path(folder) + folder.mkdir(exist_ok=True, parents=True) + plot_spectrogram_to_numpy(np.array(output["mel"].squeeze().float().cpu()), f"{filename}.png") + np.save(folder / f"{filename}", output["mel"].cpu().numpy()) + sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24") + return folder.resolve() / f"{filename}.wav" + + +def validate_args(args): + assert ( + args.text or args.file + ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." + assert args.temperature >= 0, "Sampling temperature cannot be negative" + assert args.steps > 0, "Number of ODE steps must be greater than 0" + + if args.checkpoint_path is None: + # When using pretrained models + if args.model in SINGLESPEAKER_MODEL: + args = validate_args_for_single_speaker_model(args) + + if args.model in MULTISPEAKER_MODEL: + args = validate_args_for_multispeaker_model(args) + else: + # When using a custom model + if args.vocoder != "hifigan_univ_v1": + warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech." + warnings.warn(warn_, UserWarning) + if args.speaking_rate is None: + args.speaking_rate = 1.0 + + if args.batched: + assert args.batch_size > 0, "Batch size must be greater than 0" + assert args.speaking_rate > 0, "Speaking rate must be greater than 0" + + return args + + +def validate_args_for_multispeaker_model(args): + if args.vocoder is not None: + if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]: + warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}" + warnings.warn(warn_, UserWarning) + else: + args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"] + + if args.speaking_rate is None: + args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"] + + spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"] + if args.spk is not None: + assert ( + args.spk >= spk_range[0] and args.spk <= spk_range[-1] + ), f"Speaker ID must be between {spk_range} for this model." + else: + available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"] + warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}" + warnings.warn(warn_, UserWarning) + args.spk = available_spk_id + + return args + + +def validate_args_for_single_speaker_model(args): + if args.vocoder is not None: + if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]: + warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}" + warnings.warn(warn_, UserWarning) + else: + args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"] + + if args.speaking_rate is None: + args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"] + + if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]: + warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}" + warnings.warn(warn_, UserWarning) + args.spk = SINGLESPEAKER_MODEL[args.model]["spk"] + + return args + + +@torch.inference_mode() +def cli(): + parser = argparse.ArgumentParser( + description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" + ) + parser.add_argument( + "--model", + type=str, + default="matcha_ljspeech", + help="Model to use", + choices=MATCHA_URLS.keys(), + ) + + parser.add_argument( + "--checkpoint_path", + type=str, + default=None, + help="Path to the custom model checkpoint", + ) + + parser.add_argument( + "--vocoder", + type=str, + default=None, + help="Vocoder to use (default: will use the one suggested with the pretrained model))", + choices=VOCODER_URLS.keys(), + ) + parser.add_argument("--text", type=str, default=None, help="Text to synthesize") + parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") + parser.add_argument("--spk", type=int, default=None, help="Speaker ID") + parser.add_argument( + "--temperature", + type=float, + default=0.667, + help="Variance of the x0 noise (default: 0.667)", + ) + parser.add_argument( + "--speaking_rate", + type=float, + default=None, + help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", + ) + parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)") + parser.add_argument("--cpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") + parser.add_argument( + "--denoiser_strength", + type=float, + default=0.00025, + help="Strength of the vocoder bias denoiser (default: 0.00025)", + ) + parser.add_argument( + "--output_folder", + type=str, + default=os.getcwd(), + help="Output folder to save results (default: current dir)", + ) + parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)") + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)" + ) + + args = parser.parse_args() + + args = validate_args(args) + device = get_device(args) + print_config(args) + paths = assert_required_models_available(args) + + if args.checkpoint_path is not None: + print(f"[🍵] Loading custom model from {args.checkpoint_path}") + paths["matcha"] = args.checkpoint_path + args.model = "custom_model" + + model = load_matcha(args.model, paths["matcha"], device) + vocoder, denoiser = load_vocoder(args.vocoder, paths["vocoder"], device) + + texts = get_texts(args) + + spk = torch.tensor([args.spk], device=device, dtype=torch.long) if args.spk is not None else None + if len(texts) == 1 or not args.batched: + unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk) + else: + batched_synthesis(args, device, model, vocoder, denoiser, texts, spk) + + +class BatchedSynthesisDataset(torch.utils.data.Dataset): + def __init__(self, processed_texts): + self.processed_texts = processed_texts + + def __len__(self): + return len(self.processed_texts) + + def __getitem__(self, idx): + return self.processed_texts[idx] + + +def batched_collate_fn(batch): + x = [] + x_lengths = [] + + for b in batch: + x.append(b["x"].squeeze(0)) + x_lengths.append(b["x_lengths"]) + + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) + x_lengths = torch.concat(x_lengths, dim=0) + return {"x": x, "x_lengths": x_lengths} + + +def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk): + total_rtf = [] + total_rtf_w = [] + processed_text = [process_text(i, text, "cpu") for i, text in enumerate(texts)] + dataloader = torch.utils.data.DataLoader( + BatchedSynthesisDataset(processed_text), + batch_size=args.batch_size, + collate_fn=batched_collate_fn, + num_workers=8, + ) + for i, batch in enumerate(dataloader): + i = i + 1 + start_t = dt.datetime.now() + b = batch["x"].shape[0] + output = model.synthesise( + batch["x"].to(device), + batch["x_lengths"].to(device), + n_timesteps=args.steps, + temperature=args.temperature, + spks=spk.expand(b) if spk is not None else spk, + length_scale=args.speaking_rate, + ) + + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + t = (dt.datetime.now() - start_t).total_seconds() + rtf_w = t * 22050 / (output["waveform"].shape[-1]) + print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}") + print(f"[🍵-Batch: {i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}") + total_rtf.append(output["rtf"]) + total_rtf_w.append(rtf_w) + for j in range(output["mel"].shape[0]): + base_name = f"utterance_{j:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{j:03d}" + length = output["mel_lengths"][j] + new_dict = {"mel": output["mel"][j][:, :length], "waveform": output["waveform"][j][: length * 256]} + location = save_to_folder(base_name, new_dict, args.output_folder) + print(f"[🍵-{j}] Waveform saved: {location}") + + print("".join(["="] * 100)) + print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}") + print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}") + print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!") + + +def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk): + total_rtf = [] + total_rtf_w = [] + for i, text in enumerate(texts): + i = i + 1 + base_name = f"utterance_{i:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{i:03d}" + + print("".join(["="] * 100)) + text = text.strip() + text_processed = process_text(i, text, device) + + print(f"[🍵] Whisking Matcha-T(ea)TS for: {i}") + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=args.steps, + temperature=args.temperature, + spks=spk, + length_scale=args.speaking_rate, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + # RTF with HiFiGAN + t = (dt.datetime.now() - start_t).total_seconds() + rtf_w = t * 22050 / (output["waveform"].shape[-1]) + print(f"[🍵-{i}] Matcha-TTS RTF: {output['rtf']:.4f}") + print(f"[🍵-{i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}") + total_rtf.append(output["rtf"]) + total_rtf_w.append(rtf_w) + + location = save_to_folder(base_name, output, args.output_folder) + print(f"[+] Waveform saved: {location}") + + print("".join(["="] * 100)) + print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}") + print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}") + print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!") + + +def print_config(args): + print("[!] Configurations: ") + print(f"\t- Model: {args.model}") + print(f"\t- Vocoder: {args.vocoder}") + print(f"\t- Temperature: {args.temperature}") + print(f"\t- Speaking rate: {args.speaking_rate}") + print(f"\t- Number of ODE steps: {args.steps}") + print(f"\t- Speaker: {args.spk}") + + +def get_device(args): + if torch.cuda.is_available() and not args.cpu: + print("[+] GPU Available! Using GPU") + device = torch.device("cuda") + else: + print("[-] GPU not available or forced CPU run! Using CPU") + device = torch.device("cpu") + return device + + +if __name__ == "__main__": + cli() diff --git a/third_party/Matcha-TTS/matcha/data/__init__.py b/third_party/Matcha-TTS/matcha/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/data/components/__init__.py b/third_party/Matcha-TTS/matcha/data/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/data/text_mel_datamodule.py b/third_party/Matcha-TTS/matcha/data/text_mel_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..e10dfcb8bba8fbd1d04272a70d5acfe886ae5107 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/data/text_mel_datamodule.py @@ -0,0 +1,274 @@ +import random +from pathlib import Path +from typing import Any, Dict, Optional + +import numpy as np +import torch +import torchaudio as ta +from lightning import LightningDataModule +from torch.utils.data.dataloader import DataLoader + +from matcha.text import text_to_sequence +from matcha.utils.audio import mel_spectrogram +from matcha.utils.model import fix_len_compatibility, normalize +from matcha.utils.utils import intersperse + + +def parse_filelist(filelist_path, split_char="|"): + with open(filelist_path, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split_char) for line in f] + return filepaths_and_text + + +class TextMelDataModule(LightningDataModule): + def __init__( # pylint: disable=unused-argument + self, + name, + train_filelist_path, + valid_filelist_path, + batch_size, + num_workers, + pin_memory, + cleaners, + add_blank, + n_spks, + n_fft, + n_feats, + sample_rate, + hop_length, + win_length, + f_min, + f_max, + data_statistics, + seed, + load_durations, + ): + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute things like random split twice! + """ + # load and split datasets only if not loaded already + + self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init + self.hparams.train_filelist_path, + self.hparams.n_spks, + self.hparams.cleaners, + self.hparams.add_blank, + self.hparams.n_fft, + self.hparams.n_feats, + self.hparams.sample_rate, + self.hparams.hop_length, + self.hparams.win_length, + self.hparams.f_min, + self.hparams.f_max, + self.hparams.data_statistics, + self.hparams.seed, + self.hparams.load_durations, + ) + self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init + self.hparams.valid_filelist_path, + self.hparams.n_spks, + self.hparams.cleaners, + self.hparams.add_blank, + self.hparams.n_fft, + self.hparams.n_feats, + self.hparams.sample_rate, + self.hparams.hop_length, + self.hparams.win_length, + self.hparams.f_min, + self.hparams.f_max, + self.hparams.data_statistics, + self.hparams.seed, + self.hparams.load_durations, + ) + + def train_dataloader(self): + return DataLoader( + dataset=self.trainset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + collate_fn=TextMelBatchCollate(self.hparams.n_spks), + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.validset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=TextMelBatchCollate(self.hparams.n_spks), + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass # pylint: disable=unnecessary-pass + + def state_dict(self): + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass # pylint: disable=unnecessary-pass + + +class TextMelDataset(torch.utils.data.Dataset): + def __init__( + self, + filelist_path, + n_spks, + cleaners, + add_blank=True, + n_fft=1024, + n_mels=80, + sample_rate=22050, + hop_length=256, + win_length=1024, + f_min=0.0, + f_max=8000, + data_parameters=None, + seed=None, + load_durations=False, + ): + self.filepaths_and_text = parse_filelist(filelist_path) + self.n_spks = n_spks + self.cleaners = cleaners + self.add_blank = add_blank + self.n_fft = n_fft + self.n_mels = n_mels + self.sample_rate = sample_rate + self.hop_length = hop_length + self.win_length = win_length + self.f_min = f_min + self.f_max = f_max + self.load_durations = load_durations + + if data_parameters is not None: + self.data_parameters = data_parameters + else: + self.data_parameters = {"mel_mean": 0, "mel_std": 1} + random.seed(seed) + random.shuffle(self.filepaths_and_text) + + def get_datapoint(self, filepath_and_text): + if self.n_spks > 1: + filepath, spk, text = ( + filepath_and_text[0], + int(filepath_and_text[1]), + filepath_and_text[2], + ) + else: + filepath, text = filepath_and_text[0], filepath_and_text[1] + spk = None + + text, cleaned_text = self.get_text(text, add_blank=self.add_blank) + mel = self.get_mel(filepath) + + durations = self.get_durations(filepath, text) if self.load_durations else None + + return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations} + + def get_durations(self, filepath, text): + filepath = Path(filepath) + data_dir, name = filepath.parent.parent, filepath.stem + + try: + dur_loc = data_dir / "durations" / f"{name}.npy" + durs = torch.from_numpy(np.load(dur_loc).astype(int)) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n" + ) from e + + assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match" + + return durs + + def get_mel(self, filepath): + audio, sr = ta.load(filepath) + assert sr == self.sample_rate + mel = mel_spectrogram( + audio, + self.n_fft, + self.n_mels, + self.sample_rate, + self.hop_length, + self.win_length, + self.f_min, + self.f_max, + center=False, + ).squeeze() + mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) + return mel + + def get_text(self, text, add_blank=True): + text_norm, cleaned_text = text_to_sequence(text, self.cleaners) + if self.add_blank: + text_norm = intersperse(text_norm, 0) + text_norm = torch.IntTensor(text_norm) + return text_norm, cleaned_text + + def __getitem__(self, index): + datapoint = self.get_datapoint(self.filepaths_and_text[index]) + return datapoint + + def __len__(self): + return len(self.filepaths_and_text) + + +class TextMelBatchCollate: + def __init__(self, n_spks): + self.n_spks = n_spks + + def __call__(self, batch): + B = len(batch) + y_max_length = max([item["y"].shape[-1] for item in batch]) + y_max_length = fix_len_compatibility(y_max_length) + x_max_length = max([item["x"].shape[-1] for item in batch]) + n_feats = batch[0]["y"].shape[-2] + + y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) + x = torch.zeros((B, x_max_length), dtype=torch.long) + durations = torch.zeros((B, x_max_length), dtype=torch.long) + + y_lengths, x_lengths = [], [] + spks = [] + filepaths, x_texts = [], [] + for i, item in enumerate(batch): + y_, x_ = item["y"], item["x"] + y_lengths.append(y_.shape[-1]) + x_lengths.append(x_.shape[-1]) + y[i, :, : y_.shape[-1]] = y_ + x[i, : x_.shape[-1]] = x_ + spks.append(item["spk"]) + filepaths.append(item["filepath"]) + x_texts.append(item["x_text"]) + if item["durations"] is not None: + durations[i, : item["durations"].shape[-1]] = item["durations"] + + y_lengths = torch.tensor(y_lengths, dtype=torch.long) + x_lengths = torch.tensor(x_lengths, dtype=torch.long) + spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None + + return { + "x": x, + "x_lengths": x_lengths, + "y": y, + "y_lengths": y_lengths, + "spks": spks, + "filepaths": filepaths, + "x_texts": x_texts, + "durations": durations if not torch.eq(durations, 0).all() else None, + } diff --git a/third_party/Matcha-TTS/matcha/hifigan/LICENSE b/third_party/Matcha-TTS/matcha/hifigan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..91751daed806f63ac594cf077a3065f719a41662 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/Matcha-TTS/matcha/hifigan/README.md b/third_party/Matcha-TTS/matcha/hifigan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5db25850451a794b1db1b15b08e82c1d802edbb3 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/README.md @@ -0,0 +1,101 @@ +# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis + +### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae + +In our [paper](https://arxiv.org/abs/2010.05646), +we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
+We provide our implementation and pretrained models as open source in this repository. + +**Abstract :** +Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. +Although such methods improve the sampling efficiency and memory usage, +their sample quality has not yet reached that of autoregressive and flow-based generative models. +In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. +As speech audio consists of sinusoidal signals with various periods, +we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. +A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method +demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than +real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen +speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times +faster than real-time on CPU with comparable quality to an autoregressive counterpart. + +Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. + +## Pre-requisites + +1. Python >= 3.6 +2. Clone this repository. +3. Install python requirements. Please refer [requirements.txt](requirements.txt) +4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). + And move all wav files to `LJSpeech-1.1/wavs` + +## Training + +``` +python train.py --config config_v1.json +``` + +To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
+Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
+You can change the path by adding `--checkpoint_path` option. + +Validation loss during training with V1 generator.
+![validation loss](./validation_loss.png) + +## Pretrained Model + +You can also use pretrained models we provide.
+[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
+Details of each folder are as in follows: + +| Folder Name | Generator | Dataset | Fine-Tuned | +| ------------ | --------- | --------- | ------------------------------------------------------ | +| LJ_V1 | V1 | LJSpeech | No | +| LJ_V2 | V2 | LJSpeech | No | +| LJ_V3 | V3 | LJSpeech | No | +| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| VCTK_V1 | V1 | VCTK | No | +| VCTK_V2 | V2 | VCTK | No | +| VCTK_V3 | V3 | VCTK | No | +| UNIVERSAL_V1 | V1 | Universal | No | + +We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. + +## Fine-Tuning + +1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
+ Example: + ` Audio File : LJ001-0001.wav +Mel-Spectrogram File : LJ001-0001.npy` +2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
+3. Run the following command. + ``` + python train.py --fine_tuning True --config config_v1.json + ``` + For other command line options, please refer to the training section. + +## Inference from wav file + +1. Make `test_files` directory and copy wav files into the directory. +2. Run the following command. + ` python inference.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files` by default.
+ You can change the path by adding `--output_dir` option. + +## Inference for end-to-end speech synthesis + +1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), + [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. +2. Run the following command. + ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files_from_mel` by default.
+ You can change the path by adding `--output_dir` option. + +## Acknowledgements + +We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) +and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. diff --git a/third_party/Matcha-TTS/matcha/hifigan/__init__.py b/third_party/Matcha-TTS/matcha/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/hifigan/config.py b/third_party/Matcha-TTS/matcha/hifigan/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b3abea9e151a08864353d32066bd4935e24b82e7 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/config.py @@ -0,0 +1,28 @@ +v1 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0004, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 256, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, +} diff --git a/third_party/Matcha-TTS/matcha/hifigan/denoiser.py b/third_party/Matcha-TTS/matcha/hifigan/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd33312a09b1940374a0e29a97fe3a1a1dac7d2 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/denoiser.py @@ -0,0 +1,64 @@ +# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py + +"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" +import torch + + +class Denoiser(torch.nn.Module): + """Removes model bias from audio produced with waveglow""" + + def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): + super().__init__() + self.filter_length = filter_length + self.hop_length = int(filter_length / n_overlap) + self.win_length = win_length + + dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device + self.device = device + if mode == "zeros": + mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) + elif mode == "normal": + mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) + else: + raise Exception(f"Mode {mode} if not supported") + + def stft_fn(audio, n_fft, hop_length, win_length, window): + spec = torch.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=True, + ) + spec = torch.view_as_real(spec) + return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) + + self.stft = lambda x: stft_fn( + audio=x, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + self.istft = lambda x, y: torch.istft( + torch.complex(x * torch.cos(y), x * torch.sin(y)), + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + + with torch.no_grad(): + bias_audio = vocoder(mel_input).float().squeeze(0) + bias_spec, _ = self.stft(bias_audio) + + self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) + + @torch.inference_mode() + def forward(self, audio, strength=0.0005): + audio_spec, audio_angles = self.stft(audio) + audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.istft(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/third_party/Matcha-TTS/matcha/hifigan/env.py b/third_party/Matcha-TTS/matcha/hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea4f948a3f002921bf9bc24f52cbc1c0b1fc2ec --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/env.py @@ -0,0 +1,17 @@ +""" from https://github.com/jik876/hifi-gan """ + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/third_party/Matcha-TTS/matcha/hifigan/meldataset.py b/third_party/Matcha-TTS/matcha/hifigan/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8b43ea7965e04a52d5427a485ee911b743057c4a --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/meldataset.py @@ -0,0 +1,217 @@ +""" from https://github.com/jik876/hifi-gan """ + +import math +import os +import random + +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from librosa.util import normalize +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + + with open(a.input_validation_file, encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + audio, sampling_rate = load_wav(filename) + audio = audio / MAX_WAV_VALUE + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start : audio_start + self.segment_size] + else: + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) + else: + mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/third_party/Matcha-TTS/matcha/hifigan/models.py b/third_party/Matcha-TTS/matcha/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d209d9a4e99ec29e4167a5a2eaa62d72b3eff694 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/models.py @@ -0,0 +1,368 @@ +""" from https://github.com/jik876/hifi-gan """ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .xutils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super().__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for _, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/third_party/Matcha-TTS/matcha/hifigan/xutils.py b/third_party/Matcha-TTS/matcha/hifigan/xutils.py new file mode 100644 index 0000000000000000000000000000000000000000..eefadcb7a1d0bf9015e636b88fee3e22c9771bc5 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/xutils.py @@ -0,0 +1,60 @@ +""" from https://github.com/jik876/hifi-gan """ + +import glob +import os + +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] diff --git a/third_party/Matcha-TTS/matcha/models/__init__.py b/third_party/Matcha-TTS/matcha/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/models/baselightningmodule.py b/third_party/Matcha-TTS/matcha/models/baselightningmodule.py new file mode 100644 index 0000000000000000000000000000000000000000..f8abe7b44f44688ff00720f7e56e34b75894d176 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/baselightningmodule.py @@ -0,0 +1,210 @@ +""" +This is a base lightning module that can be used to train a model. +The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. +""" +import inspect +from abc import ABC +from typing import Any, Dict + +import torch +from lightning import LightningModule +from lightning.pytorch.utilities import grad_norm + +from matcha import utils +from matcha.utils.utils import plot_tensor + +log = utils.get_pylogger(__name__) + + +class BaseLightningClass(LightningModule, ABC): + def update_data_statistics(self, data_statistics): + if data_statistics is None: + data_statistics = { + "mel_mean": 0.0, + "mel_std": 1.0, + } + + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + + def configure_optimizers(self) -> Any: + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler not in (None, {}): + scheduler_args = {} + # Manage last epoch for exponential schedulers + if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: + if hasattr(self, "ckpt_loaded_epoch"): + current_epoch = self.ckpt_loaded_epoch - 1 + else: + current_epoch = -1 + + scheduler_args.update({"optimizer": optimizer}) + scheduler = self.hparams.scheduler.scheduler(**scheduler_args) + scheduler.last_epoch = current_epoch + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": self.hparams.scheduler.lightning_args.interval, + "frequency": self.hparams.scheduler.lightning_args.frequency, + "name": "learning_rate", + }, + } + + return {"optimizer": optimizer} + + def get_losses(self, batch): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + + dur_loss, prior_loss, diff_loss, *_ = self( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + out_size=self.out_size, + durations=batch["durations"], + ) + return { + "dur_loss": dur_loss, + "prior_loss": prior_loss, + "diff_loss": diff_loss, + } + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init + + def training_step(self, batch: Any, batch_idx: int): + loss_dict = self.get_losses(batch) + self.log( + "step", + float(self.global_step), + on_step=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + self.log( + "sub_loss/train_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/train", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + return {"loss": total_loss, "log": loss_dict} + + def validation_step(self, batch: Any, batch_idx: int): + loss_dict = self.get_losses(batch) + self.log( + "sub_loss/val_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/val", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + return total_loss + + def on_validation_end(self) -> None: + if self.trainer.is_global_zero: + one_batch = next(iter(self.trainer.val_dataloaders)) + if self.current_epoch == 0: + log.debug("Plotting original samples") + for i in range(2): + y = one_batch["y"][i].unsqueeze(0).to(self.device) + self.logger.experiment.add_image( + f"original/{i}", + plot_tensor(y.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + log.debug("Synthesising...") + for i in range(2): + x = one_batch["x"][i].unsqueeze(0).to(self.device) + x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) + spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None + output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) + y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] + attn = output["attn"] + self.logger.experiment.add_image( + f"generated_enc/{i}", + plot_tensor(y_enc.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"generated_dec/{i}", + plot_tensor(y_dec.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"alignment/{i}", + plot_tensor(attn.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + def on_before_optimizer_step(self, optimizer): + self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) diff --git a/third_party/Matcha-TTS/matcha/models/components/__init__.py b/third_party/Matcha-TTS/matcha/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/models/components/decoder.py b/third_party/Matcha-TTS/matcha/models/components/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1137cd7008e9d07b4f306926a82e44c2b2cddbdf --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/components/decoder.py @@ -0,0 +1,443 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import ConformerBlock +from diffusers.models.activations import get_activation +from einops import pack, rearrange, repeat + +from matcha.models.components.transformer import BasicTransformerBlock + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( # pylint: disable=useless-super-delegation + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + + +class Decoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c") + mask_down = rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_down, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c") + mask_mid = rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_mid, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = rearrange(x, "b c t -> b t c") + mask_up = rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_up, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask diff --git a/third_party/Matcha-TTS/matcha/models/components/flow_matching.py b/third_party/Matcha-TTS/matcha/models/components/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..5cad7431ef66a8d11da32a77c1af7f6e31d6b774 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/components/flow_matching.py @@ -0,0 +1,132 @@ +from abc import ABC + +import torch +import torch.nn.functional as F + +from matcha.models.components.decoder import Decoder +from matcha.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + n_feats, + cfm_params, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 1e-4 + + self.estimator = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( + torch.sum(mask) * u.shape[1] + ) + return loss, y + + +class CFM(BASECFM): + def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) + # Just change the architecture of the estimator here + self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) diff --git a/third_party/Matcha-TTS/matcha/models/components/text_encoder.py b/third_party/Matcha-TTS/matcha/models/components/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a388d05d6351fa2c9d9632fed0942d51fbec067b --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/components/text_encoder.py @@ -0,0 +1,410 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange + +import matcha.utils as utils +from matcha.utils.model import sequence_mask + +log = utils.get_pylogger(__name__) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + if encoder_params.prenet: + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0.5, + ) + else: + self.prenet = lambda x, x_mask: x + + self.encoder = Encoder( + encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + self.proj_w = DurationPredictor( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + + def forward(self, x, x_lengths, spks=None): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size,) + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + + x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask diff --git a/third_party/Matcha-TTS/matcha/models/components/transformer.py b/third_party/Matcha-TTS/matcha/models/components/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1afa3aff5383912209e508676c6885e13ef4ee --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/components/transformer.py @@ -0,0 +1,316 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states diff --git a/third_party/Matcha-TTS/matcha/models/matcha_tts.py b/third_party/Matcha-TTS/matcha/models/matcha_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..07f95ad2e31a2de94974c21f15e28ab5445ff6fc --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/matcha_tts.py @@ -0,0 +1,244 @@ +import datetime as dt +import math +import random + +import torch + +import matcha.utils.monotonic_align as monotonic_align +from matcha import utils +from matcha.models.baselightningmodule import BaseLightningClass +from matcha.models.components.flow_matching import CFM +from matcha.models.components.text_encoder import TextEncoder +from matcha.utils.model import ( + denormalize, + duration_loss, + fix_len_compatibility, + generate_path, + sequence_mask, +) + +log = utils.get_pylogger(__name__) + + +class MatchaTTS(BaseLightningClass): # 🍵 + def __init__( + self, + n_vocab, + n_spks, + spk_emb_dim, + n_feats, + encoder, + decoder, + cfm, + data_statistics, + out_size, + optimizer=None, + scheduler=None, + prior_loss=True, + use_precomputed_durations=False, + ): + super().__init__() + + self.save_hyperparameters(logger=False) + + self.n_vocab = n_vocab + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.n_feats = n_feats + self.out_size = out_size + self.prior_loss = prior_loss + self.use_precomputed_durations = use_precomputed_durations + + if n_spks > 1: + self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) + + self.encoder = TextEncoder( + encoder.encoder_type, + encoder.encoder_params, + encoder.duration_predictor_params, + n_vocab, + n_spks, + spk_emb_dim, + ) + + self.decoder = CFM( + in_channels=2 * encoder.encoder_params.n_feats, + out_channel=encoder.encoder_params.n_feats, + cfm_params=cfm, + decoder_params=decoder, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + self.update_data_statistics(data_statistics) + + @torch.inference_mode() + def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): + """ + Generates mel-spectrogram from text. Returns: + 1. encoder outputs + 2. decoder outputs + 3. generated alignment + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + n_timesteps (int): number of steps to use for reverse diffusion in decoder. + temperature (float, optional): controls variance of terminal distribution. + spks (bool, optional): speaker ids. + shape: (batch_size,) + length_scale (float, optional): controls speech pace. + Increase value to slow down generated speech and vice versa. + + Returns: + dict: { + "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Average mel spectrogram generated by the encoder + "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Refined mel spectrogram improved by the CFM + "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), + # Alignment map between text and mel spectrogram + "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Denormalized mel spectrogram + "mel_lengths": torch.Tensor, shape: (batch_size,), + # Lengths of mel spectrograms + "rtf": float, + # Real-time factor + """ + # For RTF computation + t = dt.datetime.now() + + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks.long()) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = y_lengths.max() + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + encoder_outputs = mu_y[:, :, :y_max_length] + + # Generate sample tracing the probability flow + decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks) + decoder_outputs = decoder_outputs[:, :, :y_max_length] + + t = (dt.datetime.now() - t).total_seconds() + rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) + + return { + "encoder_outputs": encoder_outputs, + "decoder_outputs": decoder_outputs, + "attn": attn[:, :, :y_max_length], + "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), + "mel_lengths": y_lengths, + "rtf": rtf, + } + + def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): + """ + Computes 3 losses: + 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). + 2. prior loss: loss between mel-spectrogram and encoder outputs. + 3. flow matching loss: loss between mel-spectrogram and decoder outputs. + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + y (torch.Tensor): batch of corresponding mel-spectrograms. + shape: (batch_size, n_feats, max_mel_length) + y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. + shape: (batch_size,) + out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. + Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. + spks (torch.Tensor, optional): speaker ids. + shape: (batch_size,) + """ + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + y_max_length = y.shape[-1] + + y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + + if self.use_precomputed_durations: + attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1)) + else: + # Use MAS to find most likely alignment `attn` between text and mel-spectrogram + with torch.no_grad(): + const = -0.5 * math.log(2 * math.pi) * self.n_feats + factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) + y_square = torch.matmul(factor.transpose(1, 2), y**2) + y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) + mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) + log_prior = y_square - y_mu_double + mu_square + const + + attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) + attn = attn.detach() # b, t_text, T_mel + + # Compute loss between predicted log-scaled durations and those obtained from MAS + # refered to as prior loss in the paper + logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask + dur_loss = duration_loss(logw, logw_, x_lengths) + + # Cut a small segment of mel-spectrogram in order to increase batch size + # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it + # - Do not need this hack for Matcha-TTS, but it works with it as well + if not isinstance(out_size, type(None)): + max_offset = (y_lengths - out_size).clamp(0) + offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) + out_offset = torch.LongTensor( + [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges] + ).to(y_lengths) + attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) + y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) + + y_cut_lengths = [] + for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): + y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) + y_cut_lengths.append(y_cut_length) + cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length + y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] + attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] + + y_cut_lengths = torch.LongTensor(y_cut_lengths) + y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) + + attn = attn_cut + y = y_cut + y_mask = y_cut_mask + + # Align encoded text with mel-spectrogram and get mu_y segment + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + + # Compute loss of the decoder + diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) + + if self.prior_loss: + prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + else: + prior_loss = 0 + + return dur_loss, prior_loss, diff_loss, attn diff --git a/third_party/Matcha-TTS/matcha/onnx/__init__.py b/third_party/Matcha-TTS/matcha/onnx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/onnx/export.py b/third_party/Matcha-TTS/matcha/onnx/export.py new file mode 100644 index 0000000000000000000000000000000000000000..9b795086158e1ad8a4bb5cd92306f3fa765f71ea --- /dev/null +++ b/third_party/Matcha-TTS/matcha/onnx/export.py @@ -0,0 +1,181 @@ +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +from lightning import LightningModule + +from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder + +DEFAULT_OPSET = 15 + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +class MatchaWithVocoder(LightningModule): + def __init__(self, matcha, vocoder): + super().__init__() + self.matcha = matcha + self.vocoder = vocoder + + def forward(self, x, x_lengths, scales, spks=None): + mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) + wavs = self.vocoder(mel).clamp(-1, 1) + lengths = mel_lengths * 256 + return wavs.squeeze(1), lengths + + +def get_exportable_module(matcha, vocoder, n_timesteps): + """ + Return an appropriate `LighteningModule` and output-node names + based on whether the vocoder is embedded in the final graph + """ + + def onnx_forward_func(x, x_lengths, scales, spks=None): + """ + Custom forward function for accepting + scaler parameters as tensors + """ + # Extract scaler parameters from tensors + temperature = scales[0] + length_scale = scales[1] + output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) + return output["mel"], output["mel_lengths"] + + # Monkey-patch Matcha's forward function + matcha.forward = onnx_forward_func + + if vocoder is None: + model, output_names = matcha, ["mel", "mel_lengths"] + else: + model = MatchaWithVocoder(matcha, vocoder) + output_names = ["wav", "wav_lengths"] + return model, output_names + + +def get_inputs(is_multi_speaker): + """ + Create dummy inputs for tracing + """ + dummy_input_length = 50 + x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) + x_lengths = torch.LongTensor([dummy_input_length]) + + # Scales + temperature = 0.667 + length_scale = 1.0 + scales = torch.Tensor([temperature, length_scale]) + + model_inputs = [x, x_lengths, scales] + input_names = [ + "x", + "x_lengths", + "scales", + ] + + if is_multi_speaker: + spks = torch.LongTensor([1]) + model_inputs.append(spks) + input_names.append("spks") + + return tuple(model_inputs), input_names + + +def main(): + parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") + + parser.add_argument( + "checkpoint_path", + type=str, + help="Path to the model checkpoint", + ) + parser.add_argument("output", type=str, help="Path to output `.onnx` file") + parser.add_argument( + "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" + ) + parser.add_argument( + "--vocoder-name", + type=str, + choices=list(VOCODER_URLS.keys()), + default=None, + help="Name of the vocoder to embed in the ONNX graph", + ) + parser.add_argument( + "--vocoder-checkpoint-path", + type=str, + default=None, + help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", + ) + parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") + + args = parser.parse_args() + + print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") + print(f"Setting n_timesteps to {args.n_timesteps}") + + checkpoint_path = Path(args.checkpoint_path) + matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") + + if args.vocoder_name or args.vocoder_checkpoint_path: + assert ( + args.vocoder_name and args.vocoder_checkpoint_path + ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." + vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") + else: + vocoder = None + + is_multi_speaker = matcha.n_spks > 1 + + dummy_input, input_names = get_inputs(is_multi_speaker) + model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) + + # Set dynamic shape for inputs/outputs + dynamic_axes = { + "x": {0: "batch_size", 1: "time"}, + "x_lengths": {0: "batch_size"}, + } + + if vocoder is None: + dynamic_axes.update( + { + "mel": {0: "batch_size", 2: "time"}, + "mel_lengths": {0: "batch_size"}, + } + ) + else: + print("Embedding the vocoder in the ONNX graph") + dynamic_axes.update( + { + "wav": {0: "batch_size", 1: "time"}, + "wav_lengths": {0: "batch_size"}, + } + ) + + if is_multi_speaker: + dynamic_axes["spks"] = {0: "batch_size"} + + # Create the output directory (if not exists) + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + + model.to_onnx( + args.output, + dummy_input, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=args.opset, + export_params=True, + do_constant_folding=True, + ) + print(f"[🍵] ONNX model exported to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/third_party/Matcha-TTS/matcha/onnx/infer.py b/third_party/Matcha-TTS/matcha/onnx/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..89ca92559c6df3776a07a038d7838242a3d19189 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/onnx/infer.py @@ -0,0 +1,168 @@ +import argparse +import os +import warnings +from pathlib import Path +from time import perf_counter + +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + +from matcha.cli import plot_spectrogram_to_numpy, process_text + + +def validate_args(args): + assert ( + args.text or args.file + ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." + assert args.temperature >= 0, "Sampling temperature cannot be negative" + assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" + return args + + +def write_wavs(model, inputs, output_dir, external_vocoder=None): + if external_vocoder is None: + print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") + t0 = perf_counter() + wavs, wav_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + mel_infer_secs = vocoder_infer_secs = None + else: + print("[🍵] Generating mel using Matcha") + mel_t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + mel_infer_secs = perf_counter() - mel_t0 + print("Generating waveform from mel using external vocoder") + vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} + vocoder_t0 = perf_counter() + wavs = external_vocoder.run(None, vocoder_inputs)[0] + vocoder_infer_secs = perf_counter() - vocoder_t0 + wavs = wavs.squeeze(1) + wav_lengths = mel_lengths * 256 + infer_secs = mel_infer_secs + vocoder_infer_secs + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): + output_filename = output_dir.joinpath(f"output_{i + 1}.wav") + audio = wav[:wav_length] + print(f"Writing audio to {output_filename}") + sf.write(output_filename, audio, 22050, "PCM_24") + + wav_secs = wav_lengths.sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + if mel_infer_secs is not None: + mel_rtf = mel_infer_secs / wav_secs + print(f"Matcha RTF: {mel_rtf}") + if vocoder_infer_secs is not None: + vocoder_rtf = vocoder_infer_secs / wav_secs + print(f"Vocoder RTF: {vocoder_rtf}") + print(f"Overall RTF: {rtf}") + + +def write_mels(model, inputs, output_dir): + t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, mel in enumerate(mels): + output_stem = output_dir.joinpath(f"output_{i + 1}") + plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) + np.save(output_stem.with_suffix(".numpy"), mel) + + wav_secs = (mel_lengths * 256).sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + print(f"RTF: {rtf}") + + +def main(): + parser = argparse.ArgumentParser( + description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" + ) + parser.add_argument( + "model", + type=str, + help="ONNX model to use", + ) + parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") + parser.add_argument("--text", type=str, default=None, help="Text to synthesize") + parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") + parser.add_argument("--spk", type=int, default=None, help="Speaker ID") + parser.add_argument( + "--temperature", + type=float, + default=0.667, + help="Variance of the x0 noise (default: 0.667)", + ) + parser.add_argument( + "--speaking-rate", + type=float, + default=1.0, + help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", + ) + parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") + parser.add_argument( + "--output-dir", + type=str, + default=os.getcwd(), + help="Output folder to save results (default: current dir)", + ) + + args = parser.parse_args() + args = validate_args(args) + + if args.gpu: + providers = ["GPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + model = ort.InferenceSession(args.model, providers=providers) + + model_inputs = model.get_inputs() + model_outputs = list(model.get_outputs()) + + if args.text: + text_lines = args.text.splitlines() + else: + with open(args.file, encoding="utf-8") as file: + text_lines = file.read().splitlines() + + processed_lines = [process_text(0, line, "cpu") for line in text_lines] + x = [line["x"].squeeze() for line in processed_lines] + # Pad + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) + x = x.detach().cpu().numpy() + x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) + inputs = { + "x": x, + "x_lengths": x_lengths, + "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), + } + is_multi_speaker = len(model_inputs) == 4 + if is_multi_speaker: + if args.spk is None: + args.spk = 0 + warn = "[!] Speaker ID not provided! Using speaker ID 0" + warnings.warn(warn, UserWarning) + inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) + + has_vocoder_embedded = model_outputs[0].name == "wav" + if has_vocoder_embedded: + write_wavs(model, inputs, args.output_dir) + elif args.vocoder: + external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) + write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) + else: + warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" + warnings.warn(warn, UserWarning) + write_mels(model, inputs, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/third_party/Matcha-TTS/matcha/text/__init__.py b/third_party/Matcha-TTS/matcha/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c75d6b5714a0a2d30b95e00a5377c13f29d9b8a --- /dev/null +++ b/third_party/Matcha-TTS/matcha/text/__init__.py @@ -0,0 +1,53 @@ +""" from https://github.com/keithito/tacotron """ +from matcha.text import cleaners +from matcha.text.symbols import symbols + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + return sequence, clean_text + + +def cleaned_text_to_sequence(cleaned_text): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + return text diff --git a/third_party/Matcha-TTS/matcha/text/cleaners.py b/third_party/Matcha-TTS/matcha/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..36776e355257625749290f04c705335e72ffcb52 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/text/cleaners.py @@ -0,0 +1,121 @@ +""" from https://github.com/keithito/tacotron + +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import logging +import re + +import phonemizer +from unidecode import unidecode + +# To avoid excessive logging we set the log level of the phonemizer package to Critical +critical_logger = logging.getLogger("phonemizer") +critical_logger.setLevel(logging.CRITICAL) + +# Intializing the phonemizer globally significantly reduces the speed +# now the phonemizer is not initialising at every call +# Might be less flexible, but it is much-much faster +global_phonemizer = phonemizer.backend.EspeakBackend( + language="en-us", + preserve_punctuation=True, + with_stress=True, + language_switch="remove-flags", + logger=critical_logger, +) + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners2(text): + """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] + phonemes = collapse_whitespace(phonemes) + return phonemes + + +# I am removing this due to incompatibility with several version of python +# However, if you want to use it, you can uncomment it +# and install piper-phonemize with the following command: +# pip install piper-phonemize + +# import piper_phonemize +# def english_cleaners_piper(text): +# """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" +# text = convert_to_ascii(text) +# text = lowercase(text) +# text = expand_abbreviations(text) +# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) +# phonemes = collapse_whitespace(phonemes) +# return phonemes diff --git a/third_party/Matcha-TTS/matcha/text/numbers.py b/third_party/Matcha-TTS/matcha/text/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..f99a8686dcb73532091122613e74bd643a8a327f --- /dev/null +++ b/third_party/Matcha-TTS/matcha/text/numbers.py @@ -0,0 +1,71 @@ +""" from https://github.com/keithito/tacotron """ + +import re + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return f"{dollars} {dollar_unit}, {cents} {cent_unit}" + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return f"{dollars} {dollar_unit}" + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return f"{cents} {cent_unit}" + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/third_party/Matcha-TTS/matcha/text/symbols.py b/third_party/Matcha-TTS/matcha/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..7018df549a1e50c3be20416069b6913c641024bd --- /dev/null +++ b/third_party/Matcha-TTS/matcha/text/symbols.py @@ -0,0 +1,17 @@ +""" from https://github.com/keithito/tacotron + +Defines the set of symbols used in text input to the model. +""" +_pad = "_" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = ( + "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" +) + + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbols.index(" ") diff --git a/third_party/Matcha-TTS/matcha/train.py b/third_party/Matcha-TTS/matcha/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d64c6c44af2622be5e6bf368616feb6619ed7e --- /dev/null +++ b/third_party/Matcha-TTS/matcha/train.py @@ -0,0 +1,122 @@ +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import lightning as L +import rootutils +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from matcha import utils + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + + +log = utils.get_pylogger(__name__) + + +@utils.task_wrapper +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + utils.extras(cfg) + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/third_party/Matcha-TTS/matcha/utils/__init__.py b/third_party/Matcha-TTS/matcha/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..074db6461184e8cbb86d977cb41d9ebd918e958a --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/__init__.py @@ -0,0 +1,5 @@ +from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers +from matcha.utils.logging_utils import log_hyperparameters +from matcha.utils.pylogger import get_pylogger +from matcha.utils.rich_utils import enforce_tags, print_config_tree +from matcha.utils.utils import extras, get_metric_value, task_wrapper diff --git a/third_party/Matcha-TTS/matcha/utils/audio.py b/third_party/Matcha-TTS/matcha/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcd74df47fb006f68deb5a5f4a4c2fb0aa84f57 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/audio.py @@ -0,0 +1,82 @@ +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py b/third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..49ed3c1b072cc3292c899b200d657a8beec197f8 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py @@ -0,0 +1,112 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from tqdm.auto import tqdm + +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.utils.logging_utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): + """Generate data mean and standard deviation helpful in data normalisation + + Args: + data_loader (torch.utils.data.Dataloader): _description_ + out_channels (int): mel spectrogram channels + """ + total_mel_sum = 0 + total_mel_sq_sum = 0 + total_mel_len = 0 + + for batch in tqdm(data_loader, leave=False): + mels = batch["y"] + mel_lengths = batch["y_lengths"] + + total_mel_len += torch.sum(mel_lengths) + total_mel_sum += torch.sum(mels) + total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) + + data_mean = total_mel_sum / (total_mel_len * out_channels) + data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + + return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="vctk.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="256", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + args = parser.parse_args() + output_file = Path(args.input_config).with_suffix(".json") + + if os.path.exists(output_file) and not args.force: + print("File already exists. Use -f to force overwrite") + sys.exit(1) + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["data_statistics"] = None + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + data_loader = text_mel_datamodule.train_dataloader() + log.info("Dataloader loaded! Now computing stats...") + params = compute_data_statistics(data_loader, cfg["n_feats"]) + print(params) + json.dump( + params, + open(output_file, "w"), + ) + + +if __name__ == "__main__": + main() diff --git a/third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py b/third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe2f35c4238756158370ed1463bfa06f05f7e3d --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py @@ -0,0 +1,195 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import lightning +import numpy as np +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from torch import nn +from tqdm.auto import tqdm + +from matcha.cli import get_device +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.models.matcha_tts import MatchaTTS +from matcha.utils.logging_utils import pylogger +from matcha.utils.utils import get_phoneme_durations + +log = pylogger.get_pylogger(__name__) + + +def save_durations_to_folder( + attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str +): + durations = attn.squeeze().sum(1)[:x_length].numpy() + durations_json = get_phoneme_durations(durations, text) + output = output_folder / Path(filepath).name.replace(".wav", ".npy") + with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: + json.dump(durations_json, f, indent=4, ensure_ascii=False) + + np.save(output, durations) + + +@torch.inference_mode() +def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): + """Generate durations from the model for each datapoint and save it in a folder + + Args: + data_loader (torch.utils.data.DataLoader): Dataloader + model (nn.Module): MatchaTTS model + device (torch.device): GPU or CPU + """ + + for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + x = x.to(device) + y = y.to(device) + x_lengths = x_lengths.to(device) + y_lengths = y_lengths.to(device) + spks = spks.to(device) if spks is not None else None + + _, _, _, attn = model( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + ) + attn = attn.cpu() + for i in range(attn.shape[0]): + save_durations_to_folder( + attn[i], + x_lengths[i].item(), + y_lengths[i].item(), + batch["filepaths"][i], + output_folder, + batch["x_texts"][i], + ) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="ljspeech.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="32", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + parser.add_argument( + "-c", + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file to load the model from", + ) + + parser.add_argument( + "-o", + "--output-folder", + type=str, + default=None, + help="Output folder to save the data statistics", + ) + + parser.add_argument( + "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" + ) + + args = parser.parse_args() + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False + + if args.output_folder is not None: + output_folder = Path(args.output_folder) + else: + output_folder = Path(cfg["train_filelist_path"]).parent / "durations" + + print(f"Output folder set to: {output_folder}") + + if os.path.exists(output_folder) and not args.force: + print("Folder already exists. Use -f to force overwrite") + sys.exit(1) + + output_folder.mkdir(parents=True, exist_ok=True) + + print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") + print("Loading model...") + device = get_device(args) + model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + try: + print("Computing stats for training set if exists...") + train_dataloader = text_mel_datamodule.train_dataloader() + compute_durations(train_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No training set found") + + try: + print("Computing stats for validation set if exists...") + val_dataloader = text_mel_datamodule.val_dataloader() + compute_durations(val_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No validation set found") + + try: + print("Computing stats for test set if exists...") + test_dataloader = text_mel_datamodule.test_dataloader() + compute_durations(test_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No test set found") + + print(f"[+] Done! Data statistics saved to: {output_folder}") + + +if __name__ == "__main__": + # Helps with generating durations for the dataset to train other architectures + # that cannot learn to align due to limited size of dataset + # Example usage: + # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model + # This will create a folder in data/processed_data/durations/ljspeech with the durations + main() diff --git a/third_party/Matcha-TTS/matcha/utils/instantiators.py b/third_party/Matcha-TTS/matcha/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..5547b4ed61ed8c21e63c528f58526a949879a94f --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/third_party/Matcha-TTS/matcha/utils/logging_utils.py b/third_party/Matcha-TTS/matcha/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a12d1ddafa25ca3ae8e497bcd7de2191f13659b --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/logging_utils.py @@ -0,0 +1,53 @@ +from typing import Any, Dict + +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import OmegaConf + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) + hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/third_party/Matcha-TTS/matcha/utils/model.py b/third_party/Matcha-TTS/matcha/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..869cc6092f5952930534c47544fae88308e96abf --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/model.py @@ -0,0 +1,90 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import numpy as np +import torch + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if not torch.onnx.is_in_onnx_export(): + return length.int().item() + else: + return length + + +def convert_pad_shape(pad_shape): + inverted_shape = pad_shape[::-1] + pad_shape = [item for sublist in inverted_shape for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) + return loss + + +def normalize(data, mu, std): + if not isinstance(mu, (float, int)): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, (float, int)): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return (data - mu) / std + + +def denormalize(data, mu, std): + if not isinstance(mu, float): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, float): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return data * std + mu diff --git a/third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py b/third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eee6e0d47c2e3612ef02bc17442e6886998e5a94 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py @@ -0,0 +1,22 @@ +import numpy as np +import torch + +from matcha.utils.monotonic_align.core import maximum_path_c + + +def maximum_path(value, mask): + """Cython optimised version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx b/third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx new file mode 100644 index 0000000000000000000000000000000000000000..091fcc3a50a51f3d3fee47a70825260757e6d885 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx @@ -0,0 +1,47 @@ +import numpy as np + +cimport cython +cimport numpy as np + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py b/third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f22bc6a35a5a04c9e6d7b82040973722c9b770c9 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py @@ -0,0 +1,7 @@ +# from distutils.core import setup +# from Cython.Build import cythonize +# import numpy + +# setup(name='monotonic_align', +# ext_modules=cythonize("core.pyx"), +# include_dirs=[numpy.get_include()]) diff --git a/third_party/Matcha-TTS/matcha/utils/pylogger.py b/third_party/Matcha-TTS/matcha/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..61600678029362e110f655edb91d5f3bc5b1cd1c --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/pylogger.py @@ -0,0 +1,21 @@ +import logging + +from lightning.pytorch.utilities import rank_zero_only + + +def get_pylogger(name: str = __name__) -> logging.Logger: + """Initializes a multi-GPU-friendly python command line logger. + + :param name: The name of the logger, defaults to ``__name__``. + + :return: A logger object. + """ + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/third_party/Matcha-TTS/matcha/utils/rich_utils.py b/third_party/Matcha-TTS/matcha/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f602f6e9351d948946eb419eb4e420190ea634bc --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/rich_utils.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + _ = ( + queue.append(field) + if field in cfg + else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/third_party/Matcha-TTS/matcha/utils/utils.py b/third_party/Matcha-TTS/matcha/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3a48ec2b532ff8e034181d71ed5f4d7823d9be --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/utils.py @@ -0,0 +1,259 @@ +import os +import sys +import warnings +from importlib.util import find_spec +from math import ceil +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import gdown +import matplotlib.pyplot as plt +import numpy as np +import torch +import wget +from omegaconf import DictConfig + +from matcha.utils import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: The name of the metric to retrieve. + :return: The value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise ValueError( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def intersperse(lst, item): + # Adds blank symbol + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_tensor(tensor): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def save_plot(tensor, savepath): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + plt.savefig(savepath) + plt.close() + + +def to_numpy(tensor): + if isinstance(tensor, np.ndarray): + return tensor + elif isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + elif isinstance(tensor, list): + return np.array(tensor) + else: + raise TypeError("Unsupported type for conversion to numpy array") + + +def get_user_data_dir(appname="matcha_tts"): + """ + Args: + appname (str): Name of application + + Returns: + Path: path to user data directory + """ + + MATCHA_HOME = os.environ.get("MATCHA_HOME") + if MATCHA_HOME is not None: + ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel + + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", + ) + dir_, _ = winreg.QueryValueEx(key, "Local AppData") + ans = Path(dir_).resolve(strict=False) + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() + else: + ans = Path.home().joinpath(".local/share") + + final_path = ans.joinpath(appname) + final_path.mkdir(parents=True, exist_ok=True) + return final_path + + +def assert_model_downloaded(checkpoint_path, url, use_wget=True): + if Path(checkpoint_path).exists(): + log.debug(f"[+] Model already present at {checkpoint_path}!") + print(f"[+] Model already present at {checkpoint_path}!") + return + log.info(f"[-] Model not found at {checkpoint_path}! Will download it") + print(f"[-] Model not found at {checkpoint_path}! Will download it") + checkpoint_path = str(checkpoint_path) + if not use_wget: + gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) + else: + wget.download(url=url, out=checkpoint_path) + + +def get_phoneme_durations(durations, phones): + prev = durations[0] + merged_durations = [] + # Convolve with stride 2 + for i in range(1, len(durations), 2): + if i == len(durations) - 2: + # if it is last take full value + next_half = durations[i + 1] + else: + next_half = ceil(durations[i + 1] / 2) + + curr = prev + durations[i] + next_half + prev = durations[i + 1] - next_half + merged_durations.append(curr) + + assert len(phones) == len(merged_durations) + assert len(merged_durations) == (len(durations) - 1) // 2 + + merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) + start = torch.tensor(0) + duration_json = [] + for i, duration in enumerate(merged_durations): + duration_json.append( + { + phones[i]: { + "starttime": start.item(), + "endtime": duration.item(), + "duration": duration.item() - start.item(), + } + } + ) + start = duration + + assert list(duration_json[-1].values())[0]["endtime"] == sum( + durations + ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" + return duration_json diff --git a/third_party/Matcha-TTS/notebooks/.gitkeep b/third_party/Matcha-TTS/notebooks/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/pyproject.toml b/third_party/Matcha-TTS/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..74aa39300a61b8b3607dc634d68aa47013141ec5 --- /dev/null +++ b/third_party/Matcha-TTS/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"] + +[tool.black] +line-length = 120 +target-version = ['py310'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ + | foo.py # also separately exclude a file named foo.py in + # the root of the project +) +''' + +[tool.pytest.ini_options] +addopts = [ + "--color=yes", + "--durations=0", + "--strict-markers", + "--doctest-modules", +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::UserWarning", +] +log_cli = "True" +markers = [ + "slow: slow tests", +] +minversion = "6.0" +testpaths = "tests/" + +[tool.coverage.report] +exclude_lines = [ + "pragma: nocover", + "raise NotImplementedError", + "raise NotImplementedError()", + "if __name__ == .__main__.:", +] diff --git a/third_party/Matcha-TTS/requirements.txt b/third_party/Matcha-TTS/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b0eabfbd38a9483d84d3ea960671460fe69c89e --- /dev/null +++ b/third_party/Matcha-TTS/requirements.txt @@ -0,0 +1,44 @@ +# --------- pytorch --------- # +torch>=2.0.0 +torchvision>=0.15.0 +lightning>=2.0.0 +torchmetrics>=0.11.4 + +# --------- hydra --------- # +hydra-core==1.3.2 +hydra-colorlog==1.2.0 +hydra-optuna-sweeper==1.2.0 + +# --------- loggers --------- # +# wandb +# neptune-client +# mlflow +# comet-ml +# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 + +# --------- others --------- # +rootutils # standardizing the project root setup +pre-commit # hooks for applying linters on commit +rich # beautiful text formatting in terminal +pytest # tests +# sh # for running bash commands in some tests (linux/macos only) +phonemizer # phonemization of text +tensorboard +librosa +Cython +numpy +einops +inflect +Unidecode +scipy +torchaudio +matplotlib +pandas +conformer==0.3.2 +diffusers # developed using version ==0.25.0 +notebook +ipywidgets +gradio==3.43.2 +gdown +wget +seaborn diff --git a/third_party/Matcha-TTS/scripts/schedule.sh b/third_party/Matcha-TTS/scripts/schedule.sh new file mode 100644 index 0000000000000000000000000000000000000000..44b3da1116ef4d54e9acffee7d639d549e136d45 --- /dev/null +++ b/third_party/Matcha-TTS/scripts/schedule.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Schedule execution of many runs +# Run from root folder with: bash scripts/schedule.sh + +python src/train.py trainer.max_epochs=5 logger=csv + +python src/train.py trainer.max_epochs=10 logger=csv diff --git a/third_party/Matcha-TTS/setup.py b/third_party/Matcha-TTS/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c2ccd4b7eaa960f15cb682a40d8595101b2ab --- /dev/null +++ b/third_party/Matcha-TTS/setup.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +import os + +import numpy +from Cython.Build import cythonize +from setuptools import Extension, find_packages, setup + +exts = [ + Extension( + name="matcha.utils.monotonic_align.core", + sources=["matcha/utils/monotonic_align/core.pyx"], + ) +] + +with open("README.md", encoding="utf-8") as readme_file: + README = readme_file.read() + +cwd = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(cwd, "matcha", "VERSION")) as fin: + version = fin.read().strip() + +setup( + name="matcha-tts", + version=version, + description="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching", + long_description=README, + long_description_content_type="text/markdown", + author="Shivam Mehta", + author_email="shivam.mehta25@gmail.com", + url="https://shivammehta25.github.io/Matcha-TTS", + install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], + include_dirs=[numpy.get_include()], + include_package_data=True, + packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), + # use this to customize global commands available in the terminal after installing the package + entry_points={ + "console_scripts": [ + "matcha-data-stats=matcha.utils.generate_data_statistics:main", + "matcha-tts=matcha.cli:cli", + "matcha-tts-app=matcha.app:main", + "matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main", + ] + }, + ext_modules=cythonize(exts, language_level=3), + python_requires=">=3.9.0", +) diff --git a/third_party/Matcha-TTS/synthesis.ipynb b/third_party/Matcha-TTS/synthesis.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1e47c534f93b2901910386cd02a75018abfc2570 --- /dev/null +++ b/third_party/Matcha-TTS/synthesis.ipynb @@ -0,0 +1,419 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f37f4e3b-f764-4502-a6a2-6417bd9bfab9", + "metadata": {}, + "source": [ + "# Matcha-TTS: A fast TTS architecture with conditional flow matching\n", + "---\n", + "[Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)\n", + "\n", + "We introduce Matcha-TTS, a new encoder-decoder architecture for speedy TTS acoustic modelling, trained using optimal-transport conditional flow matching (OT-CFM). This yields an ODE-based decoder capable of high output quality in fewer synthesis steps than models trained using score matching. Careful design choices additionally ensure each synthesis step is fast to run. The method is probabilistic, non-autoregressive, and learns to speak from scratch without external alignments. Compared to strong pre-trained baseline models, the Matcha-TTS system has the smallest memory footprint, rivals the speed of the fastest models on long utterances, and attains the highest mean opinion score in a listening test.\n", + "\n", + "Demo Page: https://shivammehta25.github.io/Matcha-TTS \\\n", + "Code: https://github.com/shivammehta25/Matcha-TTS\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: CUDA_VISIBLE_DEVICES=0\n" + ] + } + ], + "source": [ + "%env CUDA_VISIBLE_DEVICES=0" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d5876c0-b47e-4c80-9e9c-62550f81b64e", + "metadata": {}, + "outputs": [], + "source": [ + "import datetime as dt\n", + "from pathlib import Path\n", + "\n", + "import IPython.display as ipd\n", + "import numpy as np\n", + "import soundfile as sf\n", + "import torch\n", + "from tqdm.auto import tqdm\n", + "\n", + "# Hifigan imports\n", + "from matcha.hifigan.config import v1\n", + "from matcha.hifigan.denoiser import Denoiser\n", + "from matcha.hifigan.env import AttrDict\n", + "from matcha.hifigan.models import Generator as HiFiGAN\n", + "# Matcha imports\n", + "from matcha.models.matcha_tts import MatchaTTS\n", + "from matcha.text import sequence_to_text, text_to_sequence\n", + "from matcha.utils.model import denormalize\n", + "from matcha.utils.utils import get_user_data_dir, intersperse" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b1a30306-588c-4f22-8d9b-e2676880b0e5", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "# This allows for real time code changes being reflected in the notebook, no need to restart the kernel" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a312856b-01a9-4d75-a4c8-4666dffa0692", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "88f3b3c3-d014-443b-84eb-e143cdec3e21", + "metadata": {}, + "source": [ + "## Filepaths" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7640a4c1-44ce-447c-a8ff-45012fb7bddd", + "metadata": {}, + "outputs": [], + "source": [ + "MATCHA_CHECKPOINT = get_user_data_dir()/\"matcha_ljspeech.ckpt\"\n", + "HIFIGAN_CHECKPOINT = get_user_data_dir() / \"hifigan_T2_v1\"\n", + "OUTPUT_FOLDER = \"synth_output\"" + ] + }, + { + "cell_type": "markdown", + "id": "6477a3a9-71f2-4d2f-bb86-bdf3e31c2461", + "metadata": {}, + "source": [ + "## Load Matcha-TTS" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "26a16230-04ba-4825-a844-2fb5ab945e24", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model loaded! Parameter count: 18,204,193\n" + ] + } + ], + "source": [ + "def load_model(checkpoint_path):\n", + " model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)\n", + " model.eval()\n", + " return model\n", + "count_params = lambda x: f\"{sum(p.numel() for p in x.parameters()):,}\"\n", + "\n", + "\n", + "model = load_model(MATCHA_CHECKPOINT)\n", + "print(f\"Model loaded! Parameter count: {count_params(model)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3077b84b-e3b6-42e1-a84b-2f7084b13f92", + "metadata": {}, + "source": [ + "## Load HiFi-GAN (Vocoder)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f6b68184-968d-4868-9029-f0c40e9e68af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Removing weight norm...\n" + ] + } + ], + "source": [ + "def load_vocoder(checkpoint_path):\n", + " h = AttrDict(v1)\n", + " hifigan = HiFiGAN(h).to(device)\n", + " hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])\n", + " _ = hifigan.eval()\n", + " hifigan.remove_weight_norm()\n", + " return hifigan\n", + "\n", + "vocoder = load_vocoder(HIFIGAN_CHECKPOINT)\n", + "denoiser = Denoiser(vocoder, mode='zeros')" + ] + }, + { + "cell_type": "markdown", + "id": "4cbc2ba0-09ff-40e2-9e60-6b77b534f9fb", + "metadata": {}, + "source": [ + "### Helper functions to synthesise" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "880a1879-24fd-4757-849c-850339120796", + "metadata": {}, + "outputs": [], + "source": [ + "@torch.inference_mode()\n", + "def process_text(text: str):\n", + " x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2'])[0], 0),dtype=torch.long, device=device)[None]\n", + " x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n", + " x_phones = sequence_to_text(x.squeeze(0).tolist())\n", + " return {\n", + " 'x_orig': text,\n", + " 'x': x,\n", + " 'x_lengths': x_lengths,\n", + " 'x_phones': x_phones\n", + " }\n", + "\n", + "\n", + "@torch.inference_mode()\n", + "def synthesise(text, spks=None):\n", + " text_processed = process_text(text)\n", + " start_t = dt.datetime.now()\n", + " output = model.synthesise(\n", + " text_processed['x'], \n", + " text_processed['x_lengths'],\n", + " n_timesteps=n_timesteps,\n", + " temperature=temperature,\n", + " spks=spks,\n", + " length_scale=length_scale\n", + " )\n", + " # merge everything to one dict \n", + " output.update({'start_t': start_t, **text_processed})\n", + " return output\n", + "\n", + "@torch.inference_mode()\n", + "def to_waveform(mel, vocoder):\n", + " audio = vocoder(mel).clamp(-1, 1)\n", + " audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()\n", + " return audio.cpu().squeeze()\n", + " \n", + "def save_to_folder(filename: str, output: dict, folder: str):\n", + " folder = Path(folder)\n", + " folder.mkdir(exist_ok=True, parents=True)\n", + " np.save(folder / f'{filename}', output['mel'].cpu().numpy())\n", + " sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')" + ] + }, + { + "cell_type": "markdown", + "id": "78f857e3-2ef7-4c86-b776-596c4d3cf875", + "metadata": {}, + "source": [ + "## Setup text to synthesise" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2e0a9acd-0845-4192-ba09-b9683e28a3ac", + "metadata": {}, + "outputs": [], + "source": [ + "texts = [\n", + " \"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.\"\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "a9da9e2d-99b9-4c6f-8a08-c828e2cba121", + "metadata": {}, + "source": [ + "### Hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f0d216e5-4895-4da8-9d24-9e61021d2556", + "metadata": {}, + "outputs": [], + "source": [ + "## Number of ODE Solver steps\n", + "n_timesteps = 10\n", + "\n", + "## Changes to the speaking rate\n", + "length_scale=1.0\n", + "\n", + "## Sampling temperature\n", + "temperature = 0.667" + ] + }, + { + "cell_type": "markdown", + "id": "b93aac89-c7f8-4975-8510-4e763c9689f4", + "metadata": {}, + "source": [ + "## Synthesis" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5a227963-aa12-43b9-a706-1168b6fc0ba5", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8342d12401c54017b0e19b8d293a06bf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of ODE steps: 10\n", + "Mean RTF:\t\t\t\t0.017228 ± 0.000000\n", + "Mean RTF Waveform (incl. vocoder):\t0.021445 ± 0.000000\n" + ] + } + ], + "source": [ + "outputs, rtfs = [], []\n", + "rtfs_w = []\n", + "for i, text in enumerate(tqdm(texts)):\n", + " output = synthesise(text) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))\n", + " output['waveform'] = to_waveform(output['mel'], vocoder)\n", + "\n", + " # Compute Real Time Factor (RTF) with HiFi-GAN\n", + " t = (dt.datetime.now() - output['start_t']).total_seconds()\n", + " rtf_w = t * 22050 / (output['waveform'].shape[-1])\n", + "\n", + " ## Pretty print\n", + " print(f\"{'*' * 53}\")\n", + " print(f\"Input text - {i}\")\n", + " print(f\"{'-' * 53}\")\n", + " print(output['x_orig'])\n", + " print(f\"{'*' * 53}\")\n", + " print(f\"Phonetised text - {i}\")\n", + " print(f\"{'-' * 53}\")\n", + " print(output['x_phones'])\n", + " print(f\"{'*' * 53}\")\n", + " print(f\"RTF:\\t\\t{output['rtf']:.6f}\")\n", + " print(f\"RTF Waveform:\\t{rtf_w:.6f}\")\n", + " rtfs.append(output['rtf'])\n", + " rtfs_w.append(rtf_w)\n", + "\n", + " ## Display the synthesised waveform\n", + " ipd.display(ipd.Audio(output['waveform'], rate=22050))\n", + "\n", + " ## Save the generated waveform\n", + " save_to_folder(i, output, OUTPUT_FOLDER)\n", + "\n", + "print(f\"Number of ODE steps: {n_timesteps}\")\n", + "print(f\"Mean RTF:\\t\\t\\t\\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}\")\n", + "print(f\"Mean RTF Waveform (incl. vocoder):\\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3e85c3f-1623-4647-b40c-fa96907656fc", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}