Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		TomatoCocotree
		
	commited on
		
		
					Commit 
							
							·
						
						6a62ffb
	
1
								Parent(s):
							
							59656d8
								
上传
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- Dockerfile +21 -0
 - LICENSE +24 -0
 - api_key.txt +1 -0
 - constants.py +49 -0
 - data/models/coqui/.placeholder +2 -0
 - data/models/rvc/.placeholder +3 -0
 - data/tmp/.placeholder +2 -0
 - docker/Dockerfile +35 -0
 - docker/docker-compose.yml +23 -0
 - docker/readme.md +10 -0
 - modules/classify/classify_module.py +41 -0
 - modules/speech_recognition/streaming_module.py +121 -0
 - modules/speech_recognition/vosk_module.py +77 -0
 - modules/speech_recognition/whisper_module.py +56 -0
 - modules/text_to_speech/coqui/coqui_module.py +333 -0
 - modules/utils.py +15 -0
 - modules/voice_conversion/fairseq/LICENSE +21 -0
 - modules/voice_conversion/fairseq/__init__.py +45 -0
 - modules/voice_conversion/fairseq/binarizer.py +381 -0
 - modules/voice_conversion/fairseq/checkpoint_utils.py +905 -0
 - modules/voice_conversion/fairseq/data/__init__.py +130 -0
 - modules/voice_conversion/fairseq/data/add_target_dataset.py +83 -0
 - modules/voice_conversion/fairseq/data/append_token_dataset.py +41 -0
 - modules/voice_conversion/fairseq/data/audio/__init__.py +93 -0
 - modules/voice_conversion/fairseq/data/audio/audio_utils.py +389 -0
 - modules/voice_conversion/fairseq/data/audio/data_cfg.py +387 -0
 - modules/voice_conversion/fairseq/data/audio/dataset_transforms/__init__.py +53 -0
 - modules/voice_conversion/fairseq/data/audio/dataset_transforms/concataugment.py +61 -0
 - modules/voice_conversion/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py +105 -0
 - modules/voice_conversion/fairseq/data/audio/feature_transforms/__init__.py +43 -0
 - modules/voice_conversion/fairseq/data/audio/feature_transforms/delta_deltas.py +37 -0
 - modules/voice_conversion/fairseq/data/audio/feature_transforms/global_cmvn.py +29 -0
 - modules/voice_conversion/fairseq/data/audio/feature_transforms/specaugment.py +131 -0
 - modules/voice_conversion/fairseq/data/audio/feature_transforms/utterance_cmvn.py +41 -0
 - modules/voice_conversion/fairseq/data/audio/frm_text_to_speech_dataset.py +205 -0
 - modules/voice_conversion/fairseq/data/audio/hubert_dataset.py +356 -0
 - modules/voice_conversion/fairseq/data/audio/multi_modality_dataset.py +284 -0
 - modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py +393 -0
 - modules/voice_conversion/fairseq/data/audio/speech_to_speech_dataset.py +379 -0
 - modules/voice_conversion/fairseq/data/audio/speech_to_text_dataset.py +733 -0
 - modules/voice_conversion/fairseq/data/audio/speech_to_text_joint_dataset.py +359 -0
 - modules/voice_conversion/fairseq/data/audio/text_to_speech_dataset.py +250 -0
 - modules/voice_conversion/fairseq/data/audio/waveform_transforms/__init__.py +48 -0
 - modules/voice_conversion/fairseq/data/audio/waveform_transforms/noiseaugment.py +201 -0
 - modules/voice_conversion/fairseq/data/backtranslation_dataset.py +165 -0
 - modules/voice_conversion/fairseq/data/base_wrapper_dataset.py +78 -0
 - modules/voice_conversion/fairseq/data/bucket_pad_length_dataset.py +78 -0
 - modules/voice_conversion/fairseq/data/codedataset.py +576 -0
 - modules/voice_conversion/fairseq/data/colorize_dataset.py +25 -0
 - modules/voice_conversion/fairseq/data/concat_dataset.py +124 -0
 
    	
        Dockerfile
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            FROM python:3.11
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            WORKDIR /app
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            COPY  requirements-complete.txt .
         
     | 
| 6 | 
         
            +
            RUN pip install -r  requirements-complete.txt
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            RUN mkdir /.cache && chmod -R 777 /.cache
         
     | 
| 9 | 
         
            +
            RUN mkdir .chroma && chmod -R 777 .chroma
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            COPY . .
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            RUN chmod -R 777 /app
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            RUN --mount=type=secret,id=password,mode=0444,required=true \
         
     | 
| 17 | 
         
            +
               cat /run/secrets/password > /test
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            EXPOSE 7860
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            CMD ["python", "server.py", "--cpu", "--enable-modules=caption,summarize,classify,silero-tts,edge-tts,chromadb"]
         
     | 
    	
        LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,24 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            This is free and unencumbered software released into the public domain.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Anyone is free to copy, modify, publish, use, compile, sell, or
         
     | 
| 4 | 
         
            +
            distribute this software, either in source code form or as a compiled
         
     | 
| 5 | 
         
            +
            binary, for any purpose, commercial or non-commercial, and by any
         
     | 
| 6 | 
         
            +
            means.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            In jurisdictions that recognize copyright laws, the author or authors
         
     | 
| 9 | 
         
            +
            of this software dedicate any and all copyright interest in the
         
     | 
| 10 | 
         
            +
            software to the public domain. We make this dedication for the benefit
         
     | 
| 11 | 
         
            +
            of the public at large and to the detriment of our heirs and
         
     | 
| 12 | 
         
            +
            successors. We intend this dedication to be an overt act of
         
     | 
| 13 | 
         
            +
            relinquishment in perpetuity of all present and future rights to this
         
     | 
| 14 | 
         
            +
            software under copyright law.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
         
     | 
| 17 | 
         
            +
            EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
         
     | 
| 18 | 
         
            +
            MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
         
     | 
| 19 | 
         
            +
            IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
         
     | 
| 20 | 
         
            +
            OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
         
     | 
| 21 | 
         
            +
            ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
         
     | 
| 22 | 
         
            +
            OTHER DEALINGS IN THE SOFTWARE.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            For more information, please refer to <https://unlicense.org>
         
     | 
    	
        api_key.txt
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            CHANGEME
         
     | 
    	
        constants.py
    ADDED
    
    | 
         @@ -0,0 +1,49 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Constants
         
     | 
| 2 | 
         
            +
            DEFAULT_CUDA_DEVICE = "cuda:0"
         
     | 
| 3 | 
         
            +
            # Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
         
     | 
| 4 | 
         
            +
            DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
         
     | 
| 5 | 
         
            +
            # Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
         
     | 
| 6 | 
         
            +
            DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
         
     | 
| 7 | 
         
            +
            # Also try: 'Salesforce/blip-image-captioning-base'
         
     | 
| 8 | 
         
            +
            DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
         
     | 
| 9 | 
         
            +
            DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
         
     | 
| 10 | 
         
            +
            DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
         
     | 
| 11 | 
         
            +
            DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
         
     | 
| 12 | 
         
            +
            DEFAULT_REMOTE_SD_PORT = 7860
         
     | 
| 13 | 
         
            +
            DEFAULT_CHROMA_PORT = 8000
         
     | 
| 14 | 
         
            +
            SILERO_SAMPLES_PATH = "tts_samples"
         
     | 
| 15 | 
         
            +
            SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
         
     | 
| 16 | 
         
            +
            DEFAULT_SUMMARIZE_PARAMS = {
         
     | 
| 17 | 
         
            +
                "temperature": 1.0,
         
     | 
| 18 | 
         
            +
                "repetition_penalty": 1.0,
         
     | 
| 19 | 
         
            +
                "max_length": 500,
         
     | 
| 20 | 
         
            +
                "min_length": 200,
         
     | 
| 21 | 
         
            +
                "length_penalty": 1.5,
         
     | 
| 22 | 
         
            +
                "bad_words": [
         
     | 
| 23 | 
         
            +
                    "\n",
         
     | 
| 24 | 
         
            +
                    '"',
         
     | 
| 25 | 
         
            +
                    "*",
         
     | 
| 26 | 
         
            +
                    "[",
         
     | 
| 27 | 
         
            +
                    "]",
         
     | 
| 28 | 
         
            +
                    "{",
         
     | 
| 29 | 
         
            +
                    "}",
         
     | 
| 30 | 
         
            +
                    ":",
         
     | 
| 31 | 
         
            +
                    "(",
         
     | 
| 32 | 
         
            +
                    ")",
         
     | 
| 33 | 
         
            +
                    "<",
         
     | 
| 34 | 
         
            +
                    ">",
         
     | 
| 35 | 
         
            +
                    "Â",
         
     | 
| 36 | 
         
            +
                    "The text ends",
         
     | 
| 37 | 
         
            +
                    "The story ends",
         
     | 
| 38 | 
         
            +
                    "The text is",
         
     | 
| 39 | 
         
            +
                    "The story is",
         
     | 
| 40 | 
         
            +
                ],
         
     | 
| 41 | 
         
            +
            }
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            PROMPT_PREFIX = "best quality, absurdres, "
         
     | 
| 44 | 
         
            +
            NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
         
     | 
| 45 | 
         
            +
            error hands, bad hands, error fingers, bad fingers, missing fingers
         
     | 
| 46 | 
         
            +
            error legs, bad legs, multiple legs, missing legs, error lighting,
         
     | 
| 47 | 
         
            +
            error shadow, error reflection, text, error, extra digit, fewer digits,
         
     | 
| 48 | 
         
            +
            cropped, worst quality, low quality, normal quality, jpeg artifacts,
         
     | 
| 49 | 
         
            +
            signature, watermark, username, blurry"""
         
     | 
    	
        data/models/coqui/.placeholder
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Put Coqui models folders here.
         
     | 
| 2 | 
         
            +
            Must contains both a "model.pth" and "config.json" file.
         
     | 
    	
        data/models/rvc/.placeholder
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Put RVC models folder here.
         
     | 
| 2 | 
         
            +
            Must have ".pth" file in it
         
     | 
| 3 | 
         
            +
            .index file is optional but could help improve the processing time/quality.
         
     | 
    	
        data/tmp/.placeholder
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            This is a temporary file folder.
         
     | 
| 2 | 
         
            +
            May contain RVC input/output file for research purpose.
         
     | 
    	
        docker/Dockerfile
    ADDED
    
    | 
         @@ -0,0 +1,35 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            EXPOSE 5100
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            ENV PATH="/root/miniconda3/bin:${PATH}"
         
     | 
| 6 | 
         
            +
            ARG PATH="/root/miniconda3/bin:${PATH}"
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ENV DEBIAN_FRONTEND noninteractive
         
     | 
| 9 | 
         
            +
            RUN apt-get update && apt-get install -y --no-install-recommends \
         
     | 
| 10 | 
         
            +
                    python3 python3-venv wget build-essential
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            RUN wget \
         
     | 
| 13 | 
         
            +
                https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
         
     | 
| 14 | 
         
            +
                && mkdir /root/.conda \
         
     | 
| 15 | 
         
            +
                && bash Miniconda3-latest-Linux-x86_64.sh -b \
         
     | 
| 16 | 
         
            +
                && rm -f Miniconda3-latest-Linux-x86_64.sh
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            RUN conda --version
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            RUN conda init
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            RUN conda create -n extras
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            RUN /bin/bash -c "source activate extras"
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            RUN conda install pytorch torchvision torchaudio pytorch-cuda=11.7 git -c pytorch -c nvidia -c conda-forge
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            WORKDIR /sillytavern-extras/
         
     | 
| 29 | 
         
            +
            COPY . .
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            ARG REQUIREMENTS
         
     | 
| 32 | 
         
            +
            RUN pip install -r $REQUIREMENTS
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            ARG MODULES
         
     | 
| 35 | 
         
            +
            CMD ["python","server.py","--enable-modules=$MODULES"]
         
     | 
    	
        docker/docker-compose.yml
    ADDED
    
    | 
         @@ -0,0 +1,23 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version: "3"
         
     | 
| 2 | 
         
            +
            services:
         
     | 
| 3 | 
         
            +
              sillytavern-extras:
         
     | 
| 4 | 
         
            +
                runtime: nvidia
         
     | 
| 5 | 
         
            +
                image: cohee1207/sillytavern-extras
         
     | 
| 6 | 
         
            +
                build:
         
     | 
| 7 | 
         
            +
                  context: ../
         
     | 
| 8 | 
         
            +
                  dockerfile: docker/Dockerfile
         
     | 
| 9 | 
         
            +
                  args:
         
     | 
| 10 | 
         
            +
                    REQUIREMENTS: requirements.txt
         
     | 
| 11 | 
         
            +
                    MODULES: caption,summarize,classify
         
     | 
| 12 | 
         
            +
            #        REQUIREMENTS: requirements-complete.txt
         
     | 
| 13 | 
         
            +
            #        MODULES: caption,summarize,classify,sd,silero-tts,edge-tts,chromadb
         
     | 
| 14 | 
         
            +
                volumes:
         
     | 
| 15 | 
         
            +
                  #- "./chromadb:/chromadb"
         
     | 
| 16 | 
         
            +
                  - "./cache:/root/.cache"
         
     | 
| 17 | 
         
            +
                  - "./api_key.txt:/sillytavern-extras/api_key.txt:rw"
         
     | 
| 18 | 
         
            +
                ports:
         
     | 
| 19 | 
         
            +
                  - "5100:5100"
         
     | 
| 20 | 
         
            +
                environment:
         
     | 
| 21 | 
         
            +
                  - NVIDIA_VISIBLE_DEVICES=all
         
     | 
| 22 | 
         
            +
                command: python server.py --enable-modules=caption,summarize,classify
         
     | 
| 23 | 
         
            +
            #    command: python server.py --enable-modules=caption,summarize,classify,sd,silero-tts,edge-tts,chromadb
         
     | 
    	
        docker/readme.md
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Docker Usage
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ## Building the image
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            *This is assuming you have docker and docker compose installed and running.*
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            1. Open a terminal and set your current directory to the "docker" directory in your clone of this repo.
         
     | 
| 8 | 
         
            +
            2. Adjust the "docker-compose.yml" file to match your needs. The default selection and the selection with all modules are provided as examples.
         
     | 
| 9 | 
         
            +
            3. Once ready, run the command "docker compose build" to build the "cohee1207/sillytavern-extras" docker image.
         
     | 
| 10 | 
         
            +
             
     | 
    	
        modules/classify/classify_module.py
    ADDED
    
    | 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            Classify module for SillyTavern Extras
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            Authors:
         
     | 
| 5 | 
         
            +
                - Tony Ribeiro (https://github.com/Tony-sama)
         
     | 
| 6 | 
         
            +
                - Cohee (https://github.com/Cohee1207)
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            Provides classification features for text
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            References:
         
     | 
| 11 | 
         
            +
                - https://huggingface.co/tasks/text-classification
         
     | 
| 12 | 
         
            +
            """
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from transformers import pipeline
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            DEBUG_PREFIX = "<Classify module>"
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            # Models init
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            text_emotion_pipe = None
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def init_text_emotion_classifier(model_name: str, device: str, torch_dtype: str) -> None:
         
     | 
| 23 | 
         
            +
                global text_emotion_pipe
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                print(DEBUG_PREFIX,"Initializing text classification pipeline with model",model_name)
         
     | 
| 26 | 
         
            +
                text_emotion_pipe = pipeline(
         
     | 
| 27 | 
         
            +
                        "text-classification",
         
     | 
| 28 | 
         
            +
                        model=model_name,
         
     | 
| 29 | 
         
            +
                        top_k=None,
         
     | 
| 30 | 
         
            +
                        device=device,
         
     | 
| 31 | 
         
            +
                        torch_dtype=torch_dtype,
         
     | 
| 32 | 
         
            +
                    )
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def classify_text_emotion(text: str) -> list:
         
     | 
| 36 | 
         
            +
                output = text_emotion_pipe(
         
     | 
| 37 | 
         
            +
                    text,
         
     | 
| 38 | 
         
            +
                    truncation=True,
         
     | 
| 39 | 
         
            +
                    max_length=text_emotion_pipe.model.config.max_position_embeddings,
         
     | 
| 40 | 
         
            +
                )[0]
         
     | 
| 41 | 
         
            +
                return sorted(output, key=lambda x: x["score"], reverse=True)
         
     | 
    	
        modules/speech_recognition/streaming_module.py
    ADDED
    
    | 
         @@ -0,0 +1,121 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            Speech-to-text module based on Vosk and Whisper for SillyTavern Extras
         
     | 
| 3 | 
         
            +
                - Vosk website: https://alphacephei.com/vosk/
         
     | 
| 4 | 
         
            +
                - Vosk api: https://github.com/alphacep/vosk-api
         
     | 
| 5 | 
         
            +
                - Whisper github: https://github.com/openai/whisper
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            Authors:
         
     | 
| 8 | 
         
            +
                - Tony Ribeiro (https://github.com/Tony-sama)
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            Models are saved into user cache folder, example: C:/Users/toto/.cache/whisper and C:/Users/toto/.cache/vosk
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            References:
         
     | 
| 13 | 
         
            +
                - Code adapted from:
         
     | 
| 14 | 
         
            +
                    - whisper github: https://github.com/openai/whisper
         
     | 
| 15 | 
         
            +
                    - oobabooga text-generation-webui github: https://github.com/oobabooga/text-generation-webui
         
     | 
| 16 | 
         
            +
                    - vosk github: https://github.com/alphacep/vosk-api/blob/master/python/example/test_microphone.py
         
     | 
| 17 | 
         
            +
            """
         
     | 
| 18 | 
         
            +
            from flask import jsonify, abort
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            import queue
         
     | 
| 21 | 
         
            +
            import sys
         
     | 
| 22 | 
         
            +
            import sounddevice as sd
         
     | 
| 23 | 
         
            +
            import soundfile as sf
         
     | 
| 24 | 
         
            +
            import io
         
     | 
| 25 | 
         
            +
            import numpy as np
         
     | 
| 26 | 
         
            +
            from scipy.io.wavfile import write
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            import vosk
         
     | 
| 29 | 
         
            +
            import whisper
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            DEBUG_PREFIX = "<stt streaming module>"
         
     | 
| 32 | 
         
            +
            RECORDING_FILE_PATH = "stt_test.wav"
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            whisper_model = None
         
     | 
| 35 | 
         
            +
            vosk_model = None
         
     | 
| 36 | 
         
            +
            device = None
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def load_model(file_path=None):
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
                Load given vosk model from file or default to en-us model.
         
     | 
| 41 | 
         
            +
                Download model to user cache folder, example: C:/Users/toto/.cache/vosk
         
     | 
| 42 | 
         
            +
                """
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                if file_path is None:
         
     | 
| 45 | 
         
            +
                    return (whisper.load_model("base.en"), vosk.Model(lang="en-us"))
         
     | 
| 46 | 
         
            +
                else:
         
     | 
| 47 | 
         
            +
                    return (whisper.load_model(file_path), vosk.Model(lang="en-us"))
         
     | 
| 48 | 
         
            +
                
         
     | 
| 49 | 
         
            +
            def convert_bytearray_to_wav_ndarray(input_bytearray: bytes, sampling_rate=16000):
         
     | 
| 50 | 
         
            +
                """
         
     | 
| 51 | 
         
            +
                Convert a bytearray to wav format to output in a file for quality check debuging
         
     | 
| 52 | 
         
            +
                """
         
     | 
| 53 | 
         
            +
                bytes_wav = bytes()
         
     | 
| 54 | 
         
            +
                byte_io = io.BytesIO(bytes_wav)
         
     | 
| 55 | 
         
            +
                write(byte_io, sampling_rate, np.frombuffer(input_bytearray, dtype=np.int16))
         
     | 
| 56 | 
         
            +
                output_wav = byte_io.read()
         
     | 
| 57 | 
         
            +
                output, _ = sf.read(io.BytesIO(output_wav))
         
     | 
| 58 | 
         
            +
                return output
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            def record_and_transcript():
         
     | 
| 61 | 
         
            +
                """
         
     | 
| 62 | 
         
            +
                Continuously record from mic and transcript voice.
         
     | 
| 63 | 
         
            +
                Return the transcript once no more voice is detected.
         
     | 
| 64 | 
         
            +
                """
         
     | 
| 65 | 
         
            +
                if whisper_model is None:
         
     | 
| 66 | 
         
            +
                    print(DEBUG_PREFIX,"Whisper model not initialized yet.")
         
     | 
| 67 | 
         
            +
                    return ""
         
     | 
| 68 | 
         
            +
                
         
     | 
| 69 | 
         
            +
                q = queue.Queue()
         
     | 
| 70 | 
         
            +
                stream_errors = list()
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def callback(indata, frames, time, status):
         
     | 
| 73 | 
         
            +
                    """This is called (from a separate thread) for each audio block."""
         
     | 
| 74 | 
         
            +
                    if status:
         
     | 
| 75 | 
         
            +
                        print(status, file=sys.stderr)
         
     | 
| 76 | 
         
            +
                        stream_errors.append(status)
         
     | 
| 77 | 
         
            +
                    q.put(bytes(indata))
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                try:
         
     | 
| 80 | 
         
            +
                    device_info = sd.query_devices(device, "input")
         
     | 
| 81 | 
         
            +
                    # soundfile expects an int, sounddevice provides a float:
         
     | 
| 82 | 
         
            +
                    samplerate = int(device_info["default_samplerate"])
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    print(DEBUG_PREFIX, "Start recording from:", device_info["name"], "with samplerate", samplerate)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    with sd.RawInputStream(samplerate=samplerate, blocksize = 8000, device=device, dtype="int16", channels=1, callback=callback):
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                        rec = vosk.KaldiRecognizer(vosk_model, samplerate)
         
     | 
| 89 | 
         
            +
                        full_recording = bytearray()
         
     | 
| 90 | 
         
            +
                        while True:
         
     | 
| 91 | 
         
            +
                            data = q.get()
         
     | 
| 92 | 
         
            +
                            if len(stream_errors) > 0:
         
     | 
| 93 | 
         
            +
                                raise Exception(DEBUG_PREFIX+" Stream errors: "+str(stream_errors))
         
     | 
| 94 | 
         
            +
                            
         
     | 
| 95 | 
         
            +
                            full_recording.extend(data)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                            if rec.AcceptWaveform(data):
         
     | 
| 98 | 
         
            +
                                # Extract transcript string
         
     | 
| 99 | 
         
            +
                                transcript = rec.Result()[14:-3]
         
     | 
| 100 | 
         
            +
                                print(DEBUG_PREFIX, "Transcripted from microphone stream (vosk):", transcript)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                                # ----------------------------------
         
     | 
| 103 | 
         
            +
                                # DEBUG: save recording to wav file
         
     | 
| 104 | 
         
            +
                                # ----------------------------------
         
     | 
| 105 | 
         
            +
                                output_file = convert_bytearray_to_wav_ndarray(input_bytearray=full_recording, sampling_rate=samplerate)
         
     | 
| 106 | 
         
            +
                                sf.write(file=RECORDING_FILE_PATH, data=output_file, samplerate=samplerate)
         
     | 
| 107 | 
         
            +
                                print(DEBUG_PREFIX, "Recorded message saved to", RECORDING_FILE_PATH)
         
     | 
| 108 | 
         
            +
                                
         
     | 
| 109 | 
         
            +
                                # Whisper HACK
         
     | 
| 110 | 
         
            +
                                result = whisper_model.transcribe(RECORDING_FILE_PATH)
         
     | 
| 111 | 
         
            +
                                transcript = result["text"]
         
     | 
| 112 | 
         
            +
                                print(DEBUG_PREFIX, "Transcripted from audio file (whisper):", transcript)
         
     | 
| 113 | 
         
            +
                                # ----------------------------------
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                                return jsonify({"transcript": transcript})
         
     | 
| 116 | 
         
            +
                            #else:
         
     | 
| 117 | 
         
            +
                            #    print(rec.PartialResult())
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                except Exception as e: # No exception observed during test but we never know
         
     | 
| 120 | 
         
            +
                    print(e)
         
     | 
| 121 | 
         
            +
                    abort(500, DEBUG_PREFIX+" Exception occurs while recording")
         
     | 
    	
        modules/speech_recognition/vosk_module.py
    ADDED
    
    | 
         @@ -0,0 +1,77 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            Speech-to-text module based on Vosk for SillyTavern Extras
         
     | 
| 3 | 
         
            +
                - Vosk website: https://alphacephei.com/vosk/
         
     | 
| 4 | 
         
            +
                - Vosk api: https://github.com/alphacep/vosk-api
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            Authors:
         
     | 
| 7 | 
         
            +
                - Tony Ribeiro (https://github.com/Tony-sama)
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            Models are saved into user cache folder, example: C:/Users/toto/.cache/vosk
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            References:
         
     | 
| 12 | 
         
            +
                - Code adapted from: https://github.com/alphacep/vosk-api/blob/master/python/example/test_simple.py
         
     | 
| 13 | 
         
            +
            """
         
     | 
| 14 | 
         
            +
            from flask import jsonify, abort, request
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import wave
         
     | 
| 17 | 
         
            +
            from vosk import Model, KaldiRecognizer, SetLogLevel
         
     | 
| 18 | 
         
            +
            import soundfile
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            DEBUG_PREFIX = "<stt vosk module>"
         
     | 
| 21 | 
         
            +
            RECORDING_FILE_PATH = "stt_test.wav"
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            model = None
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            SetLogLevel(-1)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def load_model(file_path=None):
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                Load given vosk model from file or default to en-us model.
         
     | 
| 30 | 
         
            +
                Download model to user cache folder, example: C:/Users/toto/.cache/vosk
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                if file_path is None:
         
     | 
| 34 | 
         
            +
                    return Model(lang="en-us")
         
     | 
| 35 | 
         
            +
                else:
         
     | 
| 36 | 
         
            +
                    return Model(file_path)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def process_audio():
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
                Transcript request audio file to text using Whisper
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                if model is None:
         
     | 
| 44 | 
         
            +
                    print(DEBUG_PREFIX,"Vosk model not initialized yet.")
         
     | 
| 45 | 
         
            +
                    return ""
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                try:    
         
     | 
| 48 | 
         
            +
                    file = request.files.get('AudioFile')
         
     | 
| 49 | 
         
            +
                    file.save(RECORDING_FILE_PATH)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    # Read and rewrite the file with soundfile
         
     | 
| 52 | 
         
            +
                    data, samplerate = soundfile.read(RECORDING_FILE_PATH)
         
     | 
| 53 | 
         
            +
                    soundfile.write(RECORDING_FILE_PATH, data, samplerate)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    wf = wave.open(RECORDING_FILE_PATH, "rb")
         
     | 
| 56 | 
         
            +
                    if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
         
     | 
| 57 | 
         
            +
                        print("Audio file must be WAV format mono PCM.")
         
     | 
| 58 | 
         
            +
                        abort(500, DEBUG_PREFIX+" Audio file must be WAV format mono PCM.")
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    rec = KaldiRecognizer(model, wf.getframerate())
         
     | 
| 61 | 
         
            +
                    #rec.SetWords(True)
         
     | 
| 62 | 
         
            +
                    #rec.SetPartialWords(True)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    while True:
         
     | 
| 65 | 
         
            +
                        data = wf.readframes(4000)
         
     | 
| 66 | 
         
            +
                        if len(data) == 0:
         
     | 
| 67 | 
         
            +
                            break
         
     | 
| 68 | 
         
            +
                        if rec.AcceptWaveform(data):
         
     | 
| 69 | 
         
            +
                            break
         
     | 
| 70 | 
         
            +
                    
         
     | 
| 71 | 
         
            +
                    transcript = rec.Result()[14:-3]
         
     | 
| 72 | 
         
            +
                    print(DEBUG_PREFIX, "Transcripted from request audio file:", transcript)
         
     | 
| 73 | 
         
            +
                    return jsonify({"transcript": transcript})
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                except Exception as e: # No exception observed during test but we never know
         
     | 
| 76 | 
         
            +
                    print(e)
         
     | 
| 77 | 
         
            +
                    abort(500, DEBUG_PREFIX+" Exception occurs while processing audio")
         
     | 
    	
        modules/speech_recognition/whisper_module.py
    ADDED
    
    | 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            Speech-to-text module based on Whisper for SillyTavern Extras
         
     | 
| 3 | 
         
            +
                - Whisper github: https://github.com/openai/whisper
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            Authors:
         
     | 
| 6 | 
         
            +
                - Tony Ribeiro (https://github.com/Tony-sama)
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            Models are saved into user cache folder, example: C:/Users/toto/.cache/whisper
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            References:
         
     | 
| 11 | 
         
            +
                - Code adapted from:
         
     | 
| 12 | 
         
            +
                    - whisper github: https://github.com/openai/whisper
         
     | 
| 13 | 
         
            +
                    - oobabooga text-generation-webui github: https://github.com/oobabooga/text-generation-webui
         
     | 
| 14 | 
         
            +
            """
         
     | 
| 15 | 
         
            +
            from flask import jsonify, abort, request
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import whisper
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            DEBUG_PREFIX = "<stt whisper module>"
         
     | 
| 20 | 
         
            +
            RECORDING_FILE_PATH = "stt_test.wav"
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            model = None
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def load_model(file_path=None):
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                Load given vosk model from file or default to en-us model.
         
     | 
| 27 | 
         
            +
                Download model to user cache folder, example: C:/Users/toto/.cache/vosk
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                if file_path is None:
         
     | 
| 31 | 
         
            +
                    return whisper.load_model("base.en")
         
     | 
| 32 | 
         
            +
                else:
         
     | 
| 33 | 
         
            +
                    return whisper.load_model(file_path)
         
     | 
| 34 | 
         
            +
                
         
     | 
| 35 | 
         
            +
            def process_audio():
         
     | 
| 36 | 
         
            +
                """
         
     | 
| 37 | 
         
            +
                Transcript request audio file to text using Whisper
         
     | 
| 38 | 
         
            +
                """
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                if model is None:
         
     | 
| 41 | 
         
            +
                    print(DEBUG_PREFIX,"Whisper model not initialized yet.")
         
     | 
| 42 | 
         
            +
                    return ""
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                try:    
         
     | 
| 45 | 
         
            +
                    file = request.files.get('AudioFile')
         
     | 
| 46 | 
         
            +
                    file.save(RECORDING_FILE_PATH)
         
     | 
| 47 | 
         
            +
                      
         
     | 
| 48 | 
         
            +
                    result = model.transcribe(RECORDING_FILE_PATH)
         
     | 
| 49 | 
         
            +
                    transcript = result["text"]
         
     | 
| 50 | 
         
            +
                    print(DEBUG_PREFIX, "Transcripted from audio file (whisper):", transcript)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    return jsonify({"transcript": transcript})
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                except Exception as e: # No exception observed during test but we never know
         
     | 
| 55 | 
         
            +
                    print(e)
         
     | 
| 56 | 
         
            +
                    abort(500, DEBUG_PREFIX+" Exception occurs while processing audio")
         
     | 
    	
        modules/text_to_speech/coqui/coqui_module.py
    ADDED
    
    | 
         @@ -0,0 +1,333 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            Coqui module for SillyTavern Extras
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            Authors:
         
     | 
| 5 | 
         
            +
                - Pyrater (https://github.com/pyrater)
         
     | 
| 6 | 
         
            +
                - Tony Ribeiro (https://github.com/Tony-sama)
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            Models are saved into user cache folder: "C:/Users/<username>/AppData/Local/tts"
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            References:
         
     | 
| 11 | 
         
            +
                - Code adapted from:
         
     | 
| 12 | 
         
            +
                    - Coqui TTS https://tts.readthedocs.io/en/latest/
         
     | 
| 13 | 
         
            +
                    - Audio-webui: https://github.com/gitmylo/audio-webui
         
     | 
| 14 | 
         
            +
            """
         
     | 
| 15 | 
         
            +
            import json
         
     | 
| 16 | 
         
            +
            import os
         
     | 
| 17 | 
         
            +
            import io
         
     | 
| 18 | 
         
            +
            import shutil
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from flask import abort, request, send_file, jsonify
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from TTS.api import TTS
         
     | 
| 23 | 
         
            +
            from TTS.utils.manage import ModelManager
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            from modules.utils import silence_log
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            DEBUG_PREFIX = "<Coqui-TTS module>"
         
     | 
| 28 | 
         
            +
            COQUI_MODELS_PATH = "data/models/coqui/"
         
     | 
| 29 | 
         
            +
            IGNORED_FILES = [".placeholder"]
         
     | 
| 30 | 
         
            +
            COQUI_LOCAL_MODEL_FILE_NAME = "model.pth"
         
     | 
| 31 | 
         
            +
            COQUI_LOCAL_CONFIG_FILE_NAME = "config.json"
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            gpu_mode = False
         
     | 
| 34 | 
         
            +
            is_downloading = False
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def install_model(model_id):
         
     | 
| 37 | 
         
            +
                global gpu_mode
         
     | 
| 38 | 
         
            +
                audio_buffer = io.BytesIO()
         
     | 
| 39 | 
         
            +
                speaker_id = None
         
     | 
| 40 | 
         
            +
                language_id = None
         
     | 
| 41 | 
         
            +
                
         
     | 
| 42 | 
         
            +
                print(DEBUG_PREFIX,"Loading model",model_id)
         
     | 
| 43 | 
         
            +
                try:
         
     | 
| 44 | 
         
            +
                    tts = TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    if tts.is_multi_lingual:
         
     | 
| 47 | 
         
            +
                        language_id = tts.languages[0]
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    if tts.is_multi_speaker:
         
     | 
| 50 | 
         
            +
                        speaker_id =tts.speakers[0]
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    tts.tts_to_file(text="this is a test message", file_path=audio_buffer, speaker=speaker_id, language=language_id)
         
     | 
| 53 | 
         
            +
                except Exception as e:
         
     | 
| 54 | 
         
            +
                    print(DEBUG_PREFIX,"ERROR:", e)
         
     | 
| 55 | 
         
            +
                    print("Model", model_id, "cannot be loaded, maybe wrong model name? Must be one of")
         
     | 
| 56 | 
         
            +
                    for i in TTS.list_models():
         
     | 
| 57 | 
         
            +
                        print(i)
         
     | 
| 58 | 
         
            +
                    return False
         
     | 
| 59 | 
         
            +
                
         
     | 
| 60 | 
         
            +
                print(DEBUG_PREFIX,"Success")
         
     | 
| 61 | 
         
            +
                return True
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def coqui_check_model_state():
         
     | 
| 64 | 
         
            +
                """
         
     | 
| 65 | 
         
            +
                    Check if the requested model is installed on the server machine
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
                try:
         
     | 
| 68 | 
         
            +
                    model_state = "absent"
         
     | 
| 69 | 
         
            +
                    request_json = request.get_json()
         
     | 
| 70 | 
         
            +
                    model_id = request_json["model_id"]
         
     | 
| 71 | 
         
            +
                    
         
     | 
| 72 | 
         
            +
                    print(DEBUG_PREFIX,"Search for model", model_id)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    coqui_models_folder = ModelManager().output_prefix  # models location
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # Check if tts folder exist
         
     | 
| 77 | 
         
            +
                    if os.path.isdir(coqui_models_folder):
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                        installed_models = os.listdir(coqui_models_folder)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                        model_folder_exists = False
         
     | 
| 82 | 
         
            +
                        model_folder = None
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                        for i in installed_models:
         
     | 
| 85 | 
         
            +
                            if model_id == i.replace("--","/",3): # Error with model wrong name
         
     | 
| 86 | 
         
            +
                                model_folder_exists = True
         
     | 
| 87 | 
         
            +
                                model_folder = i
         
     | 
| 88 | 
         
            +
                                print(DEBUG_PREFIX,"Folder found:",model_folder)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                        # Check failed download
         
     | 
| 91 | 
         
            +
                        if model_folder_exists:
         
     | 
| 92 | 
         
            +
                            content = os.listdir(os.path.join(coqui_models_folder,model_folder))
         
     | 
| 93 | 
         
            +
                            print(DEBUG_PREFIX,"Checking content:",content)
         
     | 
| 94 | 
         
            +
                            for i in content:
         
     | 
| 95 | 
         
            +
                                if i == model_folder+".zip":
         
     | 
| 96 | 
         
            +
                                    print("Corrupt installed found, model download must have failed previously")
         
     | 
| 97 | 
         
            +
                                    model_state = "corrupted"
         
     | 
| 98 | 
         
            +
                                    break
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                            if model_state != "corrupted":
         
     | 
| 101 | 
         
            +
                                model_state = "installed"
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    response = json.dumps({"model_state":model_state})
         
     | 
| 104 | 
         
            +
                    return response
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                except Exception as e:
         
     | 
| 107 | 
         
            +
                    print(e)
         
     | 
| 108 | 
         
            +
                    abort(500, DEBUG_PREFIX + " Exception occurs while trying to search for installed model")
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            def coqui_install_model():
         
     | 
| 111 | 
         
            +
                """
         
     | 
| 112 | 
         
            +
                    Install requested model is installed on the server machine
         
     | 
| 113 | 
         
            +
                """
         
     | 
| 114 | 
         
            +
                global gpu_mode
         
     | 
| 115 | 
         
            +
                global is_downloading
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                try:
         
     | 
| 118 | 
         
            +
                    model_installed = False
         
     | 
| 119 | 
         
            +
                    request_json = request.get_json()
         
     | 
| 120 | 
         
            +
                    model_id = request_json["model_id"]
         
     | 
| 121 | 
         
            +
                    action = request_json["action"]
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    print(DEBUG_PREFIX,"Received request",action,"for model",model_id)
         
     | 
| 124 | 
         
            +
                    
         
     | 
| 125 | 
         
            +
                    if (is_downloading):
         
     | 
| 126 | 
         
            +
                        print(DEBUG_PREFIX,"Rejected, already downloading a model")
         
     | 
| 127 | 
         
            +
                        return json.dumps({"status":"downloading"})
         
     | 
| 128 | 
         
            +
                    
         
     | 
| 129 | 
         
            +
                    coqui_models_folder = ModelManager().output_prefix  # models location
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    # Check if tts folder exist
         
     | 
| 132 | 
         
            +
                    if os.path.isdir(coqui_models_folder):
         
     | 
| 133 | 
         
            +
                        installed_models = os.listdir(coqui_models_folder)
         
     | 
| 134 | 
         
            +
                        model_path = None
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                        print(DEBUG_PREFIX,"Found",len(installed_models),"models in",coqui_models_folder)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                        for i in installed_models:
         
     | 
| 139 | 
         
            +
                            if model_id == i.replace("--","/"):
         
     | 
| 140 | 
         
            +
                                model_installed = True
         
     | 
| 141 | 
         
            +
                                model_path = os.path.join(coqui_models_folder,i)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                        if model_installed:
         
     | 
| 144 | 
         
            +
                            print(DEBUG_PREFIX,"model found:", model_id)
         
     | 
| 145 | 
         
            +
                        else:
         
     | 
| 146 | 
         
            +
                            print(DEBUG_PREFIX,"model not found")
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                        if action == "download":
         
     | 
| 149 | 
         
            +
                            if model_installed:
         
     | 
| 150 | 
         
            +
                                abort(500, DEBUG_PREFIX + "Bad request, model already installed.")
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                            is_downloading = True
         
     | 
| 153 | 
         
            +
                            TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
         
     | 
| 154 | 
         
            +
                            is_downloading = False
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                        if action == "repare":
         
     | 
| 157 | 
         
            +
                            if not model_installed:
         
     | 
| 158 | 
         
            +
                                abort(500, DEBUG_PREFIX + " bad request: requesting repare of model not installed")
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                            print(DEBUG_PREFIX,"Deleting corrupted model folder:",model_path)
         
     | 
| 162 | 
         
            +
                            shutil.rmtree(model_path, ignore_errors=True)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    is_downloading = True
         
     | 
| 165 | 
         
            +
                    TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
         
     | 
| 166 | 
         
            +
                    is_downloading = False
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    response = json.dumps({"status":"done"})
         
     | 
| 169 | 
         
            +
                    return response
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                except Exception as e:
         
     | 
| 172 | 
         
            +
                    is_downloading = False
         
     | 
| 173 | 
         
            +
                    print(e)
         
     | 
| 174 | 
         
            +
                    abort(500, DEBUG_PREFIX + " Exception occurs while trying to search for installed model")
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            def coqui_get_local_models():
         
     | 
| 177 | 
         
            +
                """
         
     | 
| 178 | 
         
            +
                Return user local models list in the following format: [language][dataset][name] = TTS_string_id
         
     | 
| 179 | 
         
            +
                """
         
     | 
| 180 | 
         
            +
                try:
         
     | 
| 181 | 
         
            +
                    print(DEBUG_PREFIX, "Received request for list of RVC models")
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    folder_names = os.listdir(COQUI_MODELS_PATH)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    print(DEBUG_PREFIX,"Searching model in",COQUI_MODELS_PATH)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    model_list = []
         
     | 
| 188 | 
         
            +
                    for folder_name in folder_names:
         
     | 
| 189 | 
         
            +
                        folder_path = COQUI_MODELS_PATH+folder_name
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                        if folder_name in IGNORED_FILES:
         
     | 
| 192 | 
         
            +
                            continue
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                        # Must be a folder
         
     | 
| 195 | 
         
            +
                        if not os.path.isdir(folder_path):
         
     | 
| 196 | 
         
            +
                            print("> WARNING:",folder_name,"is not a folder, it should not be there, ignored")
         
     | 
| 197 | 
         
            +
                            continue
         
     | 
| 198 | 
         
            +
                        
         
     | 
| 199 | 
         
            +
                        print("> Found model folder",folder_name)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                        # Check pth
         
     | 
| 202 | 
         
            +
                        valid_folder = False
         
     | 
| 203 | 
         
            +
                        for file_name in os.listdir(folder_path):
         
     | 
| 204 | 
         
            +
                            if file_name.endswith(".pth"):
         
     | 
| 205 | 
         
            +
                                print(" > pth:",file_name)
         
     | 
| 206 | 
         
            +
                                valid_folder = True
         
     | 
| 207 | 
         
            +
                            if file_name.endswith(".config"):
         
     | 
| 208 | 
         
            +
                                print(" > config:",file_name)
         
     | 
| 209 | 
         
            +
                            
         
     | 
| 210 | 
         
            +
                        if valid_folder:
         
     | 
| 211 | 
         
            +
                            print(" > Valid folder added to list")
         
     | 
| 212 | 
         
            +
                            model_list.append(folder_name)
         
     | 
| 213 | 
         
            +
                        else:
         
     | 
| 214 | 
         
            +
                            print(" > WARNING: Missing pth or config file, ignored folder")
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    # Return the list of valid folders
         
     | 
| 217 | 
         
            +
                    response = json.dumps({"models_list":model_list})
         
     | 
| 218 | 
         
            +
                    return response
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                except Exception as e:
         
     | 
| 221 | 
         
            +
                    print(e)
         
     | 
| 222 | 
         
            +
                    abort(500, DEBUG_PREFIX + " Exception occurs while searching for Coqui models.")
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
            def coqui_generate_tts():
         
     | 
| 227 | 
         
            +
                """
         
     | 
| 228 | 
         
            +
                Process request text with the loaded RVC model
         
     | 
| 229 | 
         
            +
                    - expected request: {
         
     | 
| 230 | 
         
            +
                        "text": text,
         
     | 
| 231 | 
         
            +
                        "model_id": voiceId,
         
     | 
| 232 | 
         
            +
                        "language_id": language,
         
     | 
| 233 | 
         
            +
                        "speaker_id": speaker
         
     | 
| 234 | 
         
            +
                    }
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    - model_id formats:
         
     | 
| 237 | 
         
            +
                        - model_type/language/dataset/model_name
         
     | 
| 238 | 
         
            +
                        - model_type/language/dataset/model_name[spearker_id]
         
     | 
| 239 | 
         
            +
                        - model_type/language/dataset/model_name[spearker_id][language_id]
         
     | 
| 240 | 
         
            +
                    - examples:
         
     | 
| 241 | 
         
            +
                        - tts_models/ja/kokoro/tacotron2-DDC
         
     | 
| 242 | 
         
            +
                        - tts_models/en/vctk/vits[0]
         
     | 
| 243 | 
         
            +
                        - tts_models/multilingual/multi-dataset/your_tts[2][1]
         
     | 
| 244 | 
         
            +
                """
         
     | 
| 245 | 
         
            +
                global gpu_mode
         
     | 
| 246 | 
         
            +
                global is_downloading
         
     | 
| 247 | 
         
            +
                audio_buffer = io.BytesIO()
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                try:
         
     | 
| 250 | 
         
            +
                    request_json = request.get_json()
         
     | 
| 251 | 
         
            +
                    #print(request_json)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    print(DEBUG_PREFIX,"Received TTS request for ", request_json)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    if (is_downloading):
         
     | 
| 256 | 
         
            +
                        print(DEBUG_PREFIX,"Rejected, currently downloading a model, cannot perform TTS")
         
     | 
| 257 | 
         
            +
                        abort(500, DEBUG_PREFIX + " Requested TTS while downloading a model")
         
     | 
| 258 | 
         
            +
                    
         
     | 
| 259 | 
         
            +
                    text = request_json["text"]
         
     | 
| 260 | 
         
            +
                    model_name = request_json["model_id"]
         
     | 
| 261 | 
         
            +
                    language_id = None
         
     | 
| 262 | 
         
            +
                    speaker_id =  None
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    # Local model
         
     | 
| 265 | 
         
            +
                    model_type = model_name.split("/")[0]
         
     | 
| 266 | 
         
            +
                    if model_type == "local":
         
     | 
| 267 | 
         
            +
                        return generate_tts_local(model_name.split("/")[1], text)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    if request_json["language_id"] != "none":
         
     | 
| 271 | 
         
            +
                        language_id = request_json["language_id"]
         
     | 
| 272 | 
         
            +
                    
         
     | 
| 273 | 
         
            +
                    if request_json["speaker_id"] != "none":
         
     | 
| 274 | 
         
            +
                        speaker_id = request_json["speaker_id"]
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    print(DEBUG_PREFIX,"Loading tts \n- model", model_name, "\n - speaker_id: ",speaker_id,"\n - language_id: ",language_id, "\n - using",("GPU" if gpu_mode else "CPU"))
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    is_downloading = True
         
     | 
| 279 | 
         
            +
                    tts = TTS(model_name=model_name, progress_bar=True, gpu=gpu_mode)
         
     | 
| 280 | 
         
            +
                    is_downloading = False
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    if tts.is_multi_lingual:
         
     | 
| 283 | 
         
            +
                        if language_id is None:
         
     | 
| 284 | 
         
            +
                            abort(400, DEBUG_PREFIX + " Requested model "+model_name+" is multi-lingual but no language id provided")
         
     | 
| 285 | 
         
            +
                        language_id = tts.languages[int(language_id)]
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    if tts.is_multi_speaker:
         
     | 
| 288 | 
         
            +
                        if speaker_id is None:
         
     | 
| 289 | 
         
            +
                            abort(400, DEBUG_PREFIX + " Requested model "+model_name+" is multi-speaker but no speaker id provided")
         
     | 
| 290 | 
         
            +
                        speaker_id =tts.speakers[int(speaker_id)]
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    tts.tts_to_file(text=text, file_path=audio_buffer, speaker=speaker_id, language=language_id)
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                    print(DEBUG_PREFIX, "Success, saved to",audio_buffer)
         
     | 
| 295 | 
         
            +
                    
         
     | 
| 296 | 
         
            +
                    # Return the output_audio_path object as a response
         
     | 
| 297 | 
         
            +
                    response = send_file(audio_buffer, mimetype="audio/x-wav")
         
     | 
| 298 | 
         
            +
                    audio_buffer = io.BytesIO()
         
     | 
| 299 | 
         
            +
                    
         
     | 
| 300 | 
         
            +
                    return response
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                except Exception as e:
         
     | 
| 303 | 
         
            +
                    print(e)
         
     | 
| 304 | 
         
            +
                    abort(500, DEBUG_PREFIX + " Exception occurs while trying to process request "+str(request_json))
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
            def generate_tts_local(model_folder, text):
         
     | 
| 307 | 
         
            +
                """
         
     | 
| 308 | 
         
            +
                Generate tts using local coqui model
         
     | 
| 309 | 
         
            +
                """
         
     | 
| 310 | 
         
            +
                audio_buffer = io.BytesIO()
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                print(DEBUG_PREFIX,"Request for tts from local coqui model",model_folder)
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                model_path = os.path.join(COQUI_MODELS_PATH,model_folder,COQUI_LOCAL_MODEL_FILE_NAME)
         
     | 
| 315 | 
         
            +
                config_path = os.path.join(COQUI_MODELS_PATH,model_folder,COQUI_LOCAL_CONFIG_FILE_NAME)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                if not os.path.exists(model_path):
         
     | 
| 318 | 
         
            +
                    raise ValueError("File does not exists:",model_path)
         
     | 
| 319 | 
         
            +
                
         
     | 
| 320 | 
         
            +
                if not os.path.exists(config_path):
         
     | 
| 321 | 
         
            +
                    raise ValueError("File does not exists:",config_path)
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                print(DEBUG_PREFIX,"Loading local tts model", model_path,"using",("GPU" if gpu_mode else "CPU"))
         
     | 
| 324 | 
         
            +
                tts = TTS(model_path=model_path, config_path=config_path, progress_bar=True, gpu=gpu_mode)
         
     | 
| 325 | 
         
            +
                tts.tts_to_file(text=text, file_path=audio_buffer)
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                print(DEBUG_PREFIX, "Success, saved to",audio_buffer)
         
     | 
| 328 | 
         
            +
                    
         
     | 
| 329 | 
         
            +
                # Return the output_audio_path object as a response
         
     | 
| 330 | 
         
            +
                response = send_file(audio_buffer, mimetype="audio/x-wav")
         
     | 
| 331 | 
         
            +
                audio_buffer = io.BytesIO()
         
     | 
| 332 | 
         
            +
                
         
     | 
| 333 | 
         
            +
                return response
         
     | 
    	
        modules/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
| 2 | 
         
            +
            from contextlib import contextmanager
         
     | 
| 3 | 
         
            +
            import sys
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            @contextmanager
         
     | 
| 6 | 
         
            +
            def silence_log():
         
     | 
| 7 | 
         
            +
                old_stdout = sys.stdout
         
     | 
| 8 | 
         
            +
                old_stderr = sys.stderr
         
     | 
| 9 | 
         
            +
                try:
         
     | 
| 10 | 
         
            +
                    with open(os.devnull, "w") as new_target:
         
     | 
| 11 | 
         
            +
                        sys.stdout = new_target
         
     | 
| 12 | 
         
            +
                        yield new_target
         
     | 
| 13 | 
         
            +
                finally:
         
     | 
| 14 | 
         
            +
                    sys.stdout = old_stdout
         
     | 
| 15 | 
         
            +
                    sys.stderr = old_stderr
         
     | 
    	
        modules/voice_conversion/fairseq/LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            MIT License
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 6 | 
         
            +
            of this software and associated documentation files (the "Software"), to deal
         
     | 
| 7 | 
         
            +
            in the Software without restriction, including without limitation the rights
         
     | 
| 8 | 
         
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 9 | 
         
            +
            copies of the Software, and to permit persons to whom the Software is
         
     | 
| 10 | 
         
            +
            furnished to do so, subject to the following conditions:
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            The above copyright notice and this permission notice shall be included in all
         
     | 
| 13 | 
         
            +
            copies or substantial portions of the Software.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 16 | 
         
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 17 | 
         
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 18 | 
         
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 19 | 
         
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 20 | 
         
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 21 | 
         
            +
            SOFTWARE.
         
     | 
    	
        modules/voice_conversion/fairseq/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
            """isort:skip_file"""
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
            import sys
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            try:
         
     | 
| 11 | 
         
            +
                from .version import __version__  # noqa
         
     | 
| 12 | 
         
            +
            except ImportError:
         
     | 
| 13 | 
         
            +
                version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
         
     | 
| 14 | 
         
            +
                with open(version_txt) as f:
         
     | 
| 15 | 
         
            +
                    __version__ = f.read().strip()
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            __all__ = ["pdb"]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            # backwards compatibility to support `from fairseq.X import Y`
         
     | 
| 20 | 
         
            +
            from fairseq.distributed import utils as distributed_utils
         
     | 
| 21 | 
         
            +
            from fairseq.logging import meters, metrics, progress_bar  # noqa
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            sys.modules["fairseq.distributed_utils"] = distributed_utils
         
     | 
| 24 | 
         
            +
            sys.modules["fairseq.meters"] = meters
         
     | 
| 25 | 
         
            +
            sys.modules["fairseq.metrics"] = metrics
         
     | 
| 26 | 
         
            +
            sys.modules["fairseq.progress_bar"] = progress_bar
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            # initialize hydra
         
     | 
| 29 | 
         
            +
            #from fairseq.dataclass.initialize import hydra_init
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            #hydra_init()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            #import fairseq.criterions  # noqa
         
     | 
| 34 | 
         
            +
            #import fairseq.distributed  # noqa
         
     | 
| 35 | 
         
            +
            #import fairseq.models  # noqa
         
     | 
| 36 | 
         
            +
            #import fairseq.modules  # noqa
         
     | 
| 37 | 
         
            +
            #import fairseq.optim  # noqa
         
     | 
| 38 | 
         
            +
            #import fairseq.optim.lr_scheduler  # noqa
         
     | 
| 39 | 
         
            +
            #import fairseq.pdb  # noqa
         
     | 
| 40 | 
         
            +
            #import fairseq.scoring  # noqa
         
     | 
| 41 | 
         
            +
            #import fairseq.tasks  # noqa
         
     | 
| 42 | 
         
            +
            #import fairseq.token_generation_constraints  # noqa
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            #import fairseq.benchmark  # noqa
         
     | 
| 45 | 
         
            +
            #import fairseq.model_parallel  # noqa
         
     | 
    	
        modules/voice_conversion/fairseq/binarizer.py
    ADDED
    
    | 
         @@ -0,0 +1,381 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import logging
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
            import typing as tp
         
     | 
| 9 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 10 | 
         
            +
            from collections import Counter
         
     | 
| 11 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 12 | 
         
            +
            from multiprocessing import Pool
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from fairseq.data import Dictionary, indexed_dataset
         
     | 
| 17 | 
         
            +
            from fairseq.file_chunker_utils import Chunker, find_offsets
         
     | 
| 18 | 
         
            +
            from fairseq.file_io import PathManager
         
     | 
| 19 | 
         
            +
            from fairseq.tokenizer import tokenize_line
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            logger = logging.getLogger("binarizer")
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            @dataclass
         
     | 
| 25 | 
         
            +
            class BinarizeSummary:
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
                Keep track of what's going on in the binarizer
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                num_seq: int = 0
         
     | 
| 31 | 
         
            +
                replaced: tp.Optional[Counter] = None
         
     | 
| 32 | 
         
            +
                num_tok: int = 0
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                @property
         
     | 
| 35 | 
         
            +
                def num_replaced(self) -> int:
         
     | 
| 36 | 
         
            +
                    if self.replaced is None:
         
     | 
| 37 | 
         
            +
                        return 0
         
     | 
| 38 | 
         
            +
                    return sum(self.replaced.values())
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                @property
         
     | 
| 41 | 
         
            +
                def replaced_percent(self) -> float:
         
     | 
| 42 | 
         
            +
                    return 100 * self.num_replaced / self.num_tok
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def __str__(self) -> str:
         
     | 
| 45 | 
         
            +
                    base = f"{self.num_seq} sents, {self.num_tok} tokens"
         
     | 
| 46 | 
         
            +
                    if self.replaced is None:
         
     | 
| 47 | 
         
            +
                        return base
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    return f"{base}, {self.replaced_percent:.3}% replaced"
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def merge(self, other: "BinarizeSummary"):
         
     | 
| 52 | 
         
            +
                    replaced = None
         
     | 
| 53 | 
         
            +
                    if self.replaced is not None:
         
     | 
| 54 | 
         
            +
                        replaced = self.replaced
         
     | 
| 55 | 
         
            +
                    if other.replaced is not None:
         
     | 
| 56 | 
         
            +
                        if replaced is None:
         
     | 
| 57 | 
         
            +
                            replaced = other.replaced
         
     | 
| 58 | 
         
            +
                        else:
         
     | 
| 59 | 
         
            +
                            replaced += other.replaced
         
     | 
| 60 | 
         
            +
                    self.replaced = replaced
         
     | 
| 61 | 
         
            +
                    self.num_seq += other.num_seq
         
     | 
| 62 | 
         
            +
                    self.num_tok += other.num_tok
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            class Binarizer(ABC):
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
                a binarizer describes how to take a string and build a tensor out of it
         
     | 
| 68 | 
         
            +
                """
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                @abstractmethod
         
     | 
| 71 | 
         
            +
                def binarize_line(
         
     | 
| 72 | 
         
            +
                    self,
         
     | 
| 73 | 
         
            +
                    line: str,
         
     | 
| 74 | 
         
            +
                    summary: BinarizeSummary,
         
     | 
| 75 | 
         
            +
                ) -> torch.IntTensor:
         
     | 
| 76 | 
         
            +
                    ...
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def _worker_prefix(output_prefix: str, worker_id: int):
         
     | 
| 80 | 
         
            +
                return f"{output_prefix}.pt{worker_id}"
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            class FileBinarizer:
         
     | 
| 84 | 
         
            +
                """
         
     | 
| 85 | 
         
            +
                An file binarizer can take a file, tokenize it, and binarize each line to a tensor
         
     | 
| 86 | 
         
            +
                """
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                @classmethod
         
     | 
| 89 | 
         
            +
                def multiprocess_dataset(
         
     | 
| 90 | 
         
            +
                    cls,
         
     | 
| 91 | 
         
            +
                    input_file: str,
         
     | 
| 92 | 
         
            +
                    dataset_impl: str,
         
     | 
| 93 | 
         
            +
                    binarizer: Binarizer,
         
     | 
| 94 | 
         
            +
                    output_prefix: str,
         
     | 
| 95 | 
         
            +
                    vocab_size=None,
         
     | 
| 96 | 
         
            +
                    num_workers=1,
         
     | 
| 97 | 
         
            +
                ) -> BinarizeSummary:
         
     | 
| 98 | 
         
            +
                    final_summary = BinarizeSummary()
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    offsets = find_offsets(input_file, num_workers)
         
     | 
| 101 | 
         
            +
                    # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
         
     | 
| 102 | 
         
            +
                    # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
         
     | 
| 103 | 
         
            +
                    # we zip the list with itself shifted by one to get all the pairs.
         
     | 
| 104 | 
         
            +
                    (first_chunk, *more_chunks) = zip(offsets, offsets[1:])
         
     | 
| 105 | 
         
            +
                    pool = None
         
     | 
| 106 | 
         
            +
                    if num_workers > 1:
         
     | 
| 107 | 
         
            +
                        pool = Pool(processes=num_workers - 1)
         
     | 
| 108 | 
         
            +
                        worker_results = [
         
     | 
| 109 | 
         
            +
                            pool.apply_async(
         
     | 
| 110 | 
         
            +
                                cls._binarize_chunk_and_finalize,
         
     | 
| 111 | 
         
            +
                                args=(
         
     | 
| 112 | 
         
            +
                                    binarizer,
         
     | 
| 113 | 
         
            +
                                    input_file,
         
     | 
| 114 | 
         
            +
                                    start_offset,
         
     | 
| 115 | 
         
            +
                                    end_offset,
         
     | 
| 116 | 
         
            +
                                    _worker_prefix(
         
     | 
| 117 | 
         
            +
                                        output_prefix,
         
     | 
| 118 | 
         
            +
                                        worker_id,
         
     | 
| 119 | 
         
            +
                                    ),
         
     | 
| 120 | 
         
            +
                                    dataset_impl,
         
     | 
| 121 | 
         
            +
                                ),
         
     | 
| 122 | 
         
            +
                                kwds={
         
     | 
| 123 | 
         
            +
                                    "vocab_size": vocab_size,
         
     | 
| 124 | 
         
            +
                                }
         
     | 
| 125 | 
         
            +
                                if vocab_size is not None
         
     | 
| 126 | 
         
            +
                                else {},
         
     | 
| 127 | 
         
            +
                            )
         
     | 
| 128 | 
         
            +
                            for worker_id, (start_offset, end_offset) in enumerate(
         
     | 
| 129 | 
         
            +
                                more_chunks, start=1
         
     | 
| 130 | 
         
            +
                            )
         
     | 
| 131 | 
         
            +
                        ]
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                        pool.close()
         
     | 
| 134 | 
         
            +
                        pool.join()
         
     | 
| 135 | 
         
            +
                        for r in worker_results:
         
     | 
| 136 | 
         
            +
                            summ = r.get()
         
     | 
| 137 | 
         
            +
                            final_summary.merge(summ)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    # do not close the bin file as we need to merge the worker results in
         
     | 
| 140 | 
         
            +
                    final_ds, summ = cls._binarize_file_chunk(
         
     | 
| 141 | 
         
            +
                        binarizer,
         
     | 
| 142 | 
         
            +
                        input_file,
         
     | 
| 143 | 
         
            +
                        offset_start=first_chunk[0],
         
     | 
| 144 | 
         
            +
                        offset_end=first_chunk[1],
         
     | 
| 145 | 
         
            +
                        output_prefix=output_prefix,
         
     | 
| 146 | 
         
            +
                        dataset_impl=dataset_impl,
         
     | 
| 147 | 
         
            +
                        vocab_size=vocab_size if vocab_size is not None else None,
         
     | 
| 148 | 
         
            +
                    )
         
     | 
| 149 | 
         
            +
                    final_summary.merge(summ)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    if num_workers > 1:
         
     | 
| 152 | 
         
            +
                        for worker_id in range(1, num_workers):
         
     | 
| 153 | 
         
            +
                            # merge the worker outputs
         
     | 
| 154 | 
         
            +
                            worker_output_prefix = _worker_prefix(
         
     | 
| 155 | 
         
            +
                                output_prefix,
         
     | 
| 156 | 
         
            +
                                worker_id,
         
     | 
| 157 | 
         
            +
                            )
         
     | 
| 158 | 
         
            +
                            final_ds.merge_file_(worker_output_prefix)
         
     | 
| 159 | 
         
            +
                            try:
         
     | 
| 160 | 
         
            +
                                os.remove(indexed_dataset.data_file_path(worker_output_prefix))
         
     | 
| 161 | 
         
            +
                                os.remove(indexed_dataset.index_file_path(worker_output_prefix))
         
     | 
| 162 | 
         
            +
                            except Exception as e:
         
     | 
| 163 | 
         
            +
                                logger.error(
         
     | 
| 164 | 
         
            +
                                    f"couldn't remove {worker_output_prefix}.*", exc_info=e
         
     | 
| 165 | 
         
            +
                                )
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    #  now we can close the file
         
     | 
| 168 | 
         
            +
                    idx_file = indexed_dataset.index_file_path(output_prefix)
         
     | 
| 169 | 
         
            +
                    final_ds.finalize(idx_file)
         
     | 
| 170 | 
         
            +
                    return final_summary
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                @staticmethod
         
     | 
| 173 | 
         
            +
                def _binarize_file_chunk(
         
     | 
| 174 | 
         
            +
                    binarizer: Binarizer,
         
     | 
| 175 | 
         
            +
                    filename: str,
         
     | 
| 176 | 
         
            +
                    offset_start: int,
         
     | 
| 177 | 
         
            +
                    offset_end: int,
         
     | 
| 178 | 
         
            +
                    output_prefix: str,
         
     | 
| 179 | 
         
            +
                    dataset_impl: str,
         
     | 
| 180 | 
         
            +
                    vocab_size=None,
         
     | 
| 181 | 
         
            +
                ) -> tp.Tuple[tp.Any, BinarizeSummary]:  # (dataset builder, BinarizeSummary)
         
     | 
| 182 | 
         
            +
                    """
         
     | 
| 183 | 
         
            +
                    creates a dataset builder and append binarized items to it. This function does not
         
     | 
| 184 | 
         
            +
                    finalize the builder, this is useful if you want to do other things with your bin file
         
     | 
| 185 | 
         
            +
                    like appending/merging other files
         
     | 
| 186 | 
         
            +
                    """
         
     | 
| 187 | 
         
            +
                    bin_file = indexed_dataset.data_file_path(output_prefix)
         
     | 
| 188 | 
         
            +
                    ds = indexed_dataset.make_builder(
         
     | 
| 189 | 
         
            +
                        bin_file,
         
     | 
| 190 | 
         
            +
                        impl=dataset_impl,
         
     | 
| 191 | 
         
            +
                        vocab_size=vocab_size,
         
     | 
| 192 | 
         
            +
                    )
         
     | 
| 193 | 
         
            +
                    summary = BinarizeSummary()
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                    with Chunker(
         
     | 
| 196 | 
         
            +
                        PathManager.get_local_path(filename), offset_start, offset_end
         
     | 
| 197 | 
         
            +
                    ) as line_iterator:
         
     | 
| 198 | 
         
            +
                        for line in line_iterator:
         
     | 
| 199 | 
         
            +
                            ds.add_item(binarizer.binarize_line(line, summary))
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    return ds, summary
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                @classmethod
         
     | 
| 204 | 
         
            +
                def _binarize_chunk_and_finalize(
         
     | 
| 205 | 
         
            +
                    cls,
         
     | 
| 206 | 
         
            +
                    binarizer: Binarizer,
         
     | 
| 207 | 
         
            +
                    filename: str,
         
     | 
| 208 | 
         
            +
                    offset_start: int,
         
     | 
| 209 | 
         
            +
                    offset_end: int,
         
     | 
| 210 | 
         
            +
                    output_prefix: str,
         
     | 
| 211 | 
         
            +
                    dataset_impl: str,
         
     | 
| 212 | 
         
            +
                    vocab_size=None,
         
     | 
| 213 | 
         
            +
                ):
         
     | 
| 214 | 
         
            +
                    """
         
     | 
| 215 | 
         
            +
                    same as above, but also finalizes the builder
         
     | 
| 216 | 
         
            +
                    """
         
     | 
| 217 | 
         
            +
                    ds, summ = cls._binarize_file_chunk(
         
     | 
| 218 | 
         
            +
                        binarizer,
         
     | 
| 219 | 
         
            +
                        filename,
         
     | 
| 220 | 
         
            +
                        offset_start,
         
     | 
| 221 | 
         
            +
                        offset_end,
         
     | 
| 222 | 
         
            +
                        output_prefix,
         
     | 
| 223 | 
         
            +
                        dataset_impl,
         
     | 
| 224 | 
         
            +
                        vocab_size=vocab_size,
         
     | 
| 225 | 
         
            +
                    )
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    idx_file = indexed_dataset.index_file_path(output_prefix)
         
     | 
| 228 | 
         
            +
                    ds.finalize(idx_file)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    return summ
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
            class VocabularyDatasetBinarizer(Binarizer):
         
     | 
| 234 | 
         
            +
                """
         
     | 
| 235 | 
         
            +
                Takes a Dictionary/Vocabulary, assign ids to each
         
     | 
| 236 | 
         
            +
                token using the dictionary encode_line function.
         
     | 
| 237 | 
         
            +
                """
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                def __init__(
         
     | 
| 240 | 
         
            +
                    self,
         
     | 
| 241 | 
         
            +
                    dict: Dictionary,
         
     | 
| 242 | 
         
            +
                    tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
         
     | 
| 243 | 
         
            +
                    append_eos: bool = True,
         
     | 
| 244 | 
         
            +
                    reverse_order: bool = False,
         
     | 
| 245 | 
         
            +
                    already_numberized: bool = False,
         
     | 
| 246 | 
         
            +
                ) -> None:
         
     | 
| 247 | 
         
            +
                    self.dict = dict
         
     | 
| 248 | 
         
            +
                    self.tokenize = tokenize
         
     | 
| 249 | 
         
            +
                    self.append_eos = append_eos
         
     | 
| 250 | 
         
            +
                    self.reverse_order = reverse_order
         
     | 
| 251 | 
         
            +
                    self.already_numberized = already_numberized
         
     | 
| 252 | 
         
            +
                    super().__init__()
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                def binarize_line(
         
     | 
| 255 | 
         
            +
                    self,
         
     | 
| 256 | 
         
            +
                    line: str,
         
     | 
| 257 | 
         
            +
                    summary: BinarizeSummary,
         
     | 
| 258 | 
         
            +
                ):
         
     | 
| 259 | 
         
            +
                    if summary.replaced is None:
         
     | 
| 260 | 
         
            +
                        summary.replaced = Counter()
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    def replaced_consumer(word, idx):
         
     | 
| 263 | 
         
            +
                        if idx == self.dict.unk_index and word != self.dict.unk_word:
         
     | 
| 264 | 
         
            +
                            summary.replaced.update([word])
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    if self.already_numberized:
         
     | 
| 267 | 
         
            +
                        id_strings = line.strip().split()
         
     | 
| 268 | 
         
            +
                        id_list = [int(id_string) for id_string in id_strings]
         
     | 
| 269 | 
         
            +
                        if self.reverse_order:
         
     | 
| 270 | 
         
            +
                            id_list.reverse()
         
     | 
| 271 | 
         
            +
                        if self.append_eos:
         
     | 
| 272 | 
         
            +
                            id_list.append(self.dict.eos())
         
     | 
| 273 | 
         
            +
                        ids = torch.IntTensor(id_list)
         
     | 
| 274 | 
         
            +
                    else:
         
     | 
| 275 | 
         
            +
                        ids = self.dict.encode_line(
         
     | 
| 276 | 
         
            +
                            line=line,
         
     | 
| 277 | 
         
            +
                            line_tokenizer=self.tokenize,
         
     | 
| 278 | 
         
            +
                            add_if_not_exist=False,
         
     | 
| 279 | 
         
            +
                            consumer=replaced_consumer,
         
     | 
| 280 | 
         
            +
                            append_eos=self.append_eos,
         
     | 
| 281 | 
         
            +
                            reverse_order=self.reverse_order,
         
     | 
| 282 | 
         
            +
                        )
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    summary.num_seq += 1
         
     | 
| 285 | 
         
            +
                    summary.num_tok += len(ids)
         
     | 
| 286 | 
         
            +
                    return ids
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
            class AlignmentDatasetBinarizer(Binarizer):
         
     | 
| 290 | 
         
            +
                """
         
     | 
| 291 | 
         
            +
                binarize by parsing a set of alignments and packing
         
     | 
| 292 | 
         
            +
                them in a tensor (see utils.parse_alignment)
         
     | 
| 293 | 
         
            +
                """
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                def __init__(
         
     | 
| 296 | 
         
            +
                    self,
         
     | 
| 297 | 
         
            +
                    alignment_parser: tp.Callable[[str], torch.IntTensor],
         
     | 
| 298 | 
         
            +
                ) -> None:
         
     | 
| 299 | 
         
            +
                    super().__init__()
         
     | 
| 300 | 
         
            +
                    self.alignment_parser = alignment_parser
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                def binarize_line(
         
     | 
| 303 | 
         
            +
                    self,
         
     | 
| 304 | 
         
            +
                    line: str,
         
     | 
| 305 | 
         
            +
                    summary: BinarizeSummary,
         
     | 
| 306 | 
         
            +
                ):
         
     | 
| 307 | 
         
            +
                    ids = self.alignment_parser(line)
         
     | 
| 308 | 
         
            +
                    summary.num_seq += 1
         
     | 
| 309 | 
         
            +
                    summary.num_tok += len(ids)
         
     | 
| 310 | 
         
            +
                    return ids
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
            class LegacyBinarizer:
         
     | 
| 314 | 
         
            +
                @classmethod
         
     | 
| 315 | 
         
            +
                def binarize(
         
     | 
| 316 | 
         
            +
                    cls,
         
     | 
| 317 | 
         
            +
                    filename: str,
         
     | 
| 318 | 
         
            +
                    dico: Dictionary,
         
     | 
| 319 | 
         
            +
                    consumer: tp.Callable[[torch.IntTensor], None],
         
     | 
| 320 | 
         
            +
                    tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
         
     | 
| 321 | 
         
            +
                    append_eos: bool = True,
         
     | 
| 322 | 
         
            +
                    reverse_order: bool = False,
         
     | 
| 323 | 
         
            +
                    offset: int = 0,
         
     | 
| 324 | 
         
            +
                    end: int = -1,
         
     | 
| 325 | 
         
            +
                    already_numberized: bool = False,
         
     | 
| 326 | 
         
            +
                ) -> tp.Dict[str, int]:
         
     | 
| 327 | 
         
            +
                    binarizer = VocabularyDatasetBinarizer(
         
     | 
| 328 | 
         
            +
                        dict=dico,
         
     | 
| 329 | 
         
            +
                        tokenize=tokenize,
         
     | 
| 330 | 
         
            +
                        append_eos=append_eos,
         
     | 
| 331 | 
         
            +
                        reverse_order=reverse_order,
         
     | 
| 332 | 
         
            +
                        already_numberized=already_numberized,
         
     | 
| 333 | 
         
            +
                    )
         
     | 
| 334 | 
         
            +
                    return cls._consume_file(
         
     | 
| 335 | 
         
            +
                        filename,
         
     | 
| 336 | 
         
            +
                        binarizer,
         
     | 
| 337 | 
         
            +
                        consumer,
         
     | 
| 338 | 
         
            +
                        offset_start=offset,
         
     | 
| 339 | 
         
            +
                        offset_end=end,
         
     | 
| 340 | 
         
            +
                    )
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                @classmethod
         
     | 
| 343 | 
         
            +
                def binarize_alignments(
         
     | 
| 344 | 
         
            +
                    cls,
         
     | 
| 345 | 
         
            +
                    filename: str,
         
     | 
| 346 | 
         
            +
                    alignment_parser: tp.Callable[[str], torch.IntTensor],
         
     | 
| 347 | 
         
            +
                    consumer: tp.Callable[[torch.IntTensor], None],
         
     | 
| 348 | 
         
            +
                    offset: int = 0,
         
     | 
| 349 | 
         
            +
                    end: int = -1,
         
     | 
| 350 | 
         
            +
                ) -> tp.Dict[str, int]:
         
     | 
| 351 | 
         
            +
                    binarizer = AlignmentDatasetBinarizer(alignment_parser)
         
     | 
| 352 | 
         
            +
                    return cls._consume_file(
         
     | 
| 353 | 
         
            +
                        filename,
         
     | 
| 354 | 
         
            +
                        binarizer,
         
     | 
| 355 | 
         
            +
                        consumer,
         
     | 
| 356 | 
         
            +
                        offset_start=offset,
         
     | 
| 357 | 
         
            +
                        offset_end=end,
         
     | 
| 358 | 
         
            +
                    )
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                @staticmethod
         
     | 
| 361 | 
         
            +
                def _consume_file(
         
     | 
| 362 | 
         
            +
                    filename: str,
         
     | 
| 363 | 
         
            +
                    binarizer: Binarizer,
         
     | 
| 364 | 
         
            +
                    consumer: tp.Callable[[torch.IntTensor], None],
         
     | 
| 365 | 
         
            +
                    offset_start: int,
         
     | 
| 366 | 
         
            +
                    offset_end: int,
         
     | 
| 367 | 
         
            +
                ) -> tp.Dict[str, int]:
         
     | 
| 368 | 
         
            +
                    summary = BinarizeSummary()
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    with Chunker(
         
     | 
| 371 | 
         
            +
                        PathManager.get_local_path(filename), offset_start, offset_end
         
     | 
| 372 | 
         
            +
                    ) as line_iterator:
         
     | 
| 373 | 
         
            +
                        for line in line_iterator:
         
     | 
| 374 | 
         
            +
                            consumer(binarizer.binarize_line(line, summary))
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                    return {
         
     | 
| 377 | 
         
            +
                        "nseq": summary.num_seq,
         
     | 
| 378 | 
         
            +
                        "nunk": summary.num_replaced,
         
     | 
| 379 | 
         
            +
                        "ntok": summary.num_tok,
         
     | 
| 380 | 
         
            +
                        "replaced": summary.replaced,
         
     | 
| 381 | 
         
            +
                    }
         
     | 
    	
        modules/voice_conversion/fairseq/checkpoint_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,905 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import ast
         
     | 
| 7 | 
         
            +
            import collections
         
     | 
| 8 | 
         
            +
            import contextlib
         
     | 
| 9 | 
         
            +
            import inspect
         
     | 
| 10 | 
         
            +
            import logging
         
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            import re
         
     | 
| 13 | 
         
            +
            import time
         
     | 
| 14 | 
         
            +
            import traceback
         
     | 
| 15 | 
         
            +
            from collections import OrderedDict
         
     | 
| 16 | 
         
            +
            from pathlib import Path
         
     | 
| 17 | 
         
            +
            from typing import Any, Dict, Optional, Union
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import numpy as np
         
     | 
| 20 | 
         
            +
            import torch
         
     | 
| 21 | 
         
            +
            from fairseq.data import data_utils
         
     | 
| 22 | 
         
            +
            from fairseq.dataclass.configs import CheckpointConfig
         
     | 
| 23 | 
         
            +
            from fairseq.dataclass.utils import (
         
     | 
| 24 | 
         
            +
                convert_namespace_to_omegaconf,
         
     | 
| 25 | 
         
            +
                overwrite_args_by_name,
         
     | 
| 26 | 
         
            +
            )
         
     | 
| 27 | 
         
            +
            from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
         
     | 
| 28 | 
         
            +
            from fairseq.file_io import PathManager
         
     | 
| 29 | 
         
            +
            from fairseq.models import FairseqDecoder, FairseqEncoder
         
     | 
| 30 | 
         
            +
            from omegaconf import DictConfig, OmegaConf, open_dict
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
         
     | 
| 36 | 
         
            +
                from fairseq import meters
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                # only one worker should attempt to create the required dir
         
     | 
| 39 | 
         
            +
                if trainer.data_parallel_rank == 0:
         
     | 
| 40 | 
         
            +
                    os.makedirs(cfg.save_dir, exist_ok=True)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                prev_best = getattr(save_checkpoint, "best", val_loss)
         
     | 
| 43 | 
         
            +
                if val_loss is not None:
         
     | 
| 44 | 
         
            +
                    best_function = max if cfg.maximize_best_checkpoint_metric else min
         
     | 
| 45 | 
         
            +
                    save_checkpoint.best = best_function(val_loss, prev_best)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                if cfg.no_save:
         
     | 
| 48 | 
         
            +
                    return
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                trainer.consolidate_optimizer()  # TODO(SS): do we need this if no_save_optimizer_state
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                if not trainer.should_save_checkpoint_on_current_rank:
         
     | 
| 53 | 
         
            +
                    if trainer.always_call_state_dict_during_save_checkpoint:
         
     | 
| 54 | 
         
            +
                        trainer.state_dict()
         
     | 
| 55 | 
         
            +
                    return
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                write_timer = meters.StopwatchMeter()
         
     | 
| 58 | 
         
            +
                write_timer.start()
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                epoch = epoch_itr.epoch
         
     | 
| 61 | 
         
            +
                end_of_epoch = epoch_itr.end_of_epoch()
         
     | 
| 62 | 
         
            +
                updates = trainer.get_num_updates()
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def is_better(a, b):
         
     | 
| 67 | 
         
            +
                    return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                suffix = trainer.checkpoint_suffix
         
     | 
| 70 | 
         
            +
                checkpoint_conds = collections.OrderedDict()
         
     | 
| 71 | 
         
            +
                checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
         
     | 
| 72 | 
         
            +
                    end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
         
     | 
| 73 | 
         
            +
                )
         
     | 
| 74 | 
         
            +
                checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
         
     | 
| 75 | 
         
            +
                    not end_of_epoch
         
     | 
| 76 | 
         
            +
                    and cfg.save_interval_updates > 0
         
     | 
| 77 | 
         
            +
                    and updates % cfg.save_interval_updates == 0
         
     | 
| 78 | 
         
            +
                )
         
     | 
| 79 | 
         
            +
                checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
         
     | 
| 80 | 
         
            +
                    not hasattr(save_checkpoint, "best")
         
     | 
| 81 | 
         
            +
                    or is_better(val_loss, save_checkpoint.best)
         
     | 
| 82 | 
         
            +
                )
         
     | 
| 83 | 
         
            +
                if val_loss is not None and cfg.keep_best_checkpoints > 0:
         
     | 
| 84 | 
         
            +
                    worst_best = getattr(save_checkpoint, "best", None)
         
     | 
| 85 | 
         
            +
                    chkpts = checkpoint_paths(
         
     | 
| 86 | 
         
            +
                        cfg.save_dir,
         
     | 
| 87 | 
         
            +
                        pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
         
     | 
| 88 | 
         
            +
                            cfg.best_checkpoint_metric, suffix
         
     | 
| 89 | 
         
            +
                        ),
         
     | 
| 90 | 
         
            +
                    )
         
     | 
| 91 | 
         
            +
                    if len(chkpts) > 0:
         
     | 
| 92 | 
         
            +
                        p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
         
     | 
| 93 | 
         
            +
                        worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
         
     | 
| 94 | 
         
            +
                    # add random digits to resolve ties
         
     | 
| 95 | 
         
            +
                    with data_utils.numpy_seed(epoch, updates, val_loss):
         
     | 
| 96 | 
         
            +
                        rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    checkpoint_conds[
         
     | 
| 99 | 
         
            +
                        "checkpoint.best_{}_{:.3f}{}{}.pt".format(
         
     | 
| 100 | 
         
            +
                            cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
         
     | 
| 101 | 
         
            +
                        )
         
     | 
| 102 | 
         
            +
                    ] = worst_best is None or is_better(val_loss, worst_best)
         
     | 
| 103 | 
         
            +
                checkpoint_conds[
         
     | 
| 104 | 
         
            +
                    "checkpoint_last{}.pt".format(suffix)
         
     | 
| 105 | 
         
            +
                ] = not cfg.no_last_checkpoints
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
         
     | 
| 108 | 
         
            +
                if hasattr(save_checkpoint, "best"):
         
     | 
| 109 | 
         
            +
                    extra_state.update({"best": save_checkpoint.best})
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                checkpoints = [
         
     | 
| 112 | 
         
            +
                    os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
         
     | 
| 113 | 
         
            +
                ]
         
     | 
| 114 | 
         
            +
                if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
         
     | 
| 115 | 
         
            +
                    trainer.save_checkpoint(checkpoints[0], extra_state)
         
     | 
| 116 | 
         
            +
                    for cp in checkpoints[1:]:
         
     | 
| 117 | 
         
            +
                        if cfg.write_checkpoints_asynchronously:
         
     | 
| 118 | 
         
            +
                            # TODO[ioPath]: Need to implement a delayed asynchronous
         
     | 
| 119 | 
         
            +
                            # file copying/moving feature.
         
     | 
| 120 | 
         
            +
                            logger.warning(
         
     | 
| 121 | 
         
            +
                                f"ioPath is not copying {checkpoints[0]} to {cp} "
         
     | 
| 122 | 
         
            +
                                "since async write mode is on."
         
     | 
| 123 | 
         
            +
                            )
         
     | 
| 124 | 
         
            +
                        else:
         
     | 
| 125 | 
         
            +
                            assert PathManager.copy(
         
     | 
| 126 | 
         
            +
                                checkpoints[0], cp, overwrite=True
         
     | 
| 127 | 
         
            +
                            ), f"Failed to copy {checkpoints[0]} to {cp}"
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    write_timer.stop()
         
     | 
| 130 | 
         
            +
                    logger.info(
         
     | 
| 131 | 
         
            +
                        "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
         
     | 
| 132 | 
         
            +
                            checkpoints[0], epoch, updates, val_loss, write_timer.sum
         
     | 
| 133 | 
         
            +
                        )
         
     | 
| 134 | 
         
            +
                    )
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                if not end_of_epoch and cfg.keep_interval_updates > 0:
         
     | 
| 137 | 
         
            +
                    # remove old checkpoints; checkpoints are sorted in descending order
         
     | 
| 138 | 
         
            +
                    if cfg.keep_interval_updates_pattern == -1:
         
     | 
| 139 | 
         
            +
                        checkpoints = checkpoint_paths(
         
     | 
| 140 | 
         
            +
                            cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
         
     | 
| 141 | 
         
            +
                        )
         
     | 
| 142 | 
         
            +
                    else:
         
     | 
| 143 | 
         
            +
                        checkpoints = checkpoint_paths(
         
     | 
| 144 | 
         
            +
                            cfg.save_dir,
         
     | 
| 145 | 
         
            +
                            pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
         
     | 
| 146 | 
         
            +
                            keep_match=True,
         
     | 
| 147 | 
         
            +
                        )
         
     | 
| 148 | 
         
            +
                        checkpoints = [
         
     | 
| 149 | 
         
            +
                            x[0]
         
     | 
| 150 | 
         
            +
                            for x in checkpoints
         
     | 
| 151 | 
         
            +
                            if x[1] % cfg.keep_interval_updates_pattern != 0
         
     | 
| 152 | 
         
            +
                        ]
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    for old_chk in checkpoints[cfg.keep_interval_updates :]:
         
     | 
| 155 | 
         
            +
                        if os.path.lexists(old_chk):
         
     | 
| 156 | 
         
            +
                            os.remove(old_chk)
         
     | 
| 157 | 
         
            +
                        elif PathManager.exists(old_chk):
         
     | 
| 158 | 
         
            +
                            PathManager.rm(old_chk)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                if cfg.keep_last_epochs > 0:
         
     | 
| 161 | 
         
            +
                    # remove old epoch checkpoints; checkpoints are sorted in descending order
         
     | 
| 162 | 
         
            +
                    checkpoints = checkpoint_paths(
         
     | 
| 163 | 
         
            +
                        cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
         
     | 
| 164 | 
         
            +
                    )
         
     | 
| 165 | 
         
            +
                    for old_chk in checkpoints[cfg.keep_last_epochs :]:
         
     | 
| 166 | 
         
            +
                        if os.path.lexists(old_chk):
         
     | 
| 167 | 
         
            +
                            os.remove(old_chk)
         
     | 
| 168 | 
         
            +
                        elif PathManager.exists(old_chk):
         
     | 
| 169 | 
         
            +
                            PathManager.rm(old_chk)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                if cfg.keep_best_checkpoints > 0:
         
     | 
| 172 | 
         
            +
                    # only keep the best N checkpoints according to validation metric
         
     | 
| 173 | 
         
            +
                    checkpoints = checkpoint_paths(
         
     | 
| 174 | 
         
            +
                        cfg.save_dir,
         
     | 
| 175 | 
         
            +
                        pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
         
     | 
| 176 | 
         
            +
                            cfg.best_checkpoint_metric, suffix
         
     | 
| 177 | 
         
            +
                        ),
         
     | 
| 178 | 
         
            +
                    )
         
     | 
| 179 | 
         
            +
                    if not cfg.maximize_best_checkpoint_metric:
         
     | 
| 180 | 
         
            +
                        checkpoints = checkpoints[::-1]
         
     | 
| 181 | 
         
            +
                    for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
         
     | 
| 182 | 
         
            +
                        if os.path.lexists(old_chk):
         
     | 
| 183 | 
         
            +
                            os.remove(old_chk)
         
     | 
| 184 | 
         
            +
                        elif PathManager.exists(old_chk):
         
     | 
| 185 | 
         
            +
                            PathManager.rm(old_chk)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
            def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
         
     | 
| 189 | 
         
            +
                """
         
     | 
| 190 | 
         
            +
                Load a checkpoint and restore the training iterator.
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                *passthrough_args* will be passed through to
         
     | 
| 193 | 
         
            +
                ``trainer.get_train_iterator``.
         
     | 
| 194 | 
         
            +
                """
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                reset_optimizer = cfg.reset_optimizer
         
     | 
| 197 | 
         
            +
                reset_lr_scheduler = cfg.reset_lr_scheduler
         
     | 
| 198 | 
         
            +
                optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
         
     | 
| 199 | 
         
            +
                reset_meters = cfg.reset_meters
         
     | 
| 200 | 
         
            +
                reset_dataloader = cfg.reset_dataloader
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                if cfg.finetune_from_model is not None and (
         
     | 
| 203 | 
         
            +
                    reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
         
     | 
| 204 | 
         
            +
                ):
         
     | 
| 205 | 
         
            +
                    raise ValueError(
         
     | 
| 206 | 
         
            +
                        "--finetune-from-model can not be set together with either --reset-optimizer"
         
     | 
| 207 | 
         
            +
                        " or reset_lr_scheduler or reset_meters or reset_dataloader"
         
     | 
| 208 | 
         
            +
                    )
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                suffix = trainer.checkpoint_suffix
         
     | 
| 211 | 
         
            +
                if (
         
     | 
| 212 | 
         
            +
                    cfg.restore_file == "checkpoint_last.pt"
         
     | 
| 213 | 
         
            +
                ):  # default value of restore_file is 'checkpoint_last.pt'
         
     | 
| 214 | 
         
            +
                    checkpoint_path = os.path.join(
         
     | 
| 215 | 
         
            +
                        cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
         
     | 
| 216 | 
         
            +
                    )
         
     | 
| 217 | 
         
            +
                    first_launch = not PathManager.exists(checkpoint_path)
         
     | 
| 218 | 
         
            +
                    if first_launch and getattr(cfg, "continue_once", None) is not None:
         
     | 
| 219 | 
         
            +
                        checkpoint_path = cfg.continue_once
         
     | 
| 220 | 
         
            +
                    elif cfg.finetune_from_model is not None and first_launch:
         
     | 
| 221 | 
         
            +
                        # if there is no last checkpoint to restore, start the finetune from pretrained model
         
     | 
| 222 | 
         
            +
                        # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
         
     | 
| 223 | 
         
            +
                        if PathManager.exists(cfg.finetune_from_model):
         
     | 
| 224 | 
         
            +
                            checkpoint_path = cfg.finetune_from_model
         
     | 
| 225 | 
         
            +
                            reset_optimizer = True
         
     | 
| 226 | 
         
            +
                            reset_lr_scheduler = True
         
     | 
| 227 | 
         
            +
                            reset_meters = True
         
     | 
| 228 | 
         
            +
                            reset_dataloader = True
         
     | 
| 229 | 
         
            +
                            logger.info(
         
     | 
| 230 | 
         
            +
                                f"loading pretrained model from {checkpoint_path}: "
         
     | 
| 231 | 
         
            +
                                "optimizer, lr scheduler, meters, dataloader will be reset"
         
     | 
| 232 | 
         
            +
                            )
         
     | 
| 233 | 
         
            +
                        else:
         
     | 
| 234 | 
         
            +
                            raise ValueError(
         
     | 
| 235 | 
         
            +
                                f"--finetune-from-model {cfg.finetune_from_model} does not exist"
         
     | 
| 236 | 
         
            +
                            )
         
     | 
| 237 | 
         
            +
                elif suffix is not None:
         
     | 
| 238 | 
         
            +
                    checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
         
     | 
| 239 | 
         
            +
                else:
         
     | 
| 240 | 
         
            +
                    checkpoint_path = cfg.restore_file
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
         
     | 
| 243 | 
         
            +
                    raise ValueError(
         
     | 
| 244 | 
         
            +
                        "--finetune-from-model and --restore-file (non-default value) "
         
     | 
| 245 | 
         
            +
                        "can not be specified together: " + str(cfg)
         
     | 
| 246 | 
         
            +
                    )
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                extra_state = trainer.load_checkpoint(
         
     | 
| 249 | 
         
            +
                    checkpoint_path,
         
     | 
| 250 | 
         
            +
                    reset_optimizer,
         
     | 
| 251 | 
         
            +
                    reset_lr_scheduler,
         
     | 
| 252 | 
         
            +
                    optimizer_overrides,
         
     | 
| 253 | 
         
            +
                    reset_meters=reset_meters,
         
     | 
| 254 | 
         
            +
                )
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                if (
         
     | 
| 257 | 
         
            +
                    extra_state is not None
         
     | 
| 258 | 
         
            +
                    and "best" in extra_state
         
     | 
| 259 | 
         
            +
                    and not reset_optimizer
         
     | 
| 260 | 
         
            +
                    and not reset_meters
         
     | 
| 261 | 
         
            +
                ):
         
     | 
| 262 | 
         
            +
                    save_checkpoint.best = extra_state["best"]
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                if extra_state is not None and not reset_dataloader:
         
     | 
| 265 | 
         
            +
                    # restore iterator from checkpoint
         
     | 
| 266 | 
         
            +
                    itr_state = extra_state["train_iterator"]
         
     | 
| 267 | 
         
            +
                    epoch_itr = trainer.get_train_iterator(
         
     | 
| 268 | 
         
            +
                        epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
         
     | 
| 269 | 
         
            +
                    )
         
     | 
| 270 | 
         
            +
                    epoch_itr.load_state_dict(itr_state)
         
     | 
| 271 | 
         
            +
                else:
         
     | 
| 272 | 
         
            +
                    epoch_itr = trainer.get_train_iterator(
         
     | 
| 273 | 
         
            +
                        epoch=1, load_dataset=True, **passthrough_args
         
     | 
| 274 | 
         
            +
                    )
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                trainer.lr_step(epoch_itr.epoch)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                return extra_state, epoch_itr
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
            def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
         
     | 
| 282 | 
         
            +
                """Loads a checkpoint to CPU (with upgrading for backward compatibility).
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                If doing single-GPU training or if the checkpoint is only being loaded by at
         
     | 
| 285 | 
         
            +
                most one process on each node (current default behavior is for only rank 0
         
     | 
| 286 | 
         
            +
                to read the checkpoint from disk), load_on_all_ranks should be False to
         
     | 
| 287 | 
         
            +
                avoid errors from torch.distributed not having been initialized or
         
     | 
| 288 | 
         
            +
                torch.distributed.barrier() hanging.
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                If all processes on each node may be loading the checkpoint
         
     | 
| 291 | 
         
            +
                simultaneously, load_on_all_ranks should be set to True to avoid I/O
         
     | 
| 292 | 
         
            +
                conflicts.
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                There's currently no support for > 1 but < all processes loading the
         
     | 
| 295 | 
         
            +
                checkpoint on each node.
         
     | 
| 296 | 
         
            +
                """
         
     | 
| 297 | 
         
            +
                local_path = PathManager.get_local_path(path)
         
     | 
| 298 | 
         
            +
                # The locally cached file returned by get_local_path() may be stale for
         
     | 
| 299 | 
         
            +
                # remote files that are periodically updated/overwritten (ex:
         
     | 
| 300 | 
         
            +
                # checkpoint_last.pt) - so we remove the local copy, sync across processes
         
     | 
| 301 | 
         
            +
                # (if needed), and then download a fresh copy.
         
     | 
| 302 | 
         
            +
                if local_path != path and PathManager.path_requires_pathmanager(path):
         
     | 
| 303 | 
         
            +
                    try:
         
     | 
| 304 | 
         
            +
                        os.remove(local_path)
         
     | 
| 305 | 
         
            +
                    except FileNotFoundError:
         
     | 
| 306 | 
         
            +
                        # With potentially multiple processes removing the same file, the
         
     | 
| 307 | 
         
            +
                        # file being missing is benign (missing_ok isn't available until
         
     | 
| 308 | 
         
            +
                        # Python 3.8).
         
     | 
| 309 | 
         
            +
                        pass
         
     | 
| 310 | 
         
            +
                    if load_on_all_ranks:
         
     | 
| 311 | 
         
            +
                        torch.distributed.barrier()
         
     | 
| 312 | 
         
            +
                    local_path = PathManager.get_local_path(path)
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                with open(local_path, "rb") as f:
         
     | 
| 315 | 
         
            +
                    state = torch.load(f, map_location=torch.device("cpu"))
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                if "args" in state and state["args"] is not None and arg_overrides is not None:
         
     | 
| 318 | 
         
            +
                    args = state["args"]
         
     | 
| 319 | 
         
            +
                    for arg_name, arg_val in arg_overrides.items():
         
     | 
| 320 | 
         
            +
                        setattr(args, arg_name, arg_val)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                if "cfg" in state and state["cfg"] is not None:
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    # hack to be able to set Namespace in dict config. this should be removed when we update to newer
         
     | 
| 325 | 
         
            +
                    # omegaconf version that supports object flags, or when we migrate all existing models
         
     | 
| 326 | 
         
            +
                    from omegaconf import __version__ as oc_version
         
     | 
| 327 | 
         
            +
                    from omegaconf import _utils
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    if oc_version < "2.2":
         
     | 
| 330 | 
         
            +
                        old_primitive = _utils.is_primitive_type
         
     | 
| 331 | 
         
            +
                        _utils.is_primitive_type = lambda _: True
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                        state["cfg"] = OmegaConf.create(state["cfg"])
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                        _utils.is_primitive_type = old_primitive
         
     | 
| 336 | 
         
            +
                        OmegaConf.set_struct(state["cfg"], True)
         
     | 
| 337 | 
         
            +
                    else:
         
     | 
| 338 | 
         
            +
                        state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                    if arg_overrides is not None:
         
     | 
| 341 | 
         
            +
                        overwrite_args_by_name(state["cfg"], arg_overrides)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                state = _upgrade_state_dict(state)
         
     | 
| 344 | 
         
            +
                return state
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
            def load_model_ensemble(
         
     | 
| 348 | 
         
            +
                filenames,
         
     | 
| 349 | 
         
            +
                arg_overrides: Optional[Dict[str, Any]] = None,
         
     | 
| 350 | 
         
            +
                task=None,
         
     | 
| 351 | 
         
            +
                strict=True,
         
     | 
| 352 | 
         
            +
                suffix="",
         
     | 
| 353 | 
         
            +
                num_shards=1,
         
     | 
| 354 | 
         
            +
                state=None,
         
     | 
| 355 | 
         
            +
            ):
         
     | 
| 356 | 
         
            +
                """Loads an ensemble of models.
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                Args:
         
     | 
| 359 | 
         
            +
                    filenames (List[str]): checkpoint files to load
         
     | 
| 360 | 
         
            +
                    arg_overrides (Dict[str,Any], optional): override model args that
         
     | 
| 361 | 
         
            +
                        were used during model training
         
     | 
| 362 | 
         
            +
                    task (fairseq.tasks.FairseqTask, optional): task to use for loading
         
     | 
| 363 | 
         
            +
                """
         
     | 
| 364 | 
         
            +
                assert not (
         
     | 
| 365 | 
         
            +
                    strict and num_shards > 1
         
     | 
| 366 | 
         
            +
                ), "Cannot load state dict with strict=True and checkpoint shards > 1"
         
     | 
| 367 | 
         
            +
                ensemble, args, _task = load_model_ensemble_and_task(
         
     | 
| 368 | 
         
            +
                    filenames,
         
     | 
| 369 | 
         
            +
                    arg_overrides,
         
     | 
| 370 | 
         
            +
                    task,
         
     | 
| 371 | 
         
            +
                    strict,
         
     | 
| 372 | 
         
            +
                    suffix,
         
     | 
| 373 | 
         
            +
                    num_shards,
         
     | 
| 374 | 
         
            +
                    state,
         
     | 
| 375 | 
         
            +
                )
         
     | 
| 376 | 
         
            +
                return ensemble, args
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
            def get_maybe_sharded_checkpoint_filename(
         
     | 
| 380 | 
         
            +
                filename: str, suffix: str, shard_idx: int, num_shards: int
         
     | 
| 381 | 
         
            +
            ) -> str:
         
     | 
| 382 | 
         
            +
                orig_filename = filename
         
     | 
| 383 | 
         
            +
                filename = filename.replace(".pt", suffix + ".pt")
         
     | 
| 384 | 
         
            +
                fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
         
     | 
| 385 | 
         
            +
                model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
         
     | 
| 386 | 
         
            +
                if PathManager.exists(fsdp_filename):
         
     | 
| 387 | 
         
            +
                    return fsdp_filename
         
     | 
| 388 | 
         
            +
                elif num_shards > 1:
         
     | 
| 389 | 
         
            +
                    return model_parallel_filename
         
     | 
| 390 | 
         
            +
                else:
         
     | 
| 391 | 
         
            +
                    return filename
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
            def load_model_ensemble_and_task(
         
     | 
| 395 | 
         
            +
                filenames,
         
     | 
| 396 | 
         
            +
                arg_overrides: Optional[Dict[str, Any]] = None,
         
     | 
| 397 | 
         
            +
                task=None,
         
     | 
| 398 | 
         
            +
                strict=True,
         
     | 
| 399 | 
         
            +
                suffix="",
         
     | 
| 400 | 
         
            +
                num_shards=1,
         
     | 
| 401 | 
         
            +
                state=None,
         
     | 
| 402 | 
         
            +
            ):
         
     | 
| 403 | 
         
            +
                assert state is None or len(filenames) == 1
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                from fairseq import tasks
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                assert not (
         
     | 
| 408 | 
         
            +
                    strict and num_shards > 1
         
     | 
| 409 | 
         
            +
                ), "Cannot load state dict with strict=True and checkpoint shards > 1"
         
     | 
| 410 | 
         
            +
                ensemble = []
         
     | 
| 411 | 
         
            +
                cfg = None
         
     | 
| 412 | 
         
            +
                for filename in filenames:
         
     | 
| 413 | 
         
            +
                    orig_filename = filename
         
     | 
| 414 | 
         
            +
                    model_shard_state = {"shard_weights": [], "shard_metadata": []}
         
     | 
| 415 | 
         
            +
                    assert num_shards > 0
         
     | 
| 416 | 
         
            +
                    st = time.time()
         
     | 
| 417 | 
         
            +
                    for shard_idx in range(num_shards):
         
     | 
| 418 | 
         
            +
                        filename = get_maybe_sharded_checkpoint_filename(
         
     | 
| 419 | 
         
            +
                            orig_filename, suffix, shard_idx, num_shards
         
     | 
| 420 | 
         
            +
                        )
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
                        if not PathManager.exists(filename):
         
     | 
| 423 | 
         
            +
                            raise IOError("Model file not found: {}".format(filename))
         
     | 
| 424 | 
         
            +
                        if state is None:
         
     | 
| 425 | 
         
            +
                            state = load_checkpoint_to_cpu(filename, arg_overrides)
         
     | 
| 426 | 
         
            +
                        if "args" in state and state["args"] is not None:
         
     | 
| 427 | 
         
            +
                            cfg = convert_namespace_to_omegaconf(state["args"])
         
     | 
| 428 | 
         
            +
                        elif "cfg" in state and state["cfg"] is not None:
         
     | 
| 429 | 
         
            +
                            cfg = state["cfg"]
         
     | 
| 430 | 
         
            +
                        else:
         
     | 
| 431 | 
         
            +
                            raise RuntimeError(
         
     | 
| 432 | 
         
            +
                                f"Neither args nor cfg exist in state keys = {state.keys()}"
         
     | 
| 433 | 
         
            +
                            )
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
                        if task is None:
         
     | 
| 436 | 
         
            +
                            task = tasks.setup_task(cfg.task)
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                        if "task_state" in state:
         
     | 
| 439 | 
         
            +
                            task.load_state_dict(state["task_state"])
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
                        if "fsdp_metadata" in state and num_shards > 1:
         
     | 
| 442 | 
         
            +
                            model_shard_state["shard_weights"].append(state["model"])
         
     | 
| 443 | 
         
            +
                            model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
         
     | 
| 444 | 
         
            +
                            # check FSDP import before the code goes too far
         
     | 
| 445 | 
         
            +
                            if not has_FSDP:
         
     | 
| 446 | 
         
            +
                                raise ImportError(
         
     | 
| 447 | 
         
            +
                                    "Cannot find FullyShardedDataParallel. "
         
     | 
| 448 | 
         
            +
                                    "Please install fairscale with: pip install fairscale"
         
     | 
| 449 | 
         
            +
                                )
         
     | 
| 450 | 
         
            +
                            if shard_idx == num_shards - 1:
         
     | 
| 451 | 
         
            +
                                consolidated_model_state = FSDP.consolidate_shard_weights(
         
     | 
| 452 | 
         
            +
                                    shard_weights=model_shard_state["shard_weights"],
         
     | 
| 453 | 
         
            +
                                    shard_metadata=model_shard_state["shard_metadata"],
         
     | 
| 454 | 
         
            +
                                )
         
     | 
| 455 | 
         
            +
                                model = task.build_model(cfg.model)
         
     | 
| 456 | 
         
            +
                                if (
         
     | 
| 457 | 
         
            +
                                    "optimizer_history" in state
         
     | 
| 458 | 
         
            +
                                    and len(state["optimizer_history"]) > 0
         
     | 
| 459 | 
         
            +
                                    and "num_updates" in state["optimizer_history"][-1]
         
     | 
| 460 | 
         
            +
                                ):
         
     | 
| 461 | 
         
            +
                                    model.set_num_updates(
         
     | 
| 462 | 
         
            +
                                        state["optimizer_history"][-1]["num_updates"]
         
     | 
| 463 | 
         
            +
                                    )
         
     | 
| 464 | 
         
            +
                                model.load_state_dict(
         
     | 
| 465 | 
         
            +
                                    consolidated_model_state, strict=strict, model_cfg=cfg.model
         
     | 
| 466 | 
         
            +
                                )
         
     | 
| 467 | 
         
            +
                        else:
         
     | 
| 468 | 
         
            +
                            # model parallel checkpoint or unsharded checkpoint
         
     | 
| 469 | 
         
            +
                            # support old external tasks
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                            argspec = inspect.getfullargspec(task.build_model)
         
     | 
| 472 | 
         
            +
                            if "from_checkpoint" in argspec.args:
         
     | 
| 473 | 
         
            +
                                model = task.build_model(cfg.model, from_checkpoint=True)
         
     | 
| 474 | 
         
            +
                            else:
         
     | 
| 475 | 
         
            +
                                model = task.build_model(cfg.model)
         
     | 
| 476 | 
         
            +
                            if (
         
     | 
| 477 | 
         
            +
                                "optimizer_history" in state
         
     | 
| 478 | 
         
            +
                                and len(state["optimizer_history"]) > 0
         
     | 
| 479 | 
         
            +
                                and "num_updates" in state["optimizer_history"][-1]
         
     | 
| 480 | 
         
            +
                            ):
         
     | 
| 481 | 
         
            +
                                model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
         
     | 
| 482 | 
         
            +
                            model.load_state_dict(
         
     | 
| 483 | 
         
            +
                                state["model"], strict=strict, model_cfg=cfg.model
         
     | 
| 484 | 
         
            +
                            )
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                        # reset state so it gets loaded for the next model in ensemble
         
     | 
| 487 | 
         
            +
                        state = None
         
     | 
| 488 | 
         
            +
                        if shard_idx % 10 == 0 and shard_idx > 0:
         
     | 
| 489 | 
         
            +
                            elapsed = time.time() - st
         
     | 
| 490 | 
         
            +
                            logger.info(
         
     | 
| 491 | 
         
            +
                                f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
         
     | 
| 492 | 
         
            +
                            )
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
                    # build model for ensemble
         
     | 
| 495 | 
         
            +
                    ensemble.append(model)
         
     | 
| 496 | 
         
            +
                return ensemble, cfg, task
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
            def load_model_ensemble_and_task_from_hf_hub(
         
     | 
| 500 | 
         
            +
                model_id,
         
     | 
| 501 | 
         
            +
                cache_dir: Optional[str] = None,
         
     | 
| 502 | 
         
            +
                arg_overrides: Optional[Dict[str, Any]] = None,
         
     | 
| 503 | 
         
            +
                **kwargs: Any,
         
     | 
| 504 | 
         
            +
            ):
         
     | 
| 505 | 
         
            +
                try:
         
     | 
| 506 | 
         
            +
                    from huggingface_hub import snapshot_download
         
     | 
| 507 | 
         
            +
                except ImportError:
         
     | 
| 508 | 
         
            +
                    raise ImportError(
         
     | 
| 509 | 
         
            +
                        "You need to install huggingface_hub to use `load_from_hf_hub`. "
         
     | 
| 510 | 
         
            +
                        "See https://pypi.org/project/huggingface-hub/ for installation."
         
     | 
| 511 | 
         
            +
                    )
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
                library_name = "fairseq"
         
     | 
| 514 | 
         
            +
                cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
         
     | 
| 515 | 
         
            +
                cache_dir = snapshot_download(
         
     | 
| 516 | 
         
            +
                    model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
         
     | 
| 517 | 
         
            +
                )
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                _arg_overrides = arg_overrides or {}
         
     | 
| 520 | 
         
            +
                _arg_overrides["data"] = cache_dir
         
     | 
| 521 | 
         
            +
                return load_model_ensemble_and_task(
         
     | 
| 522 | 
         
            +
                    [p.as_posix() for p in Path(cache_dir).glob("*.pt")],
         
     | 
| 523 | 
         
            +
                    arg_overrides=_arg_overrides,
         
     | 
| 524 | 
         
            +
                )
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
            def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
         
     | 
| 528 | 
         
            +
                """Retrieves all checkpoints found in `path` directory.
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
                Checkpoints are identified by matching filename to the specified pattern. If
         
     | 
| 531 | 
         
            +
                the pattern contains groups, the result will be sorted by the first group in
         
     | 
| 532 | 
         
            +
                descending order.
         
     | 
| 533 | 
         
            +
                """
         
     | 
| 534 | 
         
            +
                pt_regexp = re.compile(pattern)
         
     | 
| 535 | 
         
            +
                files = PathManager.ls(path)
         
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
                entries = []
         
     | 
| 538 | 
         
            +
                for i, f in enumerate(files):
         
     | 
| 539 | 
         
            +
                    m = pt_regexp.fullmatch(f)
         
     | 
| 540 | 
         
            +
                    if m is not None:
         
     | 
| 541 | 
         
            +
                        idx = float(m.group(1)) if len(m.groups()) > 0 else i
         
     | 
| 542 | 
         
            +
                        entries.append((idx, m.group(0)))
         
     | 
| 543 | 
         
            +
                if keep_match:
         
     | 
| 544 | 
         
            +
                    return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
         
     | 
| 545 | 
         
            +
                else:
         
     | 
| 546 | 
         
            +
                    return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
         
     | 
| 547 | 
         
            +
             
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
            def torch_persistent_save(obj, filename, async_write: bool = False):
         
     | 
| 550 | 
         
            +
                if async_write:
         
     | 
| 551 | 
         
            +
                    with PathManager.opena(filename, "wb") as f:
         
     | 
| 552 | 
         
            +
                        _torch_persistent_save(obj, f)
         
     | 
| 553 | 
         
            +
                else:
         
     | 
| 554 | 
         
            +
                    if PathManager.supports_rename(filename):
         
     | 
| 555 | 
         
            +
                        # do atomic save
         
     | 
| 556 | 
         
            +
                        with PathManager.open(filename + ".tmp", "wb") as f:
         
     | 
| 557 | 
         
            +
                            _torch_persistent_save(obj, f)
         
     | 
| 558 | 
         
            +
                        PathManager.rename(filename + ".tmp", filename)
         
     | 
| 559 | 
         
            +
                    else:
         
     | 
| 560 | 
         
            +
                        # fallback to non-atomic save
         
     | 
| 561 | 
         
            +
                        with PathManager.open(filename, "wb") as f:
         
     | 
| 562 | 
         
            +
                            _torch_persistent_save(obj, f)
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
             
     | 
| 565 | 
         
            +
            def _torch_persistent_save(obj, f):
         
     | 
| 566 | 
         
            +
                if isinstance(f, str):
         
     | 
| 567 | 
         
            +
                    with PathManager.open(f, "wb") as h:
         
     | 
| 568 | 
         
            +
                        torch_persistent_save(obj, h)
         
     | 
| 569 | 
         
            +
                    return
         
     | 
| 570 | 
         
            +
                for i in range(3):
         
     | 
| 571 | 
         
            +
                    try:
         
     | 
| 572 | 
         
            +
                        return torch.save(obj, f)
         
     | 
| 573 | 
         
            +
                    except Exception:
         
     | 
| 574 | 
         
            +
                        if i == 2:
         
     | 
| 575 | 
         
            +
                            logger.error(traceback.format_exc())
         
     | 
| 576 | 
         
            +
                            raise
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
            def _upgrade_state_dict(state):
         
     | 
| 580 | 
         
            +
                """Helper for upgrading old model checkpoints."""
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                # add optimizer_history
         
     | 
| 583 | 
         
            +
                if "optimizer_history" not in state:
         
     | 
| 584 | 
         
            +
                    state["optimizer_history"] = [
         
     | 
| 585 | 
         
            +
                        {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
         
     | 
| 586 | 
         
            +
                    ]
         
     | 
| 587 | 
         
            +
                    state["last_optimizer_state"] = state["optimizer"]
         
     | 
| 588 | 
         
            +
                    del state["optimizer"]
         
     | 
| 589 | 
         
            +
                    del state["best_loss"]
         
     | 
| 590 | 
         
            +
                # move extra_state into sub-dictionary
         
     | 
| 591 | 
         
            +
                if "epoch" in state and "extra_state" not in state:
         
     | 
| 592 | 
         
            +
                    state["extra_state"] = {
         
     | 
| 593 | 
         
            +
                        "epoch": state["epoch"],
         
     | 
| 594 | 
         
            +
                        "batch_offset": state["batch_offset"],
         
     | 
| 595 | 
         
            +
                        "val_loss": state["val_loss"],
         
     | 
| 596 | 
         
            +
                    }
         
     | 
| 597 | 
         
            +
                    del state["epoch"]
         
     | 
| 598 | 
         
            +
                    del state["batch_offset"]
         
     | 
| 599 | 
         
            +
                    del state["val_loss"]
         
     | 
| 600 | 
         
            +
                # reduce optimizer history's memory usage (only keep the last state)
         
     | 
| 601 | 
         
            +
                if "optimizer" in state["optimizer_history"][-1]:
         
     | 
| 602 | 
         
            +
                    state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
         
     | 
| 603 | 
         
            +
                    for optim_hist in state["optimizer_history"]:
         
     | 
| 604 | 
         
            +
                        del optim_hist["optimizer"]
         
     | 
| 605 | 
         
            +
                # record the optimizer class name
         
     | 
| 606 | 
         
            +
                if "optimizer_name" not in state["optimizer_history"][-1]:
         
     | 
| 607 | 
         
            +
                    state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
         
     | 
| 608 | 
         
            +
                # move best_loss into lr_scheduler_state
         
     | 
| 609 | 
         
            +
                if "lr_scheduler_state" not in state["optimizer_history"][-1]:
         
     | 
| 610 | 
         
            +
                    state["optimizer_history"][-1]["lr_scheduler_state"] = {
         
     | 
| 611 | 
         
            +
                        "best": state["optimizer_history"][-1]["best_loss"]
         
     | 
| 612 | 
         
            +
                    }
         
     | 
| 613 | 
         
            +
                    del state["optimizer_history"][-1]["best_loss"]
         
     | 
| 614 | 
         
            +
                # keep track of number of updates
         
     | 
| 615 | 
         
            +
                if "num_updates" not in state["optimizer_history"][-1]:
         
     | 
| 616 | 
         
            +
                    state["optimizer_history"][-1]["num_updates"] = 0
         
     | 
| 617 | 
         
            +
                # use stateful training data iterator
         
     | 
| 618 | 
         
            +
                if "train_iterator" not in state["extra_state"]:
         
     | 
| 619 | 
         
            +
                    state["extra_state"]["train_iterator"] = {
         
     | 
| 620 | 
         
            +
                        "epoch": state["extra_state"].get("epoch", 0),
         
     | 
| 621 | 
         
            +
                        "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
         
     | 
| 622 | 
         
            +
                    }
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
                # backward compatibility, cfg updates
         
     | 
| 625 | 
         
            +
                if "args" in state and state["args"] is not None:
         
     | 
| 626 | 
         
            +
                    # old model checkpoints may not have separate source/target positions
         
     | 
| 627 | 
         
            +
                    if hasattr(state["args"], "max_positions") and not hasattr(
         
     | 
| 628 | 
         
            +
                        state["args"], "max_source_positions"
         
     | 
| 629 | 
         
            +
                    ):
         
     | 
| 630 | 
         
            +
                        state["args"].max_source_positions = state["args"].max_positions
         
     | 
| 631 | 
         
            +
                        state["args"].max_target_positions = state["args"].max_positions
         
     | 
| 632 | 
         
            +
                    # default to translation task
         
     | 
| 633 | 
         
            +
                    if not hasattr(state["args"], "task"):
         
     | 
| 634 | 
         
            +
                        state["args"].task = "translation"
         
     | 
| 635 | 
         
            +
                    # --raw-text and --lazy-load are deprecated
         
     | 
| 636 | 
         
            +
                    if getattr(state["args"], "raw_text", False):
         
     | 
| 637 | 
         
            +
                        state["args"].dataset_impl = "raw"
         
     | 
| 638 | 
         
            +
                    elif getattr(state["args"], "lazy_load", False):
         
     | 
| 639 | 
         
            +
                        state["args"].dataset_impl = "lazy"
         
     | 
| 640 | 
         
            +
                    # epochs start at 1
         
     | 
| 641 | 
         
            +
                    if state["extra_state"]["train_iterator"] is not None:
         
     | 
| 642 | 
         
            +
                        state["extra_state"]["train_iterator"]["epoch"] = max(
         
     | 
| 643 | 
         
            +
                            state["extra_state"]["train_iterator"].get("epoch", 1), 1
         
     | 
| 644 | 
         
            +
                        )
         
     | 
| 645 | 
         
            +
                    # --remove-bpe ==> --postprocess
         
     | 
| 646 | 
         
            +
                    if hasattr(state["args"], "remove_bpe"):
         
     | 
| 647 | 
         
            +
                        state["args"].post_process = state["args"].remove_bpe
         
     | 
| 648 | 
         
            +
                    # --min-lr ==> --stop-min-lr
         
     | 
| 649 | 
         
            +
                    if hasattr(state["args"], "min_lr"):
         
     | 
| 650 | 
         
            +
                        state["args"].stop_min_lr = state["args"].min_lr
         
     | 
| 651 | 
         
            +
                        del state["args"].min_lr
         
     | 
| 652 | 
         
            +
                    # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
         
     | 
| 653 | 
         
            +
                    if hasattr(state["args"], "criterion") and state["args"].criterion in [
         
     | 
| 654 | 
         
            +
                        "binary_cross_entropy",
         
     | 
| 655 | 
         
            +
                        "kd_binary_cross_entropy",
         
     | 
| 656 | 
         
            +
                    ]:
         
     | 
| 657 | 
         
            +
                        state["args"].criterion = "wav2vec"
         
     | 
| 658 | 
         
            +
                    # remove log_keys if it's None (criteria will supply a default value of [])
         
     | 
| 659 | 
         
            +
                    if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
         
     | 
| 660 | 
         
            +
                        delattr(state["args"], "log_keys")
         
     | 
| 661 | 
         
            +
                    # speech_pretraining => audio pretraining
         
     | 
| 662 | 
         
            +
                    if (
         
     | 
| 663 | 
         
            +
                        hasattr(state["args"], "task")
         
     | 
| 664 | 
         
            +
                        and state["args"].task == "speech_pretraining"
         
     | 
| 665 | 
         
            +
                    ):
         
     | 
| 666 | 
         
            +
                        state["args"].task = "audio_pretraining"
         
     | 
| 667 | 
         
            +
                    # audio_cpc => wav2vec
         
     | 
| 668 | 
         
            +
                    if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
         
     | 
| 669 | 
         
            +
                        state["args"].arch = "wav2vec"
         
     | 
| 670 | 
         
            +
                    # convert legacy float learning rate to List[float]
         
     | 
| 671 | 
         
            +
                    if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
         
     | 
| 672 | 
         
            +
                        state["args"].lr = [state["args"].lr]
         
     | 
| 673 | 
         
            +
                    # convert task data arg to a string instead of List[string]
         
     | 
| 674 | 
         
            +
                    if (
         
     | 
| 675 | 
         
            +
                        hasattr(state["args"], "data")
         
     | 
| 676 | 
         
            +
                        and isinstance(state["args"].data, list)
         
     | 
| 677 | 
         
            +
                        and len(state["args"].data) > 0
         
     | 
| 678 | 
         
            +
                    ):
         
     | 
| 679 | 
         
            +
                        state["args"].data = state["args"].data[0]
         
     | 
| 680 | 
         
            +
             
     | 
| 681 | 
         
            +
                    state["cfg"] = convert_namespace_to_omegaconf(state["args"])
         
     | 
| 682 | 
         
            +
             
     | 
| 683 | 
         
            +
                if "cfg" in state and state["cfg"] is not None:
         
     | 
| 684 | 
         
            +
                    cfg = state["cfg"]
         
     | 
| 685 | 
         
            +
                    with open_dict(cfg):
         
     | 
| 686 | 
         
            +
                        # any upgrades for Hydra-based configs
         
     | 
| 687 | 
         
            +
                        if (
         
     | 
| 688 | 
         
            +
                            "task" in cfg
         
     | 
| 689 | 
         
            +
                            and "eval_wer_config" in cfg.task
         
     | 
| 690 | 
         
            +
                            and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
         
     | 
| 691 | 
         
            +
                        ):
         
     | 
| 692 | 
         
            +
                            cfg.task.eval_wer_config.print_alignment = "hard"
         
     | 
| 693 | 
         
            +
                        if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
         
     | 
| 694 | 
         
            +
                            cfg.generation.print_alignment = (
         
     | 
| 695 | 
         
            +
                                "hard" if cfg.generation.print_alignment else None
         
     | 
| 696 | 
         
            +
                            )
         
     | 
| 697 | 
         
            +
                        if (
         
     | 
| 698 | 
         
            +
                            "model" in cfg
         
     | 
| 699 | 
         
            +
                            and "w2v_args" in cfg.model
         
     | 
| 700 | 
         
            +
                            and cfg.model.w2v_args is not None
         
     | 
| 701 | 
         
            +
                            and (
         
     | 
| 702 | 
         
            +
                                hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
         
     | 
| 703 | 
         
            +
                            )
         
     | 
| 704 | 
         
            +
                            and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
         
     | 
| 705 | 
         
            +
                            and cfg.model.w2v_args.task.eval_wer_config is not None
         
     | 
| 706 | 
         
            +
                            and isinstance(
         
     | 
| 707 | 
         
            +
                                cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
         
     | 
| 708 | 
         
            +
                            )
         
     | 
| 709 | 
         
            +
                        ):
         
     | 
| 710 | 
         
            +
                            cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
         
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
                return state
         
     | 
| 713 | 
         
            +
             
     | 
| 714 | 
         
            +
             
     | 
| 715 | 
         
            +
            def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
         
     | 
| 716 | 
         
            +
                """Prune the given state_dict if desired for LayerDrop
         
     | 
| 717 | 
         
            +
                (https://arxiv.org/abs/1909.11556).
         
     | 
| 718 | 
         
            +
             
     | 
| 719 | 
         
            +
                Training with LayerDrop allows models to be robust to pruning at inference
         
     | 
| 720 | 
         
            +
                time. This function prunes state_dict to allow smaller models to be loaded
         
     | 
| 721 | 
         
            +
                from a larger model and re-maps the existing state_dict for this to occur.
         
     | 
| 722 | 
         
            +
             
     | 
| 723 | 
         
            +
                It's called by functions that load models from checkpoints and does not
         
     | 
| 724 | 
         
            +
                need to be called directly.
         
     | 
| 725 | 
         
            +
                """
         
     | 
| 726 | 
         
            +
                arch = None
         
     | 
| 727 | 
         
            +
                if model_cfg is not None:
         
     | 
| 728 | 
         
            +
                    arch = (
         
     | 
| 729 | 
         
            +
                        model_cfg._name
         
     | 
| 730 | 
         
            +
                        if isinstance(model_cfg, DictConfig)
         
     | 
| 731 | 
         
            +
                        else getattr(model_cfg, "arch", None)
         
     | 
| 732 | 
         
            +
                    )
         
     | 
| 733 | 
         
            +
             
     | 
| 734 | 
         
            +
                if not model_cfg or arch is None or arch == "ptt_transformer":
         
     | 
| 735 | 
         
            +
                    # args should not be none, but don't crash if it is.
         
     | 
| 736 | 
         
            +
                    return state_dict
         
     | 
| 737 | 
         
            +
             
     | 
| 738 | 
         
            +
                encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
         
     | 
| 739 | 
         
            +
                decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
         
     | 
| 740 | 
         
            +
             
     | 
| 741 | 
         
            +
                if not encoder_layers_to_keep and not decoder_layers_to_keep:
         
     | 
| 742 | 
         
            +
                    return state_dict
         
     | 
| 743 | 
         
            +
             
     | 
| 744 | 
         
            +
                # apply pruning
         
     | 
| 745 | 
         
            +
                logger.info(
         
     | 
| 746 | 
         
            +
                    "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
         
     | 
| 747 | 
         
            +
                )
         
     | 
| 748 | 
         
            +
             
     | 
| 749 | 
         
            +
                def create_pruning_pass(layers_to_keep, layer_name):
         
     | 
| 750 | 
         
            +
                    keep_layers = sorted(
         
     | 
| 751 | 
         
            +
                        int(layer_string) for layer_string in layers_to_keep.split(",")
         
     | 
| 752 | 
         
            +
                    )
         
     | 
| 753 | 
         
            +
                    mapping_dict = {}
         
     | 
| 754 | 
         
            +
                    for i in range(len(keep_layers)):
         
     | 
| 755 | 
         
            +
                        mapping_dict[str(keep_layers[i])] = str(i)
         
     | 
| 756 | 
         
            +
             
     | 
| 757 | 
         
            +
                    regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
         
     | 
| 758 | 
         
            +
                    return {"substitution_regex": regex, "mapping_dict": mapping_dict}
         
     | 
| 759 | 
         
            +
             
     | 
| 760 | 
         
            +
                pruning_passes = []
         
     | 
| 761 | 
         
            +
                if encoder_layers_to_keep:
         
     | 
| 762 | 
         
            +
                    pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
         
     | 
| 763 | 
         
            +
                if decoder_layers_to_keep:
         
     | 
| 764 | 
         
            +
                    pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
         
     | 
| 765 | 
         
            +
             
     | 
| 766 | 
         
            +
                new_state_dict = {}
         
     | 
| 767 | 
         
            +
                for layer_name in state_dict.keys():
         
     | 
| 768 | 
         
            +
                    match = re.search(r"\.layers\.(\d+)\.", layer_name)
         
     | 
| 769 | 
         
            +
                    # if layer has no number in it, it is a supporting layer, such as an
         
     | 
| 770 | 
         
            +
                    # embedding
         
     | 
| 771 | 
         
            +
                    if not match:
         
     | 
| 772 | 
         
            +
                        new_state_dict[layer_name] = state_dict[layer_name]
         
     | 
| 773 | 
         
            +
                        continue
         
     | 
| 774 | 
         
            +
             
     | 
| 775 | 
         
            +
                    # otherwise, layer should be pruned.
         
     | 
| 776 | 
         
            +
                    original_layer_number = match.group(1)
         
     | 
| 777 | 
         
            +
                    # figure out which mapping dict to replace from
         
     | 
| 778 | 
         
            +
                    for pruning_pass in pruning_passes:
         
     | 
| 779 | 
         
            +
                        if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
         
     | 
| 780 | 
         
            +
                            "substitution_regex"
         
     | 
| 781 | 
         
            +
                        ].search(layer_name):
         
     | 
| 782 | 
         
            +
                            new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
         
     | 
| 783 | 
         
            +
                            substitution_match = pruning_pass["substitution_regex"].search(
         
     | 
| 784 | 
         
            +
                                layer_name
         
     | 
| 785 | 
         
            +
                            )
         
     | 
| 786 | 
         
            +
                            new_state_key = (
         
     | 
| 787 | 
         
            +
                                layer_name[: substitution_match.start(1)]
         
     | 
| 788 | 
         
            +
                                + new_layer_number
         
     | 
| 789 | 
         
            +
                                + layer_name[substitution_match.end(1) :]
         
     | 
| 790 | 
         
            +
                            )
         
     | 
| 791 | 
         
            +
                            new_state_dict[new_state_key] = state_dict[layer_name]
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
                # Since layers are now pruned, *_layers_to_keep are no longer needed.
         
     | 
| 794 | 
         
            +
                # This is more of "It would make it work fix" rather than a proper fix.
         
     | 
| 795 | 
         
            +
                if isinstance(model_cfg, DictConfig):
         
     | 
| 796 | 
         
            +
                    context = open_dict(model_cfg)
         
     | 
| 797 | 
         
            +
                else:
         
     | 
| 798 | 
         
            +
                    context = contextlib.ExitStack()
         
     | 
| 799 | 
         
            +
                with context:
         
     | 
| 800 | 
         
            +
                    if hasattr(model_cfg, "encoder_layers_to_keep"):
         
     | 
| 801 | 
         
            +
                        model_cfg.encoder_layers_to_keep = None
         
     | 
| 802 | 
         
            +
                    if hasattr(model_cfg, "decoder_layers_to_keep"):
         
     | 
| 803 | 
         
            +
                        model_cfg.decoder_layers_to_keep = None
         
     | 
| 804 | 
         
            +
             
     | 
| 805 | 
         
            +
                return new_state_dict
         
     | 
| 806 | 
         
            +
             
     | 
| 807 | 
         
            +
             
     | 
| 808 | 
         
            +
            def load_pretrained_component_from_model(
         
     | 
| 809 | 
         
            +
                component: Union[FairseqEncoder, FairseqDecoder],
         
     | 
| 810 | 
         
            +
                checkpoint: str,
         
     | 
| 811 | 
         
            +
                strict: bool = True,
         
     | 
| 812 | 
         
            +
            ):
         
     | 
| 813 | 
         
            +
                """
         
     | 
| 814 | 
         
            +
                Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
         
     | 
| 815 | 
         
            +
                provided `component` object. If state_dict fails to load, there may be a
         
     | 
| 816 | 
         
            +
                mismatch in the architecture of the corresponding `component` found in the
         
     | 
| 817 | 
         
            +
                `checkpoint` file.
         
     | 
| 818 | 
         
            +
                """
         
     | 
| 819 | 
         
            +
                if not PathManager.exists(checkpoint):
         
     | 
| 820 | 
         
            +
                    raise IOError("Model file not found: {}".format(checkpoint))
         
     | 
| 821 | 
         
            +
                state = load_checkpoint_to_cpu(checkpoint)
         
     | 
| 822 | 
         
            +
                if isinstance(component, FairseqEncoder):
         
     | 
| 823 | 
         
            +
                    component_type = "encoder"
         
     | 
| 824 | 
         
            +
                elif isinstance(component, FairseqDecoder):
         
     | 
| 825 | 
         
            +
                    component_type = "decoder"
         
     | 
| 826 | 
         
            +
                else:
         
     | 
| 827 | 
         
            +
                    raise ValueError(
         
     | 
| 828 | 
         
            +
                        "component to load must be either a FairseqEncoder or "
         
     | 
| 829 | 
         
            +
                        "FairseqDecoder. Loading other component types are not supported."
         
     | 
| 830 | 
         
            +
                    )
         
     | 
| 831 | 
         
            +
                component_state_dict = OrderedDict()
         
     | 
| 832 | 
         
            +
                for key in state["model"].keys():
         
     | 
| 833 | 
         
            +
                    if key.startswith(component_type):
         
     | 
| 834 | 
         
            +
                        # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
         
     | 
| 835 | 
         
            +
                        component_subkey = key[len(component_type) + 1 :]
         
     | 
| 836 | 
         
            +
                        component_state_dict[component_subkey] = state["model"][key]
         
     | 
| 837 | 
         
            +
                component.load_state_dict(component_state_dict, strict=strict)
         
     | 
| 838 | 
         
            +
                return component
         
     | 
| 839 | 
         
            +
             
     | 
| 840 | 
         
            +
             
     | 
| 841 | 
         
            +
            def verify_checkpoint_directory(save_dir: str) -> None:
         
     | 
| 842 | 
         
            +
                if not os.path.exists(save_dir):
         
     | 
| 843 | 
         
            +
                    os.makedirs(save_dir, exist_ok=True)
         
     | 
| 844 | 
         
            +
                temp_file_path = os.path.join(save_dir, "dummy")
         
     | 
| 845 | 
         
            +
                try:
         
     | 
| 846 | 
         
            +
                    with open(temp_file_path, "w"):
         
     | 
| 847 | 
         
            +
                        pass
         
     | 
| 848 | 
         
            +
                except OSError as e:
         
     | 
| 849 | 
         
            +
                    logger.warning(
         
     | 
| 850 | 
         
            +
                        "Unable to access checkpoint save directory: {}".format(save_dir)
         
     | 
| 851 | 
         
            +
                    )
         
     | 
| 852 | 
         
            +
                    raise e
         
     | 
| 853 | 
         
            +
                else:
         
     | 
| 854 | 
         
            +
                    os.remove(temp_file_path)
         
     | 
| 855 | 
         
            +
             
     | 
| 856 | 
         
            +
             
     | 
| 857 | 
         
            +
            def save_ema_as_checkpoint(src_path, dst_path):
         
     | 
| 858 | 
         
            +
                state = load_ema_from_checkpoint(src_path)
         
     | 
| 859 | 
         
            +
                torch_persistent_save(state, dst_path)
         
     | 
| 860 | 
         
            +
             
     | 
| 861 | 
         
            +
             
     | 
| 862 | 
         
            +
            def load_ema_from_checkpoint(fpath):
         
     | 
| 863 | 
         
            +
                """Loads exponential moving averaged (EMA) checkpoint from input and
         
     | 
| 864 | 
         
            +
                returns a model with ema weights.
         
     | 
| 865 | 
         
            +
             
     | 
| 866 | 
         
            +
                Args:
         
     | 
| 867 | 
         
            +
                  fpath: A string path of checkpoint to load from.
         
     | 
| 868 | 
         
            +
             
     | 
| 869 | 
         
            +
                Returns:
         
     | 
| 870 | 
         
            +
                  A dict of string keys mapping to various values. The 'model' key
         
     | 
| 871 | 
         
            +
                  from the returned dict should correspond to an OrderedDict mapping
         
     | 
| 872 | 
         
            +
                  string parameter names to torch Tensors.
         
     | 
| 873 | 
         
            +
                """
         
     | 
| 874 | 
         
            +
                params_dict = collections.OrderedDict()
         
     | 
| 875 | 
         
            +
                new_state = None
         
     | 
| 876 | 
         
            +
             
     | 
| 877 | 
         
            +
                with PathManager.open(fpath, "rb") as f:
         
     | 
| 878 | 
         
            +
                    new_state = torch.load(
         
     | 
| 879 | 
         
            +
                        f,
         
     | 
| 880 | 
         
            +
                        map_location=(
         
     | 
| 881 | 
         
            +
                            lambda s, _: torch.serialization.default_restore_location(s, "cpu")
         
     | 
| 882 | 
         
            +
                        ),
         
     | 
| 883 | 
         
            +
                    )
         
     | 
| 884 | 
         
            +
             
     | 
| 885 | 
         
            +
                    # EMA model is stored in a separate "extra state"
         
     | 
| 886 | 
         
            +
                    model_params = new_state["extra_state"]["ema"]
         
     | 
| 887 | 
         
            +
             
     | 
| 888 | 
         
            +
                    for key in list(model_params.keys()):
         
     | 
| 889 | 
         
            +
                        p = model_params[key]
         
     | 
| 890 | 
         
            +
                        if isinstance(p, torch.HalfTensor):
         
     | 
| 891 | 
         
            +
                            p = p.float()
         
     | 
| 892 | 
         
            +
                        if key not in params_dict:
         
     | 
| 893 | 
         
            +
                            params_dict[key] = p.clone()
         
     | 
| 894 | 
         
            +
                            # NOTE: clone() is needed in case of p is a shared parameter
         
     | 
| 895 | 
         
            +
                        else:
         
     | 
| 896 | 
         
            +
                            raise ValueError("Key {} is repeated in EMA model params.".format(key))
         
     | 
| 897 | 
         
            +
             
     | 
| 898 | 
         
            +
                    if len(params_dict) == 0:
         
     | 
| 899 | 
         
            +
                        raise ValueError(
         
     | 
| 900 | 
         
            +
                            f"Input checkpoint path '{fpath}' does not contain "
         
     | 
| 901 | 
         
            +
                            "ema model weights, is this model trained with EMA?"
         
     | 
| 902 | 
         
            +
                        )
         
     | 
| 903 | 
         
            +
             
     | 
| 904 | 
         
            +
                new_state["model"] = params_dict
         
     | 
| 905 | 
         
            +
                return new_state
         
     | 
    	
        modules/voice_conversion/fairseq/data/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,130 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
            """isort:skip_file"""
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from .dictionary import Dictionary, TruncatedDictionary
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from .base_wrapper_dataset import BaseWrapperDataset
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from .add_target_dataset import AddTargetDataset
         
     | 
| 14 | 
         
            +
            from .append_token_dataset import AppendTokenDataset
         
     | 
| 15 | 
         
            +
            from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
         
     | 
| 16 | 
         
            +
            from .audio.hubert_dataset import HubertDataset
         
     | 
| 17 | 
         
            +
            from .backtranslation_dataset import BacktranslationDataset
         
     | 
| 18 | 
         
            +
            from .bucket_pad_length_dataset import BucketPadLengthDataset
         
     | 
| 19 | 
         
            +
            from .colorize_dataset import ColorizeDataset
         
     | 
| 20 | 
         
            +
            from .concat_dataset import ConcatDataset
         
     | 
| 21 | 
         
            +
            from .concat_sentences_dataset import ConcatSentencesDataset
         
     | 
| 22 | 
         
            +
            from .denoising_dataset import DenoisingDataset
         
     | 
| 23 | 
         
            +
            from .id_dataset import IdDataset
         
     | 
| 24 | 
         
            +
            from .indexed_dataset import (
         
     | 
| 25 | 
         
            +
                IndexedCachedDataset,
         
     | 
| 26 | 
         
            +
                IndexedDataset,
         
     | 
| 27 | 
         
            +
                IndexedRawTextDataset,
         
     | 
| 28 | 
         
            +
                MMapIndexedDataset,
         
     | 
| 29 | 
         
            +
            )
         
     | 
| 30 | 
         
            +
            from .language_pair_dataset import LanguagePairDataset
         
     | 
| 31 | 
         
            +
            from .list_dataset import ListDataset
         
     | 
| 32 | 
         
            +
            from .lm_context_window_dataset import LMContextWindowDataset
         
     | 
| 33 | 
         
            +
            from .lru_cache_dataset import LRUCacheDataset
         
     | 
| 34 | 
         
            +
            from .mask_tokens_dataset import MaskTokensDataset
         
     | 
| 35 | 
         
            +
            from .monolingual_dataset import MonolingualDataset
         
     | 
| 36 | 
         
            +
            from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
         
     | 
| 37 | 
         
            +
            from .nested_dictionary_dataset import NestedDictionaryDataset
         
     | 
| 38 | 
         
            +
            from .noising import NoisingDataset
         
     | 
| 39 | 
         
            +
            from .numel_dataset import NumelDataset
         
     | 
| 40 | 
         
            +
            from .num_samples_dataset import NumSamplesDataset
         
     | 
| 41 | 
         
            +
            from .offset_tokens_dataset import OffsetTokensDataset
         
     | 
| 42 | 
         
            +
            from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
         
     | 
| 43 | 
         
            +
            from .prepend_dataset import PrependDataset
         
     | 
| 44 | 
         
            +
            from .prepend_token_dataset import PrependTokenDataset
         
     | 
| 45 | 
         
            +
            from .raw_label_dataset import RawLabelDataset
         
     | 
| 46 | 
         
            +
            from .replace_dataset import ReplaceDataset
         
     | 
| 47 | 
         
            +
            from .resampling_dataset import ResamplingDataset
         
     | 
| 48 | 
         
            +
            from .roll_dataset import RollDataset
         
     | 
| 49 | 
         
            +
            from .round_robin_zip_datasets import RoundRobinZipDatasets
         
     | 
| 50 | 
         
            +
            from .sort_dataset import SortDataset
         
     | 
| 51 | 
         
            +
            from .strip_token_dataset import StripTokenDataset
         
     | 
| 52 | 
         
            +
            from .subsample_dataset import SubsampleDataset
         
     | 
| 53 | 
         
            +
            from .token_block_dataset import TokenBlockDataset
         
     | 
| 54 | 
         
            +
            from .transform_eos_dataset import TransformEosDataset
         
     | 
| 55 | 
         
            +
            from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
         
     | 
| 56 | 
         
            +
            from .shorten_dataset import TruncateDataset, RandomCropDataset
         
     | 
| 57 | 
         
            +
            from .multilingual.sampled_multi_dataset import SampledMultiDataset
         
     | 
| 58 | 
         
            +
            from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
         
     | 
| 59 | 
         
            +
            from .fasta_dataset import FastaDataset, EncodedFastaDataset
         
     | 
| 60 | 
         
            +
            from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            from .iterators import (
         
     | 
| 63 | 
         
            +
                CountingIterator,
         
     | 
| 64 | 
         
            +
                EpochBatchIterator,
         
     | 
| 65 | 
         
            +
                GroupedIterator,
         
     | 
| 66 | 
         
            +
                ShardedIterator,
         
     | 
| 67 | 
         
            +
            )
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            __all__ = [
         
     | 
| 70 | 
         
            +
                "AddTargetDataset",
         
     | 
| 71 | 
         
            +
                "AppendTokenDataset",
         
     | 
| 72 | 
         
            +
                "BacktranslationDataset",
         
     | 
| 73 | 
         
            +
                "BaseWrapperDataset",
         
     | 
| 74 | 
         
            +
                "BinarizedAudioDataset",
         
     | 
| 75 | 
         
            +
                "BucketPadLengthDataset",
         
     | 
| 76 | 
         
            +
                "ColorizeDataset",
         
     | 
| 77 | 
         
            +
                "ConcatDataset",
         
     | 
| 78 | 
         
            +
                "ConcatSentencesDataset",
         
     | 
| 79 | 
         
            +
                "CountingIterator",
         
     | 
| 80 | 
         
            +
                "DenoisingDataset",
         
     | 
| 81 | 
         
            +
                "Dictionary",
         
     | 
| 82 | 
         
            +
                "EncodedFastaDataset",
         
     | 
| 83 | 
         
            +
                "EpochBatchIterator",
         
     | 
| 84 | 
         
            +
                "FairseqDataset",
         
     | 
| 85 | 
         
            +
                "FairseqIterableDataset",
         
     | 
| 86 | 
         
            +
                "FastaDataset",
         
     | 
| 87 | 
         
            +
                "FileAudioDataset",
         
     | 
| 88 | 
         
            +
                "GroupedIterator",
         
     | 
| 89 | 
         
            +
                "HubertDataset",
         
     | 
| 90 | 
         
            +
                "IdDataset",
         
     | 
| 91 | 
         
            +
                "IndexedCachedDataset",
         
     | 
| 92 | 
         
            +
                "IndexedDataset",
         
     | 
| 93 | 
         
            +
                "IndexedRawTextDataset",
         
     | 
| 94 | 
         
            +
                "LanguagePairDataset",
         
     | 
| 95 | 
         
            +
                "LeftPadDataset",
         
     | 
| 96 | 
         
            +
                "ListDataset",
         
     | 
| 97 | 
         
            +
                "LMContextWindowDataset",
         
     | 
| 98 | 
         
            +
                "LRUCacheDataset",
         
     | 
| 99 | 
         
            +
                "MaskTokensDataset",
         
     | 
| 100 | 
         
            +
                "MMapIndexedDataset",
         
     | 
| 101 | 
         
            +
                "MonolingualDataset",
         
     | 
| 102 | 
         
            +
                "MultiCorpusSampledDataset",
         
     | 
| 103 | 
         
            +
                "NestedDictionaryDataset",
         
     | 
| 104 | 
         
            +
                "NoisingDataset",
         
     | 
| 105 | 
         
            +
                "NumelDataset",
         
     | 
| 106 | 
         
            +
                "NumSamplesDataset",
         
     | 
| 107 | 
         
            +
                "OffsetTokensDataset",
         
     | 
| 108 | 
         
            +
                "PadDataset",
         
     | 
| 109 | 
         
            +
                "PrependDataset",
         
     | 
| 110 | 
         
            +
                "PrependTokenDataset",
         
     | 
| 111 | 
         
            +
                "RandomCropDataset",
         
     | 
| 112 | 
         
            +
                "RawLabelDataset",
         
     | 
| 113 | 
         
            +
                "ResamplingDataset",
         
     | 
| 114 | 
         
            +
                "ReplaceDataset",
         
     | 
| 115 | 
         
            +
                "RightPadDataset",
         
     | 
| 116 | 
         
            +
                "RollDataset",
         
     | 
| 117 | 
         
            +
                "RoundRobinZipDatasets",
         
     | 
| 118 | 
         
            +
                "SampledMultiDataset",
         
     | 
| 119 | 
         
            +
                "SampledMultiEpochDataset",
         
     | 
| 120 | 
         
            +
                "ShardedIterator",
         
     | 
| 121 | 
         
            +
                "SortDataset",
         
     | 
| 122 | 
         
            +
                "StripTokenDataset",
         
     | 
| 123 | 
         
            +
                "SubsampleDataset",
         
     | 
| 124 | 
         
            +
                "TokenBlockDataset",
         
     | 
| 125 | 
         
            +
                "TransformEosDataset",
         
     | 
| 126 | 
         
            +
                "TransformEosLangPairDataset",
         
     | 
| 127 | 
         
            +
                "TransformEosConcatLangPairDataset",
         
     | 
| 128 | 
         
            +
                "TruncateDataset",
         
     | 
| 129 | 
         
            +
                "TruncatedDictionary",
         
     | 
| 130 | 
         
            +
            ]
         
     | 
    	
        modules/voice_conversion/fairseq/data/add_target_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,83 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from . import BaseWrapperDataset, data_utils
         
     | 
| 9 | 
         
            +
            from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class AddTargetDataset(BaseWrapperDataset):
         
     | 
| 13 | 
         
            +
                def __init__(
         
     | 
| 14 | 
         
            +
                    self,
         
     | 
| 15 | 
         
            +
                    dataset,
         
     | 
| 16 | 
         
            +
                    labels,
         
     | 
| 17 | 
         
            +
                    pad,
         
     | 
| 18 | 
         
            +
                    eos,
         
     | 
| 19 | 
         
            +
                    batch_targets,
         
     | 
| 20 | 
         
            +
                    process_label=None,
         
     | 
| 21 | 
         
            +
                    label_len_fn=None,
         
     | 
| 22 | 
         
            +
                    add_to_input=False,
         
     | 
| 23 | 
         
            +
                    text_compression_level=TextCompressionLevel.none,
         
     | 
| 24 | 
         
            +
                ):
         
     | 
| 25 | 
         
            +
                    super().__init__(dataset)
         
     | 
| 26 | 
         
            +
                    self.labels = labels
         
     | 
| 27 | 
         
            +
                    self.batch_targets = batch_targets
         
     | 
| 28 | 
         
            +
                    self.pad = pad
         
     | 
| 29 | 
         
            +
                    self.eos = eos
         
     | 
| 30 | 
         
            +
                    self.process_label = process_label
         
     | 
| 31 | 
         
            +
                    self.label_len_fn = label_len_fn
         
     | 
| 32 | 
         
            +
                    self.add_to_input = add_to_input
         
     | 
| 33 | 
         
            +
                    self.text_compressor = TextCompressor(level=text_compression_level)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def get_label(self, index, process_fn=None):
         
     | 
| 36 | 
         
            +
                    lbl = self.labels[index]
         
     | 
| 37 | 
         
            +
                    lbl = self.text_compressor.decompress(lbl)
         
     | 
| 38 | 
         
            +
                    return lbl if process_fn is None else process_fn(lbl)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 41 | 
         
            +
                    item = self.dataset[index]
         
     | 
| 42 | 
         
            +
                    item["label"] = self.get_label(index, process_fn=self.process_label)
         
     | 
| 43 | 
         
            +
                    return item
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def size(self, index):
         
     | 
| 46 | 
         
            +
                    sz = self.dataset.size(index)
         
     | 
| 47 | 
         
            +
                    own_sz = self.label_len_fn(self.get_label(index))
         
     | 
| 48 | 
         
            +
                    return sz, own_sz
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def collater(self, samples):
         
     | 
| 51 | 
         
            +
                    collated = self.dataset.collater(samples)
         
     | 
| 52 | 
         
            +
                    if len(collated) == 0:
         
     | 
| 53 | 
         
            +
                        return collated
         
     | 
| 54 | 
         
            +
                    indices = set(collated["id"].tolist())
         
     | 
| 55 | 
         
            +
                    target = [s["label"] for s in samples if s["id"] in indices]
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    if self.add_to_input:
         
     | 
| 58 | 
         
            +
                        eos = torch.LongTensor([self.eos])
         
     | 
| 59 | 
         
            +
                        prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
         
     | 
| 60 | 
         
            +
                        target = [torch.cat([t, eos], axis=-1) for t in target]
         
     | 
| 61 | 
         
            +
                        collated["net_input"]["prev_output_tokens"] = prev_output_tokens
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    if self.batch_targets:
         
     | 
| 64 | 
         
            +
                        collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
         
     | 
| 65 | 
         
            +
                        target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
         
     | 
| 66 | 
         
            +
                        collated["ntokens"] = collated["target_lengths"].sum().item()
         
     | 
| 67 | 
         
            +
                        if getattr(collated["net_input"], "prev_output_tokens", None):
         
     | 
| 68 | 
         
            +
                            collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
         
     | 
| 69 | 
         
            +
                                collated["net_input"]["prev_output_tokens"],
         
     | 
| 70 | 
         
            +
                                pad_idx=self.pad,
         
     | 
| 71 | 
         
            +
                                left_pad=False,
         
     | 
| 72 | 
         
            +
                            )
         
     | 
| 73 | 
         
            +
                    else:
         
     | 
| 74 | 
         
            +
                        collated["ntokens"] = sum([len(t) for t in target])
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    collated["target"] = target
         
     | 
| 77 | 
         
            +
                    return collated
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                def filter_indices_by_size(self, indices, max_sizes):
         
     | 
| 80 | 
         
            +
                    indices, ignored = data_utils._filter_by_size_dynamic(
         
     | 
| 81 | 
         
            +
                        indices, self.size, max_sizes
         
     | 
| 82 | 
         
            +
                    )
         
     | 
| 83 | 
         
            +
                    return indices, ignored
         
     | 
    	
        modules/voice_conversion/fairseq/data/append_token_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from . import BaseWrapperDataset
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class AppendTokenDataset(BaseWrapperDataset):
         
     | 
| 13 | 
         
            +
                def __init__(self, dataset, token=None):
         
     | 
| 14 | 
         
            +
                    super().__init__(dataset)
         
     | 
| 15 | 
         
            +
                    self.token = token
         
     | 
| 16 | 
         
            +
                    if token is not None:
         
     | 
| 17 | 
         
            +
                        self._sizes = np.array(dataset.sizes) + 1
         
     | 
| 18 | 
         
            +
                    else:
         
     | 
| 19 | 
         
            +
                        self._sizes = dataset.sizes
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 22 | 
         
            +
                    item = self.dataset[idx]
         
     | 
| 23 | 
         
            +
                    if self.token is not None:
         
     | 
| 24 | 
         
            +
                        item = torch.cat([item, item.new([self.token])])
         
     | 
| 25 | 
         
            +
                    return item
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                @property
         
     | 
| 28 | 
         
            +
                def sizes(self):
         
     | 
| 29 | 
         
            +
                    return self._sizes
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 32 | 
         
            +
                    n = self.dataset.num_tokens(index)
         
     | 
| 33 | 
         
            +
                    if self.token is not None:
         
     | 
| 34 | 
         
            +
                        n += 1
         
     | 
| 35 | 
         
            +
                    return n
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def size(self, index):
         
     | 
| 38 | 
         
            +
                    n = self.dataset.size(index)
         
     | 
| 39 | 
         
            +
                    if self.token is not None:
         
     | 
| 40 | 
         
            +
                        n += 1
         
     | 
| 41 | 
         
            +
                    return n
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,93 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 2 | 
         
            +
            from typing import Dict, Optional
         
     | 
| 3 | 
         
            +
            import importlib
         
     | 
| 4 | 
         
            +
            import os
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class AudioTransform(ABC):
         
     | 
| 9 | 
         
            +
                @classmethod
         
     | 
| 10 | 
         
            +
                @abstractmethod
         
     | 
| 11 | 
         
            +
                def from_config_dict(cls, config: Optional[Dict] = None):
         
     | 
| 12 | 
         
            +
                    pass
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class CompositeAudioTransform(AudioTransform):
         
     | 
| 16 | 
         
            +
                def _from_config_dict(
         
     | 
| 17 | 
         
            +
                    cls,
         
     | 
| 18 | 
         
            +
                    transform_type,
         
     | 
| 19 | 
         
            +
                    get_audio_transform,
         
     | 
| 20 | 
         
            +
                    composite_cls,
         
     | 
| 21 | 
         
            +
                    config=None,
         
     | 
| 22 | 
         
            +
                    return_empty=False,
         
     | 
| 23 | 
         
            +
                ):
         
     | 
| 24 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 25 | 
         
            +
                    _transforms = _config.get(f"{transform_type}_transforms")
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    if _transforms is None:
         
     | 
| 28 | 
         
            +
                        if return_empty:
         
     | 
| 29 | 
         
            +
                            _transforms = []
         
     | 
| 30 | 
         
            +
                        else:
         
     | 
| 31 | 
         
            +
                            return None
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    transforms = [
         
     | 
| 34 | 
         
            +
                        get_audio_transform(_t).from_config_dict(_config.get(_t))
         
     | 
| 35 | 
         
            +
                        for _t in _transforms
         
     | 
| 36 | 
         
            +
                    ]
         
     | 
| 37 | 
         
            +
                    return composite_cls(transforms)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                def __init__(self, transforms):
         
     | 
| 40 | 
         
            +
                    self.transforms = [t for t in transforms if t is not None]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def __call__(self, x):
         
     | 
| 43 | 
         
            +
                    for t in self.transforms:
         
     | 
| 44 | 
         
            +
                        x = t(x)
         
     | 
| 45 | 
         
            +
                    return x
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def __repr__(self):
         
     | 
| 48 | 
         
            +
                    format_string = (
         
     | 
| 49 | 
         
            +
                        [self.__class__.__name__ + "("]
         
     | 
| 50 | 
         
            +
                        + [f"    {t.__repr__()}" for t in self.transforms]
         
     | 
| 51 | 
         
            +
                        + [")"]
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
                    return "\n".join(format_string)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def register_audio_transform(name, cls_type, registry, class_names):
         
     | 
| 57 | 
         
            +
                def register_audio_transform_cls(cls):
         
     | 
| 58 | 
         
            +
                    if name in registry:
         
     | 
| 59 | 
         
            +
                        raise ValueError(f"Cannot register duplicate transform ({name})")
         
     | 
| 60 | 
         
            +
                    if not issubclass(cls, cls_type):
         
     | 
| 61 | 
         
            +
                        raise ValueError(
         
     | 
| 62 | 
         
            +
                            f"Transform ({name}: {cls.__name__}) must extend "
         
     | 
| 63 | 
         
            +
                            f"{cls_type.__name__}"
         
     | 
| 64 | 
         
            +
                        )
         
     | 
| 65 | 
         
            +
                    if cls.__name__ in class_names:
         
     | 
| 66 | 
         
            +
                        raise ValueError(
         
     | 
| 67 | 
         
            +
                            f"Cannot register audio transform with duplicate "
         
     | 
| 68 | 
         
            +
                            f"class name ({cls.__name__})"
         
     | 
| 69 | 
         
            +
                        )
         
     | 
| 70 | 
         
            +
                    registry[name] = cls
         
     | 
| 71 | 
         
            +
                    class_names.add(cls.__name__)
         
     | 
| 72 | 
         
            +
                    return cls
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                return register_audio_transform_cls
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            def import_transforms(transforms_dir, transform_type):
         
     | 
| 78 | 
         
            +
                for file in os.listdir(transforms_dir):
         
     | 
| 79 | 
         
            +
                    path = os.path.join(transforms_dir, file)
         
     | 
| 80 | 
         
            +
                    if (
         
     | 
| 81 | 
         
            +
                        not file.startswith("_")
         
     | 
| 82 | 
         
            +
                        and not file.startswith(".")
         
     | 
| 83 | 
         
            +
                        and (file.endswith(".py") or os.path.isdir(path))
         
     | 
| 84 | 
         
            +
                    ):
         
     | 
| 85 | 
         
            +
                        name = file[: file.find(".py")] if file.endswith(".py") else file
         
     | 
| 86 | 
         
            +
                        importlib.import_module(
         
     | 
| 87 | 
         
            +
                            f"fairseq.data.audio.{transform_type}_transforms." + name
         
     | 
| 88 | 
         
            +
                        )
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            # Utility fn for uniform numbers in transforms
         
     | 
| 92 | 
         
            +
            def rand_uniform(a, b):
         
     | 
| 93 | 
         
            +
                return np.random.uniform() * (b - a) + a
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/audio_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,389 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import mmap
         
     | 
| 8 | 
         
            +
            from pathlib import Path
         
     | 
| 9 | 
         
            +
            import io
         
     | 
| 10 | 
         
            +
            from typing import BinaryIO, List, Optional, Tuple, Union
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import torch
         
     | 
| 14 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
         
     | 
| 19 | 
         
            +
            FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def convert_waveform(
         
     | 
| 23 | 
         
            +
                waveform: Union[np.ndarray, torch.Tensor],
         
     | 
| 24 | 
         
            +
                sample_rate: int,
         
     | 
| 25 | 
         
            +
                normalize_volume: bool = False,
         
     | 
| 26 | 
         
            +
                to_mono: bool = False,
         
     | 
| 27 | 
         
            +
                to_sample_rate: Optional[int] = None,
         
     | 
| 28 | 
         
            +
            ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
         
     | 
| 29 | 
         
            +
                """convert a waveform:
         
     | 
| 30 | 
         
            +
                - to a target sample rate
         
     | 
| 31 | 
         
            +
                - from multi-channel to mono channel
         
     | 
| 32 | 
         
            +
                - volume normalization
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                Args:
         
     | 
| 35 | 
         
            +
                    waveform (numpy.ndarray or torch.Tensor): 2D original waveform
         
     | 
| 36 | 
         
            +
                        (channels x length)
         
     | 
| 37 | 
         
            +
                    sample_rate (int): original sample rate
         
     | 
| 38 | 
         
            +
                    normalize_volume (bool): perform volume normalization
         
     | 
| 39 | 
         
            +
                    to_mono (bool): convert to mono channel if having multiple channels
         
     | 
| 40 | 
         
            +
                    to_sample_rate (Optional[int]): target sample rate
         
     | 
| 41 | 
         
            +
                Returns:
         
     | 
| 42 | 
         
            +
                    waveform (numpy.ndarray): converted 2D waveform (channels x length)
         
     | 
| 43 | 
         
            +
                    sample_rate (float): target sample rate
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                try:
         
     | 
| 46 | 
         
            +
                    import torchaudio.sox_effects as ta_sox
         
     | 
| 47 | 
         
            +
                except ImportError:
         
     | 
| 48 | 
         
            +
                    raise ImportError("Please install torchaudio: pip install torchaudio")
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                effects = []
         
     | 
| 51 | 
         
            +
                if normalize_volume:
         
     | 
| 52 | 
         
            +
                    effects.append(["gain", "-n"])
         
     | 
| 53 | 
         
            +
                if to_sample_rate is not None and to_sample_rate != sample_rate:
         
     | 
| 54 | 
         
            +
                    effects.append(["rate", f"{to_sample_rate}"])
         
     | 
| 55 | 
         
            +
                if to_mono and waveform.shape[0] > 1:
         
     | 
| 56 | 
         
            +
                    effects.append(["channels", "1"])
         
     | 
| 57 | 
         
            +
                if len(effects) > 0:
         
     | 
| 58 | 
         
            +
                    is_np_input = isinstance(waveform, np.ndarray)
         
     | 
| 59 | 
         
            +
                    _waveform = torch.from_numpy(waveform) if is_np_input else waveform
         
     | 
| 60 | 
         
            +
                    converted, converted_sample_rate = ta_sox.apply_effects_tensor(
         
     | 
| 61 | 
         
            +
                        _waveform, sample_rate, effects
         
     | 
| 62 | 
         
            +
                    )
         
     | 
| 63 | 
         
            +
                    if is_np_input:
         
     | 
| 64 | 
         
            +
                        converted = converted.numpy()
         
     | 
| 65 | 
         
            +
                    return converted, converted_sample_rate
         
     | 
| 66 | 
         
            +
                return waveform, sample_rate
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            def get_waveform(
         
     | 
| 70 | 
         
            +
                path_or_fp: Union[str, BinaryIO],
         
     | 
| 71 | 
         
            +
                normalization: bool = True,
         
     | 
| 72 | 
         
            +
                mono: bool = True,
         
     | 
| 73 | 
         
            +
                frames: int = -1,
         
     | 
| 74 | 
         
            +
                start: int = 0,
         
     | 
| 75 | 
         
            +
                always_2d: bool = True,
         
     | 
| 76 | 
         
            +
                output_sample_rate: Optional[int] = None,
         
     | 
| 77 | 
         
            +
                normalize_volume: bool = False,
         
     | 
| 78 | 
         
            +
                waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
         
     | 
| 79 | 
         
            +
            ) -> Tuple[np.ndarray, int]:
         
     | 
| 80 | 
         
            +
                """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                Args:
         
     | 
| 83 | 
         
            +
                    path_or_fp (str or BinaryIO): the path or file-like object
         
     | 
| 84 | 
         
            +
                    normalization (bool): normalize values to [-1, 1] (Default: True)
         
     | 
| 85 | 
         
            +
                    mono (bool): convert multi-channel audio to mono-channel one
         
     | 
| 86 | 
         
            +
                    frames (int): the number of frames to read. (-1 for reading all)
         
     | 
| 87 | 
         
            +
                    start (int): Where to start reading. A negative value counts from the end.
         
     | 
| 88 | 
         
            +
                    always_2d (bool): always return 2D array even for mono-channel audios
         
     | 
| 89 | 
         
            +
                    output_sample_rate (Optional[int]): output sample rate
         
     | 
| 90 | 
         
            +
                    normalize_volume (bool): normalize volume
         
     | 
| 91 | 
         
            +
                Returns:
         
     | 
| 92 | 
         
            +
                    waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
         
     | 
| 93 | 
         
            +
                    sample_rate (float): sample rate
         
     | 
| 94 | 
         
            +
                """
         
     | 
| 95 | 
         
            +
                if isinstance(path_or_fp, str):
         
     | 
| 96 | 
         
            +
                    ext = Path(path_or_fp).suffix
         
     | 
| 97 | 
         
            +
                    if ext not in SF_AUDIO_FILE_EXTENSIONS:
         
     | 
| 98 | 
         
            +
                        raise ValueError(f"Unsupported audio format: {ext}")
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                try:
         
     | 
| 101 | 
         
            +
                    import soundfile as sf
         
     | 
| 102 | 
         
            +
                except ImportError:
         
     | 
| 103 | 
         
            +
                    raise ImportError("Please install soundfile: pip install soundfile")
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                waveform, sample_rate = sf.read(
         
     | 
| 106 | 
         
            +
                    path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
         
     | 
| 107 | 
         
            +
                )
         
     | 
| 108 | 
         
            +
                waveform = waveform.T  # T x C -> C x T
         
     | 
| 109 | 
         
            +
                waveform, sample_rate = convert_waveform(
         
     | 
| 110 | 
         
            +
                    waveform,
         
     | 
| 111 | 
         
            +
                    sample_rate,
         
     | 
| 112 | 
         
            +
                    normalize_volume=normalize_volume,
         
     | 
| 113 | 
         
            +
                    to_mono=mono,
         
     | 
| 114 | 
         
            +
                    to_sample_rate=output_sample_rate,
         
     | 
| 115 | 
         
            +
                )
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                if not normalization:
         
     | 
| 118 | 
         
            +
                    waveform *= 2**15  # denormalized to 16-bit signed integers
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                if waveform_transforms is not None:
         
     | 
| 121 | 
         
            +
                    waveform, sample_rate = waveform_transforms(waveform, sample_rate)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                if not always_2d:
         
     | 
| 124 | 
         
            +
                    waveform = waveform.squeeze(axis=0)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                return waveform, sample_rate
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            def get_features_from_npy_or_audio(path, waveform_transforms=None):
         
     | 
| 130 | 
         
            +
                ext = Path(path).suffix
         
     | 
| 131 | 
         
            +
                if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
         
     | 
| 132 | 
         
            +
                    raise ValueError(f'Unsupported file format for "{path}"')
         
     | 
| 133 | 
         
            +
                return (
         
     | 
| 134 | 
         
            +
                    np.load(path)
         
     | 
| 135 | 
         
            +
                    if ext == ".npy"
         
     | 
| 136 | 
         
            +
                    else get_fbank(path, waveform_transforms=waveform_transforms)
         
     | 
| 137 | 
         
            +
                )
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
            def get_features_or_waveform_from_stored_zip(
         
     | 
| 141 | 
         
            +
                path,
         
     | 
| 142 | 
         
            +
                byte_offset,
         
     | 
| 143 | 
         
            +
                byte_size,
         
     | 
| 144 | 
         
            +
                need_waveform=False,
         
     | 
| 145 | 
         
            +
                use_sample_rate=None,
         
     | 
| 146 | 
         
            +
                waveform_transforms=None,
         
     | 
| 147 | 
         
            +
            ):
         
     | 
| 148 | 
         
            +
                assert path.endswith(".zip")
         
     | 
| 149 | 
         
            +
                data = read_from_stored_zip(path, byte_offset, byte_size)
         
     | 
| 150 | 
         
            +
                f = io.BytesIO(data)
         
     | 
| 151 | 
         
            +
                if is_npy_data(data):
         
     | 
| 152 | 
         
            +
                    features_or_waveform = np.load(f)
         
     | 
| 153 | 
         
            +
                elif is_sf_audio_data(data):
         
     | 
| 154 | 
         
            +
                    features_or_waveform = (
         
     | 
| 155 | 
         
            +
                        get_waveform(
         
     | 
| 156 | 
         
            +
                            f,
         
     | 
| 157 | 
         
            +
                            always_2d=False,
         
     | 
| 158 | 
         
            +
                            output_sample_rate=use_sample_rate,
         
     | 
| 159 | 
         
            +
                            waveform_transforms=waveform_transforms,
         
     | 
| 160 | 
         
            +
                        )[0]
         
     | 
| 161 | 
         
            +
                        if need_waveform
         
     | 
| 162 | 
         
            +
                        else get_fbank(f, waveform_transforms=waveform_transforms)
         
     | 
| 163 | 
         
            +
                    )
         
     | 
| 164 | 
         
            +
                else:
         
     | 
| 165 | 
         
            +
                    raise ValueError(f'Unknown file format for "{path}"')
         
     | 
| 166 | 
         
            +
                return features_or_waveform
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            def get_features_or_waveform(
         
     | 
| 170 | 
         
            +
                path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
         
     | 
| 171 | 
         
            +
            ):
         
     | 
| 172 | 
         
            +
                """Get speech features from .npy file or waveform from .wav/.flac file.
         
     | 
| 173 | 
         
            +
                The file may be inside an uncompressed ZIP file and is accessed via byte
         
     | 
| 174 | 
         
            +
                offset and length.
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                Args:
         
     | 
| 177 | 
         
            +
                    path (str): File path in the format of "<.npy/.wav/.flac path>" or
         
     | 
| 178 | 
         
            +
                    "<zip path>:<byte offset>:<byte length>".
         
     | 
| 179 | 
         
            +
                    need_waveform (bool): return waveform instead of features.
         
     | 
| 180 | 
         
            +
                    use_sample_rate (int): change sample rate for the input wave file
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                Returns:
         
     | 
| 183 | 
         
            +
                    features_or_waveform (numpy.ndarray): speech features or waveform.
         
     | 
| 184 | 
         
            +
                """
         
     | 
| 185 | 
         
            +
                _path, slice_ptr = parse_path(path)
         
     | 
| 186 | 
         
            +
                if len(slice_ptr) == 0:
         
     | 
| 187 | 
         
            +
                    if need_waveform:
         
     | 
| 188 | 
         
            +
                        return get_waveform(
         
     | 
| 189 | 
         
            +
                            _path,
         
     | 
| 190 | 
         
            +
                            always_2d=False,
         
     | 
| 191 | 
         
            +
                            output_sample_rate=use_sample_rate,
         
     | 
| 192 | 
         
            +
                            waveform_transforms=waveform_transforms,
         
     | 
| 193 | 
         
            +
                        )[0]
         
     | 
| 194 | 
         
            +
                    return get_features_from_npy_or_audio(
         
     | 
| 195 | 
         
            +
                        _path, waveform_transforms=waveform_transforms
         
     | 
| 196 | 
         
            +
                    )
         
     | 
| 197 | 
         
            +
                elif len(slice_ptr) == 2:
         
     | 
| 198 | 
         
            +
                    features_or_waveform = get_features_or_waveform_from_stored_zip(
         
     | 
| 199 | 
         
            +
                        _path,
         
     | 
| 200 | 
         
            +
                        slice_ptr[0],
         
     | 
| 201 | 
         
            +
                        slice_ptr[1],
         
     | 
| 202 | 
         
            +
                        need_waveform=need_waveform,
         
     | 
| 203 | 
         
            +
                        use_sample_rate=use_sample_rate,
         
     | 
| 204 | 
         
            +
                        waveform_transforms=waveform_transforms,
         
     | 
| 205 | 
         
            +
                    )
         
     | 
| 206 | 
         
            +
                else:
         
     | 
| 207 | 
         
            +
                    raise ValueError(f"Invalid path: {path}")
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                return features_or_waveform
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
            def _get_kaldi_fbank(
         
     | 
| 213 | 
         
            +
                waveform: np.ndarray, sample_rate: int, n_bins=80
         
     | 
| 214 | 
         
            +
            ) -> Optional[np.ndarray]:
         
     | 
| 215 | 
         
            +
                """Get mel-filter bank features via PyKaldi."""
         
     | 
| 216 | 
         
            +
                try:
         
     | 
| 217 | 
         
            +
                    from kaldi.feat.fbank import Fbank, FbankOptions
         
     | 
| 218 | 
         
            +
                    from kaldi.feat.mel import MelBanksOptions
         
     | 
| 219 | 
         
            +
                    from kaldi.feat.window import FrameExtractionOptions
         
     | 
| 220 | 
         
            +
                    from kaldi.matrix import Vector
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    mel_opts = MelBanksOptions()
         
     | 
| 223 | 
         
            +
                    mel_opts.num_bins = n_bins
         
     | 
| 224 | 
         
            +
                    frame_opts = FrameExtractionOptions()
         
     | 
| 225 | 
         
            +
                    frame_opts.samp_freq = sample_rate
         
     | 
| 226 | 
         
            +
                    opts = FbankOptions()
         
     | 
| 227 | 
         
            +
                    opts.mel_opts = mel_opts
         
     | 
| 228 | 
         
            +
                    opts.frame_opts = frame_opts
         
     | 
| 229 | 
         
            +
                    fbank = Fbank(opts=opts)
         
     | 
| 230 | 
         
            +
                    features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
         
     | 
| 231 | 
         
            +
                    return features
         
     | 
| 232 | 
         
            +
                except ImportError:
         
     | 
| 233 | 
         
            +
                    return None
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
            def _get_torchaudio_fbank(
         
     | 
| 237 | 
         
            +
                waveform: np.ndarray, sample_rate, n_bins=80
         
     | 
| 238 | 
         
            +
            ) -> Optional[np.ndarray]:
         
     | 
| 239 | 
         
            +
                """Get mel-filter bank features via TorchAudio."""
         
     | 
| 240 | 
         
            +
                try:
         
     | 
| 241 | 
         
            +
                    import torchaudio.compliance.kaldi as ta_kaldi
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    waveform = torch.from_numpy(waveform)
         
     | 
| 244 | 
         
            +
                    features = ta_kaldi.fbank(
         
     | 
| 245 | 
         
            +
                        waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
         
     | 
| 246 | 
         
            +
                    )
         
     | 
| 247 | 
         
            +
                    return features.numpy()
         
     | 
| 248 | 
         
            +
                except ImportError:
         
     | 
| 249 | 
         
            +
                    return None
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
            def get_fbank(
         
     | 
| 253 | 
         
            +
                path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
         
     | 
| 254 | 
         
            +
            ) -> np.ndarray:
         
     | 
| 255 | 
         
            +
                """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
         
     | 
| 256 | 
         
            +
                (faster CPP implementation) to TorchAudio (Python implementation). Note that
         
     | 
| 257 | 
         
            +
                Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
         
     | 
| 258 | 
         
            +
                waveform should not be normalized."""
         
     | 
| 259 | 
         
            +
                waveform, sample_rate = get_waveform(
         
     | 
| 260 | 
         
            +
                    path_or_fp, normalization=False, waveform_transforms=waveform_transforms
         
     | 
| 261 | 
         
            +
                )
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
         
     | 
| 264 | 
         
            +
                if features is None:
         
     | 
| 265 | 
         
            +
                    features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
         
     | 
| 266 | 
         
            +
                if features is None:
         
     | 
| 267 | 
         
            +
                    raise ImportError(
         
     | 
| 268 | 
         
            +
                        "Please install pyKaldi or torchaudio to enable "
         
     | 
| 269 | 
         
            +
                        "online filterbank feature extraction"
         
     | 
| 270 | 
         
            +
                    )
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                return features
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
            def is_npy_data(data: bytes) -> bool:
         
     | 
| 276 | 
         
            +
                return data[0] == 147 and data[1] == 78
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
            def is_sf_audio_data(data: bytes) -> bool:
         
     | 
| 280 | 
         
            +
                is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
         
     | 
| 281 | 
         
            +
                is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
         
     | 
| 282 | 
         
            +
                is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
         
     | 
| 283 | 
         
            +
                return is_wav or is_flac or is_ogg
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
            def mmap_read(path: str, offset: int, length: int) -> bytes:
         
     | 
| 287 | 
         
            +
                with open(path, "rb") as f:
         
     | 
| 288 | 
         
            +
                    with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
         
     | 
| 289 | 
         
            +
                        data = mmap_o[offset : offset + length]
         
     | 
| 290 | 
         
            +
                return data
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
            def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
         
     | 
| 294 | 
         
            +
                return mmap_read(zip_path, offset, length)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
            def parse_path(path: str) -> Tuple[str, List[int]]:
         
     | 
| 298 | 
         
            +
                """Parse data path which is either a path to
         
     | 
| 299 | 
         
            +
                1. a .npy/.wav/.flac/.ogg file
         
     | 
| 300 | 
         
            +
                2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                  Args:
         
     | 
| 303 | 
         
            +
                      path (str): the data path to parse
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                  Returns:
         
     | 
| 306 | 
         
            +
                      file_path (str): the file path
         
     | 
| 307 | 
         
            +
                      slice_ptr (list of int): empty in case 1;
         
     | 
| 308 | 
         
            +
                        byte offset and length for the slice in case 2
         
     | 
| 309 | 
         
            +
                """
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
         
     | 
| 312 | 
         
            +
                    _path, slice_ptr = path, []
         
     | 
| 313 | 
         
            +
                else:
         
     | 
| 314 | 
         
            +
                    _path, *slice_ptr = path.split(":")
         
     | 
| 315 | 
         
            +
                    if not Path(_path).is_file():
         
     | 
| 316 | 
         
            +
                        raise FileNotFoundError(f"File not found: {_path}")
         
     | 
| 317 | 
         
            +
                assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
         
     | 
| 318 | 
         
            +
                slice_ptr = [int(i) for i in slice_ptr]
         
     | 
| 319 | 
         
            +
                return _path, slice_ptr
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
            def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
         
     | 
| 323 | 
         
            +
                padding = n_fft - win_length
         
     | 
| 324 | 
         
            +
                assert padding >= 0
         
     | 
| 325 | 
         
            +
                return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
            def get_fourier_basis(n_fft: int) -> torch.Tensor:
         
     | 
| 329 | 
         
            +
                basis = np.fft.fft(np.eye(n_fft))
         
     | 
| 330 | 
         
            +
                basis = np.vstack(
         
     | 
| 331 | 
         
            +
                    [np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
         
     | 
| 332 | 
         
            +
                )
         
     | 
| 333 | 
         
            +
                return torch.from_numpy(basis).float()
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
            def get_mel_filters(
         
     | 
| 337 | 
         
            +
                sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
         
     | 
| 338 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 339 | 
         
            +
                try:
         
     | 
| 340 | 
         
            +
                    import librosa
         
     | 
| 341 | 
         
            +
                except ImportError:
         
     | 
| 342 | 
         
            +
                    raise ImportError("Please install librosa: pip install librosa")
         
     | 
| 343 | 
         
            +
                basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
         
     | 
| 344 | 
         
            +
                return torch.from_numpy(basis).float()
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
            class TTSSpectrogram(torch.nn.Module):
         
     | 
| 348 | 
         
            +
                def __init__(
         
     | 
| 349 | 
         
            +
                    self,
         
     | 
| 350 | 
         
            +
                    n_fft: int,
         
     | 
| 351 | 
         
            +
                    win_length: int,
         
     | 
| 352 | 
         
            +
                    hop_length: int,
         
     | 
| 353 | 
         
            +
                    window_fn: callable = torch.hann_window,
         
     | 
| 354 | 
         
            +
                    return_phase: bool = False,
         
     | 
| 355 | 
         
            +
                ) -> None:
         
     | 
| 356 | 
         
            +
                    super(TTSSpectrogram, self).__init__()
         
     | 
| 357 | 
         
            +
                    self.n_fft = n_fft
         
     | 
| 358 | 
         
            +
                    self.hop_length = hop_length
         
     | 
| 359 | 
         
            +
                    self.return_phase = return_phase
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    basis = get_fourier_basis(n_fft).unsqueeze(1)
         
     | 
| 362 | 
         
            +
                    basis *= get_window(window_fn, n_fft, win_length)
         
     | 
| 363 | 
         
            +
                    self.register_buffer("basis", basis)
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                def forward(
         
     | 
| 366 | 
         
            +
                    self, waveform: torch.Tensor
         
     | 
| 367 | 
         
            +
                ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         
     | 
| 368 | 
         
            +
                    padding = (self.n_fft // 2, self.n_fft // 2)
         
     | 
| 369 | 
         
            +
                    x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
         
     | 
| 370 | 
         
            +
                    x = F.conv1d(x, self.basis, stride=self.hop_length)
         
     | 
| 371 | 
         
            +
                    real_part = x[:, : self.n_fft // 2 + 1, :]
         
     | 
| 372 | 
         
            +
                    imag_part = x[:, self.n_fft // 2 + 1 :, :]
         
     | 
| 373 | 
         
            +
                    magnitude = torch.sqrt(real_part**2 + imag_part**2)
         
     | 
| 374 | 
         
            +
                    if self.return_phase:
         
     | 
| 375 | 
         
            +
                        phase = torch.atan2(imag_part, real_part)
         
     | 
| 376 | 
         
            +
                        return magnitude, phase
         
     | 
| 377 | 
         
            +
                    return magnitude
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
            class TTSMelScale(torch.nn.Module):
         
     | 
| 381 | 
         
            +
                def __init__(
         
     | 
| 382 | 
         
            +
                    self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
         
     | 
| 383 | 
         
            +
                ) -> None:
         
     | 
| 384 | 
         
            +
                    super(TTSMelScale, self).__init__()
         
     | 
| 385 | 
         
            +
                    basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
         
     | 
| 386 | 
         
            +
                    self.register_buffer("basis", basis)
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                def forward(self, specgram: torch.Tensor) -> torch.Tensor:
         
     | 
| 389 | 
         
            +
                    return torch.matmul(self.basis, specgram)
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/data_cfg.py
    ADDED
    
    | 
         @@ -0,0 +1,387 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import logging
         
     | 
| 7 | 
         
            +
            from argparse import Namespace
         
     | 
| 8 | 
         
            +
            from copy import deepcopy
         
     | 
| 9 | 
         
            +
            from pathlib import Path
         
     | 
| 10 | 
         
            +
            from typing import Dict, Optional
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from fairseq.data import Dictionary
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def get_config_from_yaml(yaml_path: Path):
         
     | 
| 18 | 
         
            +
                try:
         
     | 
| 19 | 
         
            +
                    import yaml
         
     | 
| 20 | 
         
            +
                except ImportError:
         
     | 
| 21 | 
         
            +
                    print("Please install PyYAML: pip install PyYAML")
         
     | 
| 22 | 
         
            +
                config = {}
         
     | 
| 23 | 
         
            +
                if yaml_path.is_file():
         
     | 
| 24 | 
         
            +
                    try:
         
     | 
| 25 | 
         
            +
                        with open(yaml_path) as f:
         
     | 
| 26 | 
         
            +
                            config = yaml.load(f, Loader=yaml.FullLoader)
         
     | 
| 27 | 
         
            +
                    except Exception as e:
         
     | 
| 28 | 
         
            +
                        raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
         
     | 
| 29 | 
         
            +
                else:
         
     | 
| 30 | 
         
            +
                    raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                return config
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            class S2TDataConfig(object):
         
     | 
| 36 | 
         
            +
                """Wrapper class for data config YAML"""
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def __init__(self, yaml_path: Path):
         
     | 
| 39 | 
         
            +
                    self.config = get_config_from_yaml(yaml_path)
         
     | 
| 40 | 
         
            +
                    self.root = yaml_path.parent
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def _auto_convert_to_abs_path(self, x):
         
     | 
| 43 | 
         
            +
                    if isinstance(x, str):
         
     | 
| 44 | 
         
            +
                        if not Path(x).exists() and (self.root / x).exists():
         
     | 
| 45 | 
         
            +
                            return (self.root / x).as_posix()
         
     | 
| 46 | 
         
            +
                    elif isinstance(x, dict):
         
     | 
| 47 | 
         
            +
                        return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
         
     | 
| 48 | 
         
            +
                    return x
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                @property
         
     | 
| 51 | 
         
            +
                def vocab_filename(self):
         
     | 
| 52 | 
         
            +
                    """fairseq vocabulary file under data root"""
         
     | 
| 53 | 
         
            +
                    return self.config.get("vocab_filename", "dict.txt")
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                @property
         
     | 
| 56 | 
         
            +
                def speaker_set_filename(self):
         
     | 
| 57 | 
         
            +
                    """speaker set file under data root"""
         
     | 
| 58 | 
         
            +
                    return self.config.get("speaker_set_filename", None)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                @property
         
     | 
| 61 | 
         
            +
                def shuffle(self) -> bool:
         
     | 
| 62 | 
         
            +
                    """Shuffle dataset samples before batching"""
         
     | 
| 63 | 
         
            +
                    return self.config.get("shuffle", False)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                @property
         
     | 
| 66 | 
         
            +
                def pre_tokenizer(self) -> Dict:
         
     | 
| 67 | 
         
            +
                    """Pre-tokenizer to apply before subword tokenization. Returning
         
     | 
| 68 | 
         
            +
                    a dictionary with `tokenizer` providing the tokenizer name and
         
     | 
| 69 | 
         
            +
                    the other items providing the tokenizer-specific arguments.
         
     | 
| 70 | 
         
            +
                    Tokenizers are defined in `fairseq.data.encoders.*`"""
         
     | 
| 71 | 
         
            +
                    tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
         
     | 
| 72 | 
         
            +
                    return self._auto_convert_to_abs_path(tokenizer)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                @property
         
     | 
| 75 | 
         
            +
                def bpe_tokenizer(self) -> Dict:
         
     | 
| 76 | 
         
            +
                    """Subword tokenizer to apply after pre-tokenization. Returning
         
     | 
| 77 | 
         
            +
                    a dictionary with `bpe` providing the tokenizer name and
         
     | 
| 78 | 
         
            +
                    the other items providing the tokenizer-specific arguments.
         
     | 
| 79 | 
         
            +
                    Tokenizers are defined in `fairseq.data.encoders.*`"""
         
     | 
| 80 | 
         
            +
                    tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
         
     | 
| 81 | 
         
            +
                    return self._auto_convert_to_abs_path(tokenizer)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                @property
         
     | 
| 84 | 
         
            +
                def prepend_tgt_lang_tag(self) -> bool:
         
     | 
| 85 | 
         
            +
                    """Prepend target lang ID token as the target BOS (e.g. for to-many
         
     | 
| 86 | 
         
            +
                    multilingual setting). During inference, this requires `--prefix-size 1`
         
     | 
| 87 | 
         
            +
                    to force BOS to be lang ID token."""
         
     | 
| 88 | 
         
            +
                    return self.config.get("prepend_tgt_lang_tag", False)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                @property
         
     | 
| 91 | 
         
            +
                def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
         
     | 
| 92 | 
         
            +
                    """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
         
     | 
| 93 | 
         
            +
                    return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                @property
         
     | 
| 96 | 
         
            +
                def input_feat_per_channel(self):
         
     | 
| 97 | 
         
            +
                    """The dimension of input features (per audio channel)"""
         
     | 
| 98 | 
         
            +
                    return self.config.get("input_feat_per_channel", 80)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                @property
         
     | 
| 101 | 
         
            +
                def input_channels(self):
         
     | 
| 102 | 
         
            +
                    """The number of channels in the input audio"""
         
     | 
| 103 | 
         
            +
                    return self.config.get("input_channels", 1)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                @property
         
     | 
| 106 | 
         
            +
                def sample_rate(self):
         
     | 
| 107 | 
         
            +
                    return self.config.get("sample_rate", 16_000)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                @property
         
     | 
| 110 | 
         
            +
                def sampling_alpha(self):
         
     | 
| 111 | 
         
            +
                    """Hyper-parameter alpha = 1/T for temperature-based resampling.
         
     | 
| 112 | 
         
            +
                    (alpha = 1 for no resampling)"""
         
     | 
| 113 | 
         
            +
                    return self.config.get("sampling_alpha", 1.0)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                @property
         
     | 
| 116 | 
         
            +
                def use_audio_input(self):
         
     | 
| 117 | 
         
            +
                    """Needed by the dataset loader to see if the model requires
         
     | 
| 118 | 
         
            +
                    raw audio as inputs."""
         
     | 
| 119 | 
         
            +
                    return self.config.get("use_audio_input", False)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def standardize_audio(self) -> bool:
         
     | 
| 122 | 
         
            +
                    return self.use_audio_input and self.config.get("standardize_audio", False)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                @property
         
     | 
| 125 | 
         
            +
                def use_sample_rate(self):
         
     | 
| 126 | 
         
            +
                    """Needed by the dataset loader to see if the model requires
         
     | 
| 127 | 
         
            +
                    raw audio with specific sample rate as inputs."""
         
     | 
| 128 | 
         
            +
                    return self.config.get("use_sample_rate", 16000)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                @property
         
     | 
| 131 | 
         
            +
                def audio_root(self):
         
     | 
| 132 | 
         
            +
                    """Audio paths in the manifest TSV can be relative and this provides
         
     | 
| 133 | 
         
            +
                    the root path. Set this to empty string when using absolute paths."""
         
     | 
| 134 | 
         
            +
                    return self.config.get("audio_root", "")
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def get_transforms(self, transform_type, split, is_train):
         
     | 
| 137 | 
         
            +
                    """Split-specific feature transforms. Allowing train set
         
     | 
| 138 | 
         
            +
                    wildcard `_train`, evaluation set wildcard `_eval` and general
         
     | 
| 139 | 
         
            +
                    wildcard `*` for matching."""
         
     | 
| 140 | 
         
            +
                    from copy import deepcopy
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    cfg = deepcopy(self.config)
         
     | 
| 143 | 
         
            +
                    _cur = cfg.get(f"{transform_type}transforms", {})
         
     | 
| 144 | 
         
            +
                    cur = _cur.get(split)
         
     | 
| 145 | 
         
            +
                    cur = _cur.get("_train") if cur is None and is_train else cur
         
     | 
| 146 | 
         
            +
                    cur = _cur.get("_eval") if cur is None and not is_train else cur
         
     | 
| 147 | 
         
            +
                    cur = _cur.get("*") if cur is None else cur
         
     | 
| 148 | 
         
            +
                    return cur
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                def get_feature_transforms(self, split, is_train):
         
     | 
| 151 | 
         
            +
                    cfg = deepcopy(self.config)
         
     | 
| 152 | 
         
            +
                    # TODO: deprecate transforms
         
     | 
| 153 | 
         
            +
                    cur = self.get_transforms("", split, is_train)
         
     | 
| 154 | 
         
            +
                    if cur is not None:
         
     | 
| 155 | 
         
            +
                        logger.warning(
         
     | 
| 156 | 
         
            +
                            "Auto converting transforms into feature_transforms, "
         
     | 
| 157 | 
         
            +
                            "but transforms will be deprecated in the future. Please "
         
     | 
| 158 | 
         
            +
                            "update this in the config."
         
     | 
| 159 | 
         
            +
                        )
         
     | 
| 160 | 
         
            +
                        ft_transforms = self.get_transforms("feature_", split, is_train)
         
     | 
| 161 | 
         
            +
                        if ft_transforms:
         
     | 
| 162 | 
         
            +
                            cur.extend(ft_transforms)
         
     | 
| 163 | 
         
            +
                    else:
         
     | 
| 164 | 
         
            +
                        cur = self.get_transforms("feature_", split, is_train)
         
     | 
| 165 | 
         
            +
                    cfg["feature_transforms"] = cur
         
     | 
| 166 | 
         
            +
                    return cfg
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                def get_waveform_transforms(self, split, is_train):
         
     | 
| 169 | 
         
            +
                    cfg = deepcopy(self.config)
         
     | 
| 170 | 
         
            +
                    cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
         
     | 
| 171 | 
         
            +
                    return cfg
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                def get_dataset_transforms(self, split, is_train):
         
     | 
| 174 | 
         
            +
                    cfg = deepcopy(self.config)
         
     | 
| 175 | 
         
            +
                    cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
         
     | 
| 176 | 
         
            +
                    return cfg
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                @property
         
     | 
| 179 | 
         
            +
                def global_cmvn_stats_npz(self) -> Optional[str]:
         
     | 
| 180 | 
         
            +
                    path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
         
     | 
| 181 | 
         
            +
                    return self._auto_convert_to_abs_path(path)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                @property
         
     | 
| 184 | 
         
            +
                def vocoder(self) -> Dict[str, str]:
         
     | 
| 185 | 
         
            +
                    vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
         
     | 
| 186 | 
         
            +
                    return self._auto_convert_to_abs_path(vocoder)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                @property
         
     | 
| 189 | 
         
            +
                def hub(self) -> Dict[str, str]:
         
     | 
| 190 | 
         
            +
                    return self.config.get("hub", {})
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            class S2SDataConfig(S2TDataConfig):
         
     | 
| 194 | 
         
            +
                """Wrapper class for data config YAML"""
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                @property
         
     | 
| 197 | 
         
            +
                def vocab_filename(self):
         
     | 
| 198 | 
         
            +
                    """fairseq vocabulary file under data root"""
         
     | 
| 199 | 
         
            +
                    return self.config.get("vocab_filename", None)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                @property
         
     | 
| 202 | 
         
            +
                def pre_tokenizer(self) -> Dict:
         
     | 
| 203 | 
         
            +
                    return None
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                @property
         
     | 
| 206 | 
         
            +
                def bpe_tokenizer(self) -> Dict:
         
     | 
| 207 | 
         
            +
                    return None
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                @property
         
     | 
| 210 | 
         
            +
                def input_transformed_channels(self):
         
     | 
| 211 | 
         
            +
                    """The number of channels in the audio after feature transforms"""
         
     | 
| 212 | 
         
            +
                    # TODO: move this into individual transforms
         
     | 
| 213 | 
         
            +
                    # TODO: deprecate transforms
         
     | 
| 214 | 
         
            +
                    _cur = self.config.get("transforms", {})
         
     | 
| 215 | 
         
            +
                    ft_transforms = self.config.get("feature_transforms", {})
         
     | 
| 216 | 
         
            +
                    if _cur and ft_transforms:
         
     | 
| 217 | 
         
            +
                        _cur.update(ft_transforms)
         
     | 
| 218 | 
         
            +
                    else:
         
     | 
| 219 | 
         
            +
                        _cur = self.config.get("feature_transforms", {})
         
     | 
| 220 | 
         
            +
                    cur = _cur.get("_train", [])
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    _channels = self.input_channels
         
     | 
| 223 | 
         
            +
                    if "delta_deltas" in cur:
         
     | 
| 224 | 
         
            +
                        _channels *= 3
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    return _channels
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                @property
         
     | 
| 229 | 
         
            +
                def output_sample_rate(self):
         
     | 
| 230 | 
         
            +
                    """The audio sample rate of output target speech"""
         
     | 
| 231 | 
         
            +
                    return self.config.get("output_sample_rate", 22050)
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                @property
         
     | 
| 234 | 
         
            +
                def target_speaker_embed(self):
         
     | 
| 235 | 
         
            +
                    """Target speaker embedding file (one line per target audio sample)"""
         
     | 
| 236 | 
         
            +
                    return self.config.get("target_speaker_embed", None)
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                @property
         
     | 
| 239 | 
         
            +
                def prepend_tgt_lang_tag_as_bos(self) -> bool:
         
     | 
| 240 | 
         
            +
                    """Prepend target lang ID token as the target BOS."""
         
     | 
| 241 | 
         
            +
                    return self.config.get("prepend_tgt_lang_tag_as_bos", False)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
            class MultitaskConfig(object):
         
     | 
| 245 | 
         
            +
                """Wrapper class for data config YAML"""
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                def __init__(self, yaml_path: Path):
         
     | 
| 248 | 
         
            +
                    config = get_config_from_yaml(yaml_path)
         
     | 
| 249 | 
         
            +
                    self.config = {}
         
     | 
| 250 | 
         
            +
                    for k, v in config.items():
         
     | 
| 251 | 
         
            +
                        self.config[k] = SingleTaskConfig(k, v)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                def get_all_tasks(self):
         
     | 
| 254 | 
         
            +
                    return self.config
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                def get_single_task(self, name):
         
     | 
| 257 | 
         
            +
                    assert name in self.config, f"multitask '{name}' does not exist!"
         
     | 
| 258 | 
         
            +
                    return self.config[name]
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                @property
         
     | 
| 261 | 
         
            +
                def first_pass_decoder_task_index(self):
         
     | 
| 262 | 
         
            +
                    """Return the task index of the first-pass text decoder.
         
     | 
| 263 | 
         
            +
                    If there are multiple 'is_first_pass_decoder: True' in the config file,
         
     | 
| 264 | 
         
            +
                        the last task is used for the first-pass decoder.
         
     | 
| 265 | 
         
            +
                    If there is no 'is_first_pass_decoder: True' in the config file,
         
     | 
| 266 | 
         
            +
                        the last task whose task_name includes 'target' and decoder_type is not ctc.
         
     | 
| 267 | 
         
            +
                    """
         
     | 
| 268 | 
         
            +
                    idx = -1
         
     | 
| 269 | 
         
            +
                    for i, (k, v) in enumerate(self.config.items()):
         
     | 
| 270 | 
         
            +
                        if v.is_first_pass_decoder:
         
     | 
| 271 | 
         
            +
                            idx = i
         
     | 
| 272 | 
         
            +
                    if idx < 0:
         
     | 
| 273 | 
         
            +
                        for i, (k, v) in enumerate(self.config.items()):
         
     | 
| 274 | 
         
            +
                            if k.startswith("target") and v.decoder_type == "transformer":
         
     | 
| 275 | 
         
            +
                                idx = i
         
     | 
| 276 | 
         
            +
                    return idx
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
            class SingleTaskConfig(object):
         
     | 
| 280 | 
         
            +
                def __init__(self, name, config):
         
     | 
| 281 | 
         
            +
                    self.task_name = name
         
     | 
| 282 | 
         
            +
                    self.config = config
         
     | 
| 283 | 
         
            +
                    dict_path = config.get("dict", "")
         
     | 
| 284 | 
         
            +
                    self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                @property
         
     | 
| 287 | 
         
            +
                def data(self):
         
     | 
| 288 | 
         
            +
                    return self.config.get("data", "")
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                @property
         
     | 
| 291 | 
         
            +
                def decoder_type(self):
         
     | 
| 292 | 
         
            +
                    return self.config.get("decoder_type", "transformer")
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                @property
         
     | 
| 295 | 
         
            +
                def decoder_args(self):
         
     | 
| 296 | 
         
            +
                    """Decoder arch related args"""
         
     | 
| 297 | 
         
            +
                    args = self.config.get("decoder_args", {})
         
     | 
| 298 | 
         
            +
                    return Namespace(**args)
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                @property
         
     | 
| 301 | 
         
            +
                def criterion_cfg(self):
         
     | 
| 302 | 
         
            +
                    """cfg for the multitask criterion"""
         
     | 
| 303 | 
         
            +
                    if self.decoder_type == "ctc":
         
     | 
| 304 | 
         
            +
                        from fairseq.criterions.ctc import CtcCriterionConfig
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                        cfg = CtcCriterionConfig
         
     | 
| 307 | 
         
            +
                        cfg.zero_infinity = self.config.get("zero_infinity", True)
         
     | 
| 308 | 
         
            +
                    else:
         
     | 
| 309 | 
         
            +
                        from fairseq.criterions.label_smoothed_cross_entropy import (
         
     | 
| 310 | 
         
            +
                            LabelSmoothedCrossEntropyCriterionConfig,
         
     | 
| 311 | 
         
            +
                        )
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                        cfg = LabelSmoothedCrossEntropyCriterionConfig
         
     | 
| 314 | 
         
            +
                        cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
         
     | 
| 315 | 
         
            +
                    return cfg
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                @property
         
     | 
| 318 | 
         
            +
                def input_from(self):
         
     | 
| 319 | 
         
            +
                    """Condition on encoder/decoder of the main model"""
         
     | 
| 320 | 
         
            +
                    return "decoder" if "decoder_layer" in self.config else "encoder"
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                @property
         
     | 
| 323 | 
         
            +
                def input_layer(self):
         
     | 
| 324 | 
         
            +
                    if self.input_from == "decoder":
         
     | 
| 325 | 
         
            +
                        return self.config["decoder_layer"] - 1
         
     | 
| 326 | 
         
            +
                    else:
         
     | 
| 327 | 
         
            +
                        # default using the output from the last encoder layer (-1)
         
     | 
| 328 | 
         
            +
                        return self.config.get("encoder_layer", 0) - 1
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                @property
         
     | 
| 331 | 
         
            +
                def loss_weight_schedule(self):
         
     | 
| 332 | 
         
            +
                    return (
         
     | 
| 333 | 
         
            +
                        "decay"
         
     | 
| 334 | 
         
            +
                        if "loss_weight_max" in self.config
         
     | 
| 335 | 
         
            +
                        and "loss_weight_decay_steps" in self.config
         
     | 
| 336 | 
         
            +
                        else "fixed"
         
     | 
| 337 | 
         
            +
                    )
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                def get_loss_weight(self, num_updates):
         
     | 
| 340 | 
         
            +
                    if self.loss_weight_schedule == "fixed":
         
     | 
| 341 | 
         
            +
                        weight = self.config.get("loss_weight", 1.0)
         
     | 
| 342 | 
         
            +
                    else:  # "decay"
         
     | 
| 343 | 
         
            +
                        assert (
         
     | 
| 344 | 
         
            +
                            self.config.get("loss_weight_decay_steps", 0) > 0
         
     | 
| 345 | 
         
            +
                        ), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
         
     | 
| 346 | 
         
            +
                        loss_weight_min = self.config.get("loss_weight_min", 0.0001)
         
     | 
| 347 | 
         
            +
                        loss_weight_decay_stepsize = (
         
     | 
| 348 | 
         
            +
                            self.config["loss_weight_max"] - loss_weight_min
         
     | 
| 349 | 
         
            +
                        ) / self.config["loss_weight_decay_steps"]
         
     | 
| 350 | 
         
            +
                        weight = max(
         
     | 
| 351 | 
         
            +
                            self.config["loss_weight_max"]
         
     | 
| 352 | 
         
            +
                            - loss_weight_decay_stepsize * num_updates,
         
     | 
| 353 | 
         
            +
                            loss_weight_min,
         
     | 
| 354 | 
         
            +
                        )
         
     | 
| 355 | 
         
            +
                    return weight
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                @property
         
     | 
| 358 | 
         
            +
                def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
         
     | 
| 359 | 
         
            +
                    """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
         
     | 
| 360 | 
         
            +
                    return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                @property
         
     | 
| 363 | 
         
            +
                def eos_token(self):
         
     | 
| 364 | 
         
            +
                    """EOS token during generation"""
         
     | 
| 365 | 
         
            +
                    return self.config.get("eos_token", "<eos>")
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                @property
         
     | 
| 368 | 
         
            +
                def rdrop_alpha(self):
         
     | 
| 369 | 
         
            +
                    return self.config.get("rdrop_alpha", 0.0)
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                @property
         
     | 
| 372 | 
         
            +
                def is_first_pass_decoder(self):
         
     | 
| 373 | 
         
            +
                    flag = self.config.get("is_first_pass_decoder", False)
         
     | 
| 374 | 
         
            +
                    if flag:
         
     | 
| 375 | 
         
            +
                        if self.decoder_type == "ctc":
         
     | 
| 376 | 
         
            +
                            raise ValueError(
         
     | 
| 377 | 
         
            +
                                "First-pass decoder in the multi-decoder model must not be CTC."
         
     | 
| 378 | 
         
            +
                            )
         
     | 
| 379 | 
         
            +
                        if "target" not in self.task_name:
         
     | 
| 380 | 
         
            +
                            raise Warning(
         
     | 
| 381 | 
         
            +
                                'The name of the first-pass decoder does not include "target".'
         
     | 
| 382 | 
         
            +
                            )
         
     | 
| 383 | 
         
            +
                    return flag
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                @property
         
     | 
| 386 | 
         
            +
                def get_lang_tag_mapping(self):
         
     | 
| 387 | 
         
            +
                    return self.config.get("lang_tag_mapping", {})
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/dataset_transforms/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,53 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            from fairseq.data.audio import (
         
     | 
| 3 | 
         
            +
                AudioTransform,
         
     | 
| 4 | 
         
            +
                CompositeAudioTransform,
         
     | 
| 5 | 
         
            +
                import_transforms,
         
     | 
| 6 | 
         
            +
                register_audio_transform,
         
     | 
| 7 | 
         
            +
            )
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class AudioDatasetTransform(AudioTransform):
         
     | 
| 11 | 
         
            +
                pass
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            AUDIO_DATASET_TRANSFORM_REGISTRY = {}
         
     | 
| 15 | 
         
            +
            AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def get_audio_dataset_transform(name):
         
     | 
| 19 | 
         
            +
                return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def register_audio_dataset_transform(name):
         
     | 
| 23 | 
         
            +
                return register_audio_transform(
         
     | 
| 24 | 
         
            +
                    name,
         
     | 
| 25 | 
         
            +
                    AudioDatasetTransform,
         
     | 
| 26 | 
         
            +
                    AUDIO_DATASET_TRANSFORM_REGISTRY,
         
     | 
| 27 | 
         
            +
                    AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
         
     | 
| 28 | 
         
            +
                )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            import_transforms(os.path.dirname(__file__), "dataset")
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            class CompositeAudioDatasetTransform(CompositeAudioTransform):
         
     | 
| 35 | 
         
            +
                @classmethod
         
     | 
| 36 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 37 | 
         
            +
                    return super()._from_config_dict(
         
     | 
| 38 | 
         
            +
                        cls,
         
     | 
| 39 | 
         
            +
                        "dataset",
         
     | 
| 40 | 
         
            +
                        get_audio_dataset_transform,
         
     | 
| 41 | 
         
            +
                        CompositeAudioDatasetTransform,
         
     | 
| 42 | 
         
            +
                        config,
         
     | 
| 43 | 
         
            +
                        return_empty=True,
         
     | 
| 44 | 
         
            +
                    )
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def get_transform(self, cls):
         
     | 
| 47 | 
         
            +
                    for t in self.transforms:
         
     | 
| 48 | 
         
            +
                        if isinstance(t, cls):
         
     | 
| 49 | 
         
            +
                            return t
         
     | 
| 50 | 
         
            +
                    return None
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def has_transform(self, cls):
         
     | 
| 53 | 
         
            +
                    return self.get_transform(cls) is not None
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/dataset_transforms/concataugment.py
    ADDED
    
    | 
         @@ -0,0 +1,61 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import List
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from fairseq.data.audio.dataset_transforms import (
         
     | 
| 5 | 
         
            +
                AudioDatasetTransform,
         
     | 
| 6 | 
         
            +
                register_audio_dataset_transform,
         
     | 
| 7 | 
         
            +
            )
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            _DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @register_audio_dataset_transform("concataugment")
         
     | 
| 13 | 
         
            +
            class ConcatAugment(AudioDatasetTransform):
         
     | 
| 14 | 
         
            +
                @classmethod
         
     | 
| 15 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 16 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 17 | 
         
            +
                    return ConcatAugment(
         
     | 
| 18 | 
         
            +
                        _config.get("rate", _DEFAULTS["rate"]),
         
     | 
| 19 | 
         
            +
                        _config.get("max_tokens", _DEFAULTS["max_tokens"]),
         
     | 
| 20 | 
         
            +
                        _config.get("attempts", _DEFAULTS["attempts"]),
         
     | 
| 21 | 
         
            +
                    )
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __init__(
         
     | 
| 24 | 
         
            +
                    self,
         
     | 
| 25 | 
         
            +
                    rate=_DEFAULTS["rate"],
         
     | 
| 26 | 
         
            +
                    max_tokens=_DEFAULTS["max_tokens"],
         
     | 
| 27 | 
         
            +
                    attempts=_DEFAULTS["attempts"],
         
     | 
| 28 | 
         
            +
                ):
         
     | 
| 29 | 
         
            +
                    self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def __repr__(self):
         
     | 
| 32 | 
         
            +
                    return (
         
     | 
| 33 | 
         
            +
                        self.__class__.__name__
         
     | 
| 34 | 
         
            +
                        + "("
         
     | 
| 35 | 
         
            +
                        + ", ".join(
         
     | 
| 36 | 
         
            +
                            [
         
     | 
| 37 | 
         
            +
                                f"rate={self.rate}",
         
     | 
| 38 | 
         
            +
                                f"max_tokens={self.max_tokens}",
         
     | 
| 39 | 
         
            +
                                f"attempts={self.attempts}",
         
     | 
| 40 | 
         
            +
                            ]
         
     | 
| 41 | 
         
            +
                        )
         
     | 
| 42 | 
         
            +
                        + ")"
         
     | 
| 43 | 
         
            +
                    )
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def find_indices(self, index: int, n_frames: List[int], n_samples: int):
         
     | 
| 46 | 
         
            +
                    # skip conditions: application rate, max_tokens limit exceeded
         
     | 
| 47 | 
         
            +
                    if np.random.random() > self.rate:
         
     | 
| 48 | 
         
            +
                        return [index]
         
     | 
| 49 | 
         
            +
                    if self.max_tokens and n_frames[index] > self.max_tokens:
         
     | 
| 50 | 
         
            +
                        return [index]
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    # pick second sample to concatenate
         
     | 
| 53 | 
         
            +
                    for _ in range(self.attempts):
         
     | 
| 54 | 
         
            +
                        index2 = np.random.randint(0, n_samples)
         
     | 
| 55 | 
         
            +
                        if index2 != index and (
         
     | 
| 56 | 
         
            +
                            not self.max_tokens
         
     | 
| 57 | 
         
            +
                            or n_frames[index] + n_frames[index2] < self.max_tokens
         
     | 
| 58 | 
         
            +
                        ):
         
     | 
| 59 | 
         
            +
                            return [index, index2]
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    return [index]
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py
    ADDED
    
    | 
         @@ -0,0 +1,105 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from fairseq.data.audio import rand_uniform
         
     | 
| 5 | 
         
            +
            from fairseq.data.audio.dataset_transforms import (
         
     | 
| 6 | 
         
            +
                AudioDatasetTransform,
         
     | 
| 7 | 
         
            +
                register_audio_dataset_transform,
         
     | 
| 8 | 
         
            +
            )
         
     | 
| 9 | 
         
            +
            from fairseq.data.audio.waveform_transforms.noiseaugment import (
         
     | 
| 10 | 
         
            +
                NoiseAugmentTransform,
         
     | 
| 11 | 
         
            +
            )
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            _DEFAULTS = {
         
     | 
| 14 | 
         
            +
                "rate": 0.25,
         
     | 
| 15 | 
         
            +
                "mixing_noise_rate": 0.1,
         
     | 
| 16 | 
         
            +
                "noise_path": "",
         
     | 
| 17 | 
         
            +
                "noise_snr_min": -5,
         
     | 
| 18 | 
         
            +
                "noise_snr_max": 5,
         
     | 
| 19 | 
         
            +
                "utterance_snr_min": -5,
         
     | 
| 20 | 
         
            +
                "utterance_snr_max": 5,
         
     | 
| 21 | 
         
            +
            }
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            @register_audio_dataset_transform("noisyoverlapaugment")
         
     | 
| 25 | 
         
            +
            class NoisyOverlapAugment(AudioDatasetTransform):
         
     | 
| 26 | 
         
            +
                @classmethod
         
     | 
| 27 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 28 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 29 | 
         
            +
                    return NoisyOverlapAugment(
         
     | 
| 30 | 
         
            +
                        _config.get("rate", _DEFAULTS["rate"]),
         
     | 
| 31 | 
         
            +
                        _config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
         
     | 
| 32 | 
         
            +
                        _config.get("noise_path", _DEFAULTS["noise_path"]),
         
     | 
| 33 | 
         
            +
                        _config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
         
     | 
| 34 | 
         
            +
                        _config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
         
     | 
| 35 | 
         
            +
                        _config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
         
     | 
| 36 | 
         
            +
                        _config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
         
     | 
| 37 | 
         
            +
                    )
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                def __init__(
         
     | 
| 40 | 
         
            +
                    self,
         
     | 
| 41 | 
         
            +
                    rate=_DEFAULTS["rate"],
         
     | 
| 42 | 
         
            +
                    mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
         
     | 
| 43 | 
         
            +
                    noise_path=_DEFAULTS["noise_path"],
         
     | 
| 44 | 
         
            +
                    noise_snr_min=_DEFAULTS["noise_snr_min"],
         
     | 
| 45 | 
         
            +
                    noise_snr_max=_DEFAULTS["noise_snr_max"],
         
     | 
| 46 | 
         
            +
                    utterance_snr_min=_DEFAULTS["utterance_snr_min"],
         
     | 
| 47 | 
         
            +
                    utterance_snr_max=_DEFAULTS["utterance_snr_max"],
         
     | 
| 48 | 
         
            +
                ):
         
     | 
| 49 | 
         
            +
                    self.rate = rate
         
     | 
| 50 | 
         
            +
                    self.mixing_noise_rate = mixing_noise_rate
         
     | 
| 51 | 
         
            +
                    self.noise_shaper = NoiseAugmentTransform(noise_path)
         
     | 
| 52 | 
         
            +
                    self.noise_snr_min = noise_snr_min
         
     | 
| 53 | 
         
            +
                    self.noise_snr_max = noise_snr_max
         
     | 
| 54 | 
         
            +
                    self.utterance_snr_min = utterance_snr_min
         
     | 
| 55 | 
         
            +
                    self.utterance_snr_max = utterance_snr_max
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def __repr__(self):
         
     | 
| 58 | 
         
            +
                    return (
         
     | 
| 59 | 
         
            +
                        self.__class__.__name__
         
     | 
| 60 | 
         
            +
                        + "("
         
     | 
| 61 | 
         
            +
                        + ", ".join(
         
     | 
| 62 | 
         
            +
                            [
         
     | 
| 63 | 
         
            +
                                f"rate={self.rate}",
         
     | 
| 64 | 
         
            +
                                f"mixing_noise_rate={self.mixing_noise_rate}",
         
     | 
| 65 | 
         
            +
                                f"noise_snr_min={self.noise_snr_min}",
         
     | 
| 66 | 
         
            +
                                f"noise_snr_max={self.noise_snr_max}",
         
     | 
| 67 | 
         
            +
                                f"utterance_snr_min={self.utterance_snr_min}",
         
     | 
| 68 | 
         
            +
                                f"utterance_snr_max={self.utterance_snr_max}",
         
     | 
| 69 | 
         
            +
                            ]
         
     | 
| 70 | 
         
            +
                        )
         
     | 
| 71 | 
         
            +
                        + ")"
         
     | 
| 72 | 
         
            +
                    )
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def __call__(self, sources):
         
     | 
| 75 | 
         
            +
                    for i, source in enumerate(sources):
         
     | 
| 76 | 
         
            +
                        if np.random.random() > self.rate:
         
     | 
| 77 | 
         
            +
                            continue
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                        pri = source.numpy()
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                        if np.random.random() > self.mixing_noise_rate:
         
     | 
| 82 | 
         
            +
                            sec = sources[np.random.randint(0, len(sources))].numpy()
         
     | 
| 83 | 
         
            +
                            snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
         
     | 
| 84 | 
         
            +
                        else:
         
     | 
| 85 | 
         
            +
                            sec = self.noise_shaper.pick_sample(source.shape)
         
     | 
| 86 | 
         
            +
                            snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                        L1 = pri.shape[-1]
         
     | 
| 89 | 
         
            +
                        L2 = sec.shape[-1]
         
     | 
| 90 | 
         
            +
                        l = np.random.randint(0, min(round(L1 / 2), L2))  # mix len
         
     | 
| 91 | 
         
            +
                        s_source = np.random.randint(0, L1 - l)
         
     | 
| 92 | 
         
            +
                        s_sec = np.random.randint(0, L2 - l)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                        get_power = lambda x: np.mean(x**2)
         
     | 
| 95 | 
         
            +
                        if get_power(sec) == 0:
         
     | 
| 96 | 
         
            +
                            continue
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                        scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                        pri[s_source : s_source + l] = np.add(
         
     | 
| 101 | 
         
            +
                            pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
         
     | 
| 102 | 
         
            +
                        )
         
     | 
| 103 | 
         
            +
                        sources[i] = torch.from_numpy(pri).float()
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    return sources
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/feature_transforms/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,43 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            from fairseq.data.audio import (
         
     | 
| 3 | 
         
            +
                AudioTransform,
         
     | 
| 4 | 
         
            +
                CompositeAudioTransform,
         
     | 
| 5 | 
         
            +
                import_transforms,
         
     | 
| 6 | 
         
            +
                register_audio_transform,
         
     | 
| 7 | 
         
            +
            )
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class AudioFeatureTransform(AudioTransform):
         
     | 
| 11 | 
         
            +
                pass
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
         
     | 
| 15 | 
         
            +
            AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def get_audio_feature_transform(name):
         
     | 
| 19 | 
         
            +
                return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def register_audio_feature_transform(name):
         
     | 
| 23 | 
         
            +
                return register_audio_transform(
         
     | 
| 24 | 
         
            +
                    name,
         
     | 
| 25 | 
         
            +
                    AudioFeatureTransform,
         
     | 
| 26 | 
         
            +
                    AUDIO_FEATURE_TRANSFORM_REGISTRY,
         
     | 
| 27 | 
         
            +
                    AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
         
     | 
| 28 | 
         
            +
                )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            import_transforms(os.path.dirname(__file__), "feature")
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            class CompositeAudioFeatureTransform(CompositeAudioTransform):
         
     | 
| 35 | 
         
            +
                @classmethod
         
     | 
| 36 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 37 | 
         
            +
                    return super()._from_config_dict(
         
     | 
| 38 | 
         
            +
                        cls,
         
     | 
| 39 | 
         
            +
                        "feature",
         
     | 
| 40 | 
         
            +
                        get_audio_feature_transform,
         
     | 
| 41 | 
         
            +
                        CompositeAudioFeatureTransform,
         
     | 
| 42 | 
         
            +
                        config,
         
     | 
| 43 | 
         
            +
                    )
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/feature_transforms/delta_deltas.py
    ADDED
    
    | 
         @@ -0,0 +1,37 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from fairseq.data.audio.feature_transforms import (
         
     | 
| 4 | 
         
            +
                AudioFeatureTransform,
         
     | 
| 5 | 
         
            +
                register_audio_feature_transform,
         
     | 
| 6 | 
         
            +
            )
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @register_audio_feature_transform("delta_deltas")
         
     | 
| 10 | 
         
            +
            class DeltaDeltas(AudioFeatureTransform):
         
     | 
| 11 | 
         
            +
                """Expand delta-deltas features from spectrum."""
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                @classmethod
         
     | 
| 14 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 15 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 16 | 
         
            +
                    return DeltaDeltas(_config.get("win_length", 5))
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def __init__(self, win_length=5):
         
     | 
| 19 | 
         
            +
                    self.win_length = win_length
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def __repr__(self):
         
     | 
| 22 | 
         
            +
                    return self.__class__.__name__
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def __call__(self, spectrogram):
         
     | 
| 25 | 
         
            +
                    from torchaudio.functional import compute_deltas
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
         
     | 
| 28 | 
         
            +
                    # spectrogram is T x F, while compute_deltas takes (…, F, T)
         
     | 
| 29 | 
         
            +
                    spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
         
     | 
| 30 | 
         
            +
                    delta = compute_deltas(spectrogram)
         
     | 
| 31 | 
         
            +
                    delta_delta = compute_deltas(delta)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    out_feat = np.concatenate(
         
     | 
| 34 | 
         
            +
                        [spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
         
     | 
| 35 | 
         
            +
                    )
         
     | 
| 36 | 
         
            +
                    out_feat = np.transpose(out_feat)
         
     | 
| 37 | 
         
            +
                    return out_feat
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/feature_transforms/global_cmvn.py
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            from fairseq.data.audio.feature_transforms import (
         
     | 
| 3 | 
         
            +
                AudioFeatureTransform,
         
     | 
| 4 | 
         
            +
                register_audio_feature_transform,
         
     | 
| 5 | 
         
            +
            )
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            @register_audio_feature_transform("global_cmvn")
         
     | 
| 9 | 
         
            +
            class GlobalCMVN(AudioFeatureTransform):
         
     | 
| 10 | 
         
            +
                """Global CMVN (cepstral mean and variance normalization). The global mean
         
     | 
| 11 | 
         
            +
                and variance need to be pre-computed and stored in NumPy format (.npz)."""
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                @classmethod
         
     | 
| 14 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 15 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 16 | 
         
            +
                    return GlobalCMVN(_config.get("stats_npz_path"))
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def __init__(self, stats_npz_path):
         
     | 
| 19 | 
         
            +
                    self.stats_npz_path = stats_npz_path
         
     | 
| 20 | 
         
            +
                    stats = np.load(stats_npz_path)
         
     | 
| 21 | 
         
            +
                    self.mean, self.std = stats["mean"], stats["std"]
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __repr__(self):
         
     | 
| 24 | 
         
            +
                    return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def __call__(self, x):
         
     | 
| 27 | 
         
            +
                    x = np.subtract(x, self.mean)
         
     | 
| 28 | 
         
            +
                    x = np.divide(x, self.std)
         
     | 
| 29 | 
         
            +
                    return x
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/feature_transforms/specaugment.py
    ADDED
    
    | 
         @@ -0,0 +1,131 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            import numbers
         
     | 
| 3 | 
         
            +
            from typing import Optional
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            from fairseq.data.audio.feature_transforms import (
         
     | 
| 7 | 
         
            +
                AudioFeatureTransform,
         
     | 
| 8 | 
         
            +
                register_audio_feature_transform,
         
     | 
| 9 | 
         
            +
            )
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @register_audio_feature_transform("specaugment")
         
     | 
| 13 | 
         
            +
            class SpecAugmentTransform(AudioFeatureTransform):
         
     | 
| 14 | 
         
            +
                """SpecAugment (https://arxiv.org/abs/1904.08779)"""
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                @classmethod
         
     | 
| 17 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 18 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 19 | 
         
            +
                    return SpecAugmentTransform(
         
     | 
| 20 | 
         
            +
                        _config.get("time_warp_W", 0),
         
     | 
| 21 | 
         
            +
                        _config.get("freq_mask_N", 0),
         
     | 
| 22 | 
         
            +
                        _config.get("freq_mask_F", 0),
         
     | 
| 23 | 
         
            +
                        _config.get("time_mask_N", 0),
         
     | 
| 24 | 
         
            +
                        _config.get("time_mask_T", 0),
         
     | 
| 25 | 
         
            +
                        _config.get("time_mask_p", 0.0),
         
     | 
| 26 | 
         
            +
                        _config.get("mask_value", None),
         
     | 
| 27 | 
         
            +
                    )
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def __init__(
         
     | 
| 30 | 
         
            +
                    self,
         
     | 
| 31 | 
         
            +
                    time_warp_w: int = 0,
         
     | 
| 32 | 
         
            +
                    freq_mask_n: int = 0,
         
     | 
| 33 | 
         
            +
                    freq_mask_f: int = 0,
         
     | 
| 34 | 
         
            +
                    time_mask_n: int = 0,
         
     | 
| 35 | 
         
            +
                    time_mask_t: int = 0,
         
     | 
| 36 | 
         
            +
                    time_mask_p: float = 0.0,
         
     | 
| 37 | 
         
            +
                    mask_value: Optional[float] = 0.0,
         
     | 
| 38 | 
         
            +
                ):
         
     | 
| 39 | 
         
            +
                    # Sanity checks
         
     | 
| 40 | 
         
            +
                    assert mask_value is None or isinstance(
         
     | 
| 41 | 
         
            +
                        mask_value, numbers.Number
         
     | 
| 42 | 
         
            +
                    ), f"mask_value (type: {type(mask_value)}) must be None or a number"
         
     | 
| 43 | 
         
            +
                    if freq_mask_n > 0:
         
     | 
| 44 | 
         
            +
                        assert freq_mask_f > 0, (
         
     | 
| 45 | 
         
            +
                            f"freq_mask_F ({freq_mask_f}) "
         
     | 
| 46 | 
         
            +
                            f"must be larger than 0 when doing freq masking."
         
     | 
| 47 | 
         
            +
                        )
         
     | 
| 48 | 
         
            +
                    if time_mask_n > 0:
         
     | 
| 49 | 
         
            +
                        assert time_mask_t > 0, (
         
     | 
| 50 | 
         
            +
                            f"time_mask_T ({time_mask_t}) must be larger than 0 when "
         
     | 
| 51 | 
         
            +
                            f"doing time masking."
         
     | 
| 52 | 
         
            +
                        )
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    self.time_warp_w = time_warp_w
         
     | 
| 55 | 
         
            +
                    self.freq_mask_n = freq_mask_n
         
     | 
| 56 | 
         
            +
                    self.freq_mask_f = freq_mask_f
         
     | 
| 57 | 
         
            +
                    self.time_mask_n = time_mask_n
         
     | 
| 58 | 
         
            +
                    self.time_mask_t = time_mask_t
         
     | 
| 59 | 
         
            +
                    self.time_mask_p = time_mask_p
         
     | 
| 60 | 
         
            +
                    self.mask_value = mask_value
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def __repr__(self):
         
     | 
| 63 | 
         
            +
                    return (
         
     | 
| 64 | 
         
            +
                        self.__class__.__name__
         
     | 
| 65 | 
         
            +
                        + "("
         
     | 
| 66 | 
         
            +
                        + ", ".join(
         
     | 
| 67 | 
         
            +
                            [
         
     | 
| 68 | 
         
            +
                                f"time_warp_w={self.time_warp_w}",
         
     | 
| 69 | 
         
            +
                                f"freq_mask_n={self.freq_mask_n}",
         
     | 
| 70 | 
         
            +
                                f"freq_mask_f={self.freq_mask_f}",
         
     | 
| 71 | 
         
            +
                                f"time_mask_n={self.time_mask_n}",
         
     | 
| 72 | 
         
            +
                                f"time_mask_t={self.time_mask_t}",
         
     | 
| 73 | 
         
            +
                                f"time_mask_p={self.time_mask_p}",
         
     | 
| 74 | 
         
            +
                            ]
         
     | 
| 75 | 
         
            +
                        )
         
     | 
| 76 | 
         
            +
                        + ")"
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                def __call__(self, spectrogram):
         
     | 
| 80 | 
         
            +
                    assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    distorted = spectrogram.copy()  # make a copy of input spectrogram.
         
     | 
| 83 | 
         
            +
                    num_frames = spectrogram.shape[0]  # or 'tau' in the paper.
         
     | 
| 84 | 
         
            +
                    num_freqs = spectrogram.shape[1]  # or 'miu' in the paper.
         
     | 
| 85 | 
         
            +
                    mask_value = self.mask_value
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    if mask_value is None:  # if no value was specified, use local mean.
         
     | 
| 88 | 
         
            +
                        mask_value = spectrogram.mean()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    if num_frames == 0:
         
     | 
| 91 | 
         
            +
                        return spectrogram
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    if num_freqs < self.freq_mask_f:
         
     | 
| 94 | 
         
            +
                        return spectrogram
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    if self.time_warp_w > 0:
         
     | 
| 97 | 
         
            +
                        if 2 * self.time_warp_w < num_frames:
         
     | 
| 98 | 
         
            +
                            import cv2
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                            w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
         
     | 
| 101 | 
         
            +
                            w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
         
     | 
| 102 | 
         
            +
                            upper, lower = distorted[:w0, :], distorted[w0:, :]
         
     | 
| 103 | 
         
            +
                            upper = cv2.resize(
         
     | 
| 104 | 
         
            +
                                upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
         
     | 
| 105 | 
         
            +
                            )
         
     | 
| 106 | 
         
            +
                            lower = cv2.resize(
         
     | 
| 107 | 
         
            +
                                lower,
         
     | 
| 108 | 
         
            +
                                dsize=(num_freqs, num_frames - w0 - w),
         
     | 
| 109 | 
         
            +
                                interpolation=cv2.INTER_LINEAR,
         
     | 
| 110 | 
         
            +
                            )
         
     | 
| 111 | 
         
            +
                            distorted = np.concatenate((upper, lower), axis=0)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    for _i in range(self.freq_mask_n):
         
     | 
| 114 | 
         
            +
                        f = np.random.randint(0, self.freq_mask_f)
         
     | 
| 115 | 
         
            +
                        f0 = np.random.randint(0, num_freqs - f)
         
     | 
| 116 | 
         
            +
                        if f != 0:
         
     | 
| 117 | 
         
            +
                            distorted[:, f0 : f0 + f] = mask_value
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    max_time_mask_t = min(
         
     | 
| 120 | 
         
            +
                        self.time_mask_t, math.floor(num_frames * self.time_mask_p)
         
     | 
| 121 | 
         
            +
                    )
         
     | 
| 122 | 
         
            +
                    if max_time_mask_t < 1:
         
     | 
| 123 | 
         
            +
                        return distorted
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    for _i in range(self.time_mask_n):
         
     | 
| 126 | 
         
            +
                        t = np.random.randint(0, max_time_mask_t)
         
     | 
| 127 | 
         
            +
                        t0 = np.random.randint(0, num_frames - t)
         
     | 
| 128 | 
         
            +
                        if t != 0:
         
     | 
| 129 | 
         
            +
                            distorted[t0 : t0 + t, :] = mask_value
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    return distorted
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/feature_transforms/utterance_cmvn.py
    ADDED
    
    | 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from fairseq.data.audio.feature_transforms import (
         
     | 
| 4 | 
         
            +
                AudioFeatureTransform,
         
     | 
| 5 | 
         
            +
                register_audio_feature_transform,
         
     | 
| 6 | 
         
            +
            )
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @register_audio_feature_transform("utterance_cmvn")
         
     | 
| 10 | 
         
            +
            class UtteranceCMVN(AudioFeatureTransform):
         
     | 
| 11 | 
         
            +
                """Utterance-level CMVN (cepstral mean and variance normalization)"""
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                @classmethod
         
     | 
| 14 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 15 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 16 | 
         
            +
                    return UtteranceCMVN(
         
     | 
| 17 | 
         
            +
                        _config.get("norm_means", True),
         
     | 
| 18 | 
         
            +
                        _config.get("norm_vars", True),
         
     | 
| 19 | 
         
            +
                    )
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def __init__(self, norm_means=True, norm_vars=True):
         
     | 
| 22 | 
         
            +
                    self.norm_means, self.norm_vars = norm_means, norm_vars
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def __repr__(self):
         
     | 
| 25 | 
         
            +
                    return (
         
     | 
| 26 | 
         
            +
                        self.__class__.__name__
         
     | 
| 27 | 
         
            +
                        + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
         
     | 
| 28 | 
         
            +
                    )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def __call__(self, x):
         
     | 
| 31 | 
         
            +
                    mean = x.mean(axis=0)
         
     | 
| 32 | 
         
            +
                    square_sums = (x**2).sum(axis=0)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    if self.norm_means:
         
     | 
| 35 | 
         
            +
                        x = np.subtract(x, mean)
         
     | 
| 36 | 
         
            +
                    if self.norm_vars:
         
     | 
| 37 | 
         
            +
                        var = square_sums / x.shape[0] - mean**2
         
     | 
| 38 | 
         
            +
                        std = np.sqrt(np.maximum(var, 1e-10))
         
     | 
| 39 | 
         
            +
                        x = np.divide(x, std)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    return x
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/frm_text_to_speech_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,205 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2017-present, Facebook, Inc.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the LICENSE file in
         
     | 
| 5 | 
         
            +
            # the root directory of this source tree. An additional grant of patent rights
         
     | 
| 6 | 
         
            +
            # can be found in the PATENTS file in the same directory.abs
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import csv
         
     | 
| 9 | 
         
            +
            import logging
         
     | 
| 10 | 
         
            +
            import os.path as op
         
     | 
| 11 | 
         
            +
            from typing import List, Optional
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import numpy as np
         
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            from fairseq.data import Dictionary
         
     | 
| 16 | 
         
            +
            from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
         
     | 
| 17 | 
         
            +
            from fairseq.data.audio.text_to_speech_dataset import (
         
     | 
| 18 | 
         
            +
                TextToSpeechDataset,
         
     | 
| 19 | 
         
            +
                TextToSpeechDatasetCreator,
         
     | 
| 20 | 
         
            +
            )
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class FrmTextToSpeechDataset(TextToSpeechDataset):
         
     | 
| 26 | 
         
            +
                def __init__(
         
     | 
| 27 | 
         
            +
                    self,
         
     | 
| 28 | 
         
            +
                    split: str,
         
     | 
| 29 | 
         
            +
                    is_train_split: bool,
         
     | 
| 30 | 
         
            +
                    data_cfg: S2TDataConfig,
         
     | 
| 31 | 
         
            +
                    audio_paths: List[str],
         
     | 
| 32 | 
         
            +
                    n_frames: List[int],
         
     | 
| 33 | 
         
            +
                    src_texts: Optional[List[str]] = None,
         
     | 
| 34 | 
         
            +
                    tgt_texts: Optional[List[str]] = None,
         
     | 
| 35 | 
         
            +
                    speakers: Optional[List[str]] = None,
         
     | 
| 36 | 
         
            +
                    src_langs: Optional[List[str]] = None,
         
     | 
| 37 | 
         
            +
                    tgt_langs: Optional[List[str]] = None,
         
     | 
| 38 | 
         
            +
                    ids: Optional[List[str]] = None,
         
     | 
| 39 | 
         
            +
                    tgt_dict: Optional[Dictionary] = None,
         
     | 
| 40 | 
         
            +
                    pre_tokenizer=None,
         
     | 
| 41 | 
         
            +
                    bpe_tokenizer=None,
         
     | 
| 42 | 
         
            +
                    n_frames_per_step=1,
         
     | 
| 43 | 
         
            +
                    speaker_to_id=None,
         
     | 
| 44 | 
         
            +
                    do_chunk=False,
         
     | 
| 45 | 
         
            +
                    chunk_bound=-1,
         
     | 
| 46 | 
         
            +
                    chunk_init=50,
         
     | 
| 47 | 
         
            +
                    chunk_incr=5,
         
     | 
| 48 | 
         
            +
                    add_eos=True,
         
     | 
| 49 | 
         
            +
                    dedup=True,
         
     | 
| 50 | 
         
            +
                    ref_fpu=-1,
         
     | 
| 51 | 
         
            +
                ):
         
     | 
| 52 | 
         
            +
                    # It assumes texts are encoded at a fixed frame-rate
         
     | 
| 53 | 
         
            +
                    super().__init__(
         
     | 
| 54 | 
         
            +
                        split=split,
         
     | 
| 55 | 
         
            +
                        is_train_split=is_train_split,
         
     | 
| 56 | 
         
            +
                        data_cfg=data_cfg,
         
     | 
| 57 | 
         
            +
                        audio_paths=audio_paths,
         
     | 
| 58 | 
         
            +
                        n_frames=n_frames,
         
     | 
| 59 | 
         
            +
                        src_texts=src_texts,
         
     | 
| 60 | 
         
            +
                        tgt_texts=tgt_texts,
         
     | 
| 61 | 
         
            +
                        speakers=speakers,
         
     | 
| 62 | 
         
            +
                        src_langs=src_langs,
         
     | 
| 63 | 
         
            +
                        tgt_langs=tgt_langs,
         
     | 
| 64 | 
         
            +
                        ids=ids,
         
     | 
| 65 | 
         
            +
                        tgt_dict=tgt_dict,
         
     | 
| 66 | 
         
            +
                        pre_tokenizer=pre_tokenizer,
         
     | 
| 67 | 
         
            +
                        bpe_tokenizer=bpe_tokenizer,
         
     | 
| 68 | 
         
            +
                        n_frames_per_step=n_frames_per_step,
         
     | 
| 69 | 
         
            +
                        speaker_to_id=speaker_to_id,
         
     | 
| 70 | 
         
            +
                    )
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    self.do_chunk = do_chunk
         
     | 
| 73 | 
         
            +
                    self.chunk_bound = chunk_bound
         
     | 
| 74 | 
         
            +
                    self.chunk_init = chunk_init
         
     | 
| 75 | 
         
            +
                    self.chunk_incr = chunk_incr
         
     | 
| 76 | 
         
            +
                    self.add_eos = add_eos
         
     | 
| 77 | 
         
            +
                    self.dedup = dedup
         
     | 
| 78 | 
         
            +
                    self.ref_fpu = ref_fpu
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    self.chunk_size = -1
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    if do_chunk:
         
     | 
| 83 | 
         
            +
                        assert self.chunk_incr >= 0
         
     | 
| 84 | 
         
            +
                        assert self.pre_tokenizer is None
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 87 | 
         
            +
                    index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
         
     | 
| 88 | 
         
            +
                    if target[-1].item() == self.tgt_dict.eos_index:
         
     | 
| 89 | 
         
            +
                        target = target[:-1]
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    fpu = source.size(0) / target.size(0)  # frame-per-unit
         
     | 
| 92 | 
         
            +
                    fps = self.n_frames_per_step
         
     | 
| 93 | 
         
            +
                    assert (
         
     | 
| 94 | 
         
            +
                        self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
         
     | 
| 95 | 
         
            +
                    ), f"{fpu*fps} != {self.ref_fpu}"
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    # only chunk training split
         
     | 
| 98 | 
         
            +
                    if self.is_train_split and self.do_chunk and self.chunk_size > 0:
         
     | 
| 99 | 
         
            +
                        lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
         
     | 
| 100 | 
         
            +
                        text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
         
     | 
| 101 | 
         
            +
                        size = len(text)
         
     | 
| 102 | 
         
            +
                        chunk_size = min(self.chunk_size, size)
         
     | 
| 103 | 
         
            +
                        chunk_start = np.random.randint(size - chunk_size + 1)
         
     | 
| 104 | 
         
            +
                        text = text[chunk_start : chunk_start + chunk_size]
         
     | 
| 105 | 
         
            +
                        target = torch.cat((lang, text), 0)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                        f_size = int(np.floor(chunk_size * fpu))
         
     | 
| 108 | 
         
            +
                        f_start = int(np.floor(chunk_start * fpu))
         
     | 
| 109 | 
         
            +
                        assert f_size > 0
         
     | 
| 110 | 
         
            +
                        source = source[f_start : f_start + f_size, :]
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    if self.dedup:
         
     | 
| 113 | 
         
            +
                        target = torch.unique_consecutive(target)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    if self.add_eos:
         
     | 
| 116 | 
         
            +
                        eos_idx = self.tgt_dict.eos_index
         
     | 
| 117 | 
         
            +
                        target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    return index, source, target, speaker_id
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def set_epoch(self, epoch):
         
     | 
| 122 | 
         
            +
                    if self.is_train_split and self.do_chunk:
         
     | 
| 123 | 
         
            +
                        old = self.chunk_size
         
     | 
| 124 | 
         
            +
                        self.chunk_size = self.chunk_init + epoch * self.chunk_incr
         
     | 
| 125 | 
         
            +
                        if self.chunk_bound > 0:
         
     | 
| 126 | 
         
            +
                            self.chunk_size = min(self.chunk_size, self.chunk_bound)
         
     | 
| 127 | 
         
            +
                        logger.info(
         
     | 
| 128 | 
         
            +
                            (
         
     | 
| 129 | 
         
            +
                                f"{self.split}: setting chunk size "
         
     | 
| 130 | 
         
            +
                                f"from {old} to {self.chunk_size}"
         
     | 
| 131 | 
         
            +
                            )
         
     | 
| 132 | 
         
            +
                        )
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
            class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
         
     | 
| 136 | 
         
            +
                # inherit for key names
         
     | 
| 137 | 
         
            +
                @classmethod
         
     | 
| 138 | 
         
            +
                def from_tsv(
         
     | 
| 139 | 
         
            +
                    cls,
         
     | 
| 140 | 
         
            +
                    root: str,
         
     | 
| 141 | 
         
            +
                    data_cfg: S2TDataConfig,
         
     | 
| 142 | 
         
            +
                    split: str,
         
     | 
| 143 | 
         
            +
                    tgt_dict,
         
     | 
| 144 | 
         
            +
                    pre_tokenizer,
         
     | 
| 145 | 
         
            +
                    bpe_tokenizer,
         
     | 
| 146 | 
         
            +
                    is_train_split: bool,
         
     | 
| 147 | 
         
            +
                    n_frames_per_step: int,
         
     | 
| 148 | 
         
            +
                    speaker_to_id,
         
     | 
| 149 | 
         
            +
                    do_chunk: bool = False,
         
     | 
| 150 | 
         
            +
                    chunk_bound: int = -1,
         
     | 
| 151 | 
         
            +
                    chunk_init: int = 50,
         
     | 
| 152 | 
         
            +
                    chunk_incr: int = 5,
         
     | 
| 153 | 
         
            +
                    add_eos: bool = True,
         
     | 
| 154 | 
         
            +
                    dedup: bool = True,
         
     | 
| 155 | 
         
            +
                    ref_fpu: float = -1,
         
     | 
| 156 | 
         
            +
                ) -> FrmTextToSpeechDataset:
         
     | 
| 157 | 
         
            +
                    tsv_path = op.join(root, f"{split}.tsv")
         
     | 
| 158 | 
         
            +
                    if not op.isfile(tsv_path):
         
     | 
| 159 | 
         
            +
                        raise FileNotFoundError(f"Dataset not found: {tsv_path}")
         
     | 
| 160 | 
         
            +
                    with open(tsv_path) as f:
         
     | 
| 161 | 
         
            +
                        reader = csv.DictReader(
         
     | 
| 162 | 
         
            +
                            f,
         
     | 
| 163 | 
         
            +
                            delimiter="\t",
         
     | 
| 164 | 
         
            +
                            quotechar=None,
         
     | 
| 165 | 
         
            +
                            doublequote=False,
         
     | 
| 166 | 
         
            +
                            lineterminator="\n",
         
     | 
| 167 | 
         
            +
                            quoting=csv.QUOTE_NONE,
         
     | 
| 168 | 
         
            +
                        )
         
     | 
| 169 | 
         
            +
                        s = [dict(e) for e in reader]
         
     | 
| 170 | 
         
            +
                        assert len(s) > 0
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    ids = [ss[cls.KEY_ID] for ss in s]
         
     | 
| 173 | 
         
            +
                    audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
         
     | 
| 174 | 
         
            +
                    n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
         
     | 
| 175 | 
         
            +
                    tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
         
     | 
| 176 | 
         
            +
                    src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
         
     | 
| 177 | 
         
            +
                    speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
         
     | 
| 178 | 
         
            +
                    src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
         
     | 
| 179 | 
         
            +
                    tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    return FrmTextToSpeechDataset(
         
     | 
| 182 | 
         
            +
                        split=split,
         
     | 
| 183 | 
         
            +
                        is_train_split=is_train_split,
         
     | 
| 184 | 
         
            +
                        data_cfg=data_cfg,
         
     | 
| 185 | 
         
            +
                        audio_paths=audio_paths,
         
     | 
| 186 | 
         
            +
                        n_frames=n_frames,
         
     | 
| 187 | 
         
            +
                        src_texts=src_texts,
         
     | 
| 188 | 
         
            +
                        tgt_texts=tgt_texts,
         
     | 
| 189 | 
         
            +
                        speakers=speakers,
         
     | 
| 190 | 
         
            +
                        src_langs=src_langs,
         
     | 
| 191 | 
         
            +
                        tgt_langs=tgt_langs,
         
     | 
| 192 | 
         
            +
                        ids=ids,
         
     | 
| 193 | 
         
            +
                        tgt_dict=tgt_dict,
         
     | 
| 194 | 
         
            +
                        pre_tokenizer=pre_tokenizer,
         
     | 
| 195 | 
         
            +
                        bpe_tokenizer=bpe_tokenizer,
         
     | 
| 196 | 
         
            +
                        n_frames_per_step=n_frames_per_step,
         
     | 
| 197 | 
         
            +
                        speaker_to_id=speaker_to_id,
         
     | 
| 198 | 
         
            +
                        do_chunk=do_chunk,
         
     | 
| 199 | 
         
            +
                        chunk_bound=chunk_bound,
         
     | 
| 200 | 
         
            +
                        chunk_init=chunk_init,
         
     | 
| 201 | 
         
            +
                        chunk_incr=chunk_incr,
         
     | 
| 202 | 
         
            +
                        add_eos=add_eos,
         
     | 
| 203 | 
         
            +
                        dedup=dedup,
         
     | 
| 204 | 
         
            +
                        ref_fpu=ref_fpu,
         
     | 
| 205 | 
         
            +
                    )
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/hubert_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,356 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import itertools
         
     | 
| 7 | 
         
            +
            import logging
         
     | 
| 8 | 
         
            +
            import os
         
     | 
| 9 | 
         
            +
            import sys
         
     | 
| 10 | 
         
            +
            from typing import Any, List, Optional, Union
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 16 | 
         
            +
            from fairseq.data import data_utils
         
     | 
| 17 | 
         
            +
            from fairseq.data.fairseq_dataset import FairseqDataset
         
     | 
| 18 | 
         
            +
            from fairseq.data.audio.audio_utils import (
         
     | 
| 19 | 
         
            +
                parse_path,
         
     | 
| 20 | 
         
            +
                read_from_stored_zip,
         
     | 
| 21 | 
         
            +
            )
         
     | 
| 22 | 
         
            +
            import io
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def load_audio(manifest_path, max_keep, min_keep):
         
     | 
| 28 | 
         
            +
                n_long, n_short = 0, 0
         
     | 
| 29 | 
         
            +
                names, inds, sizes = [], [], []
         
     | 
| 30 | 
         
            +
                with open(manifest_path) as f:
         
     | 
| 31 | 
         
            +
                    root = f.readline().strip()
         
     | 
| 32 | 
         
            +
                    for ind, line in enumerate(f):
         
     | 
| 33 | 
         
            +
                        items = line.strip().split("\t")
         
     | 
| 34 | 
         
            +
                        assert len(items) == 2, line
         
     | 
| 35 | 
         
            +
                        sz = int(items[1])
         
     | 
| 36 | 
         
            +
                        if min_keep is not None and sz < min_keep:
         
     | 
| 37 | 
         
            +
                            n_short += 1
         
     | 
| 38 | 
         
            +
                        elif max_keep is not None and sz > max_keep:
         
     | 
| 39 | 
         
            +
                            n_long += 1
         
     | 
| 40 | 
         
            +
                        else:
         
     | 
| 41 | 
         
            +
                            names.append(items[0])
         
     | 
| 42 | 
         
            +
                            inds.append(ind)
         
     | 
| 43 | 
         
            +
                            sizes.append(sz)
         
     | 
| 44 | 
         
            +
                tot = ind + 1
         
     | 
| 45 | 
         
            +
                logger.info(
         
     | 
| 46 | 
         
            +
                    (
         
     | 
| 47 | 
         
            +
                        f"max_keep={max_keep}, min_keep={min_keep}, "
         
     | 
| 48 | 
         
            +
                        f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
         
     | 
| 49 | 
         
            +
                        f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
         
     | 
| 50 | 
         
            +
                    )
         
     | 
| 51 | 
         
            +
                )
         
     | 
| 52 | 
         
            +
                return root, names, inds, tot, sizes
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            def load_label(label_path, inds, tot):
         
     | 
| 56 | 
         
            +
                with open(label_path) as f:
         
     | 
| 57 | 
         
            +
                    labels = [line.rstrip() for line in f]
         
     | 
| 58 | 
         
            +
                    assert (
         
     | 
| 59 | 
         
            +
                        len(labels) == tot
         
     | 
| 60 | 
         
            +
                    ), f"number of labels does not match ({len(labels)} != {tot})"
         
     | 
| 61 | 
         
            +
                    labels = [labels[i] for i in inds]
         
     | 
| 62 | 
         
            +
                return labels
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            def load_label_offset(label_path, inds, tot):
         
     | 
| 66 | 
         
            +
                with open(label_path) as f:
         
     | 
| 67 | 
         
            +
                    code_lengths = [len(line.encode("utf-8")) for line in f]
         
     | 
| 68 | 
         
            +
                    assert (
         
     | 
| 69 | 
         
            +
                        len(code_lengths) == tot
         
     | 
| 70 | 
         
            +
                    ), f"number of labels does not match ({len(code_lengths)} != {tot})"
         
     | 
| 71 | 
         
            +
                    offsets = list(itertools.accumulate([0] + code_lengths))
         
     | 
| 72 | 
         
            +
                    offsets = [(offsets[i], offsets[i + 1]) for i in inds]
         
     | 
| 73 | 
         
            +
                return offsets
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            def verify_label_lengths(
         
     | 
| 77 | 
         
            +
                audio_sizes,
         
     | 
| 78 | 
         
            +
                audio_rate,
         
     | 
| 79 | 
         
            +
                label_path,
         
     | 
| 80 | 
         
            +
                label_rate,
         
     | 
| 81 | 
         
            +
                inds,
         
     | 
| 82 | 
         
            +
                tot,
         
     | 
| 83 | 
         
            +
                tol=0.1,  # tolerance in seconds
         
     | 
| 84 | 
         
            +
            ):
         
     | 
| 85 | 
         
            +
                if label_rate < 0:
         
     | 
| 86 | 
         
            +
                    logger.info(f"{label_path} is sequence label. skipped")
         
     | 
| 87 | 
         
            +
                    return
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                with open(label_path) as f:
         
     | 
| 90 | 
         
            +
                    lengths = [len(line.rstrip().split()) for line in f]
         
     | 
| 91 | 
         
            +
                    assert len(lengths) == tot
         
     | 
| 92 | 
         
            +
                    lengths = [lengths[i] for i in inds]
         
     | 
| 93 | 
         
            +
                num_invalid = 0
         
     | 
| 94 | 
         
            +
                for i, ind in enumerate(inds):
         
     | 
| 95 | 
         
            +
                    dur_from_audio = audio_sizes[i] / audio_rate
         
     | 
| 96 | 
         
            +
                    dur_from_label = lengths[i] / label_rate
         
     | 
| 97 | 
         
            +
                    if abs(dur_from_audio - dur_from_label) > tol:
         
     | 
| 98 | 
         
            +
                        logger.warning(
         
     | 
| 99 | 
         
            +
                            (
         
     | 
| 100 | 
         
            +
                                f"audio and label duration differ too much "
         
     | 
| 101 | 
         
            +
                                f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
         
     | 
| 102 | 
         
            +
                                f"in line {ind+1} of {label_path}. Check if `label_rate` "
         
     | 
| 103 | 
         
            +
                                f"is correctly set (currently {label_rate}). "
         
     | 
| 104 | 
         
            +
                                f"num. of samples = {audio_sizes[i]}; "
         
     | 
| 105 | 
         
            +
                                f"label length = {lengths[i]}"
         
     | 
| 106 | 
         
            +
                            )
         
     | 
| 107 | 
         
            +
                        )
         
     | 
| 108 | 
         
            +
                        num_invalid += 1
         
     | 
| 109 | 
         
            +
                if num_invalid > 0:
         
     | 
| 110 | 
         
            +
                    logger.warning(
         
     | 
| 111 | 
         
            +
                        f"total {num_invalid} (audio, label) pairs with mismatched lengths"
         
     | 
| 112 | 
         
            +
                    )
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            class HubertDataset(FairseqDataset):
         
     | 
| 116 | 
         
            +
                def __init__(
         
     | 
| 117 | 
         
            +
                    self,
         
     | 
| 118 | 
         
            +
                    manifest_path: str,
         
     | 
| 119 | 
         
            +
                    sample_rate: float,
         
     | 
| 120 | 
         
            +
                    label_paths: List[str],
         
     | 
| 121 | 
         
            +
                    label_rates: Union[List[float], float],  # -1 for sequence labels
         
     | 
| 122 | 
         
            +
                    pad_list: List[str],
         
     | 
| 123 | 
         
            +
                    eos_list: List[str],
         
     | 
| 124 | 
         
            +
                    label_processors: Optional[List[Any]] = None,
         
     | 
| 125 | 
         
            +
                    max_keep_sample_size: Optional[int] = None,
         
     | 
| 126 | 
         
            +
                    min_keep_sample_size: Optional[int] = None,
         
     | 
| 127 | 
         
            +
                    max_sample_size: Optional[int] = None,
         
     | 
| 128 | 
         
            +
                    shuffle: bool = True,
         
     | 
| 129 | 
         
            +
                    pad_audio: bool = False,
         
     | 
| 130 | 
         
            +
                    normalize: bool = False,
         
     | 
| 131 | 
         
            +
                    store_labels: bool = True,
         
     | 
| 132 | 
         
            +
                    random_crop: bool = False,
         
     | 
| 133 | 
         
            +
                    single_target: bool = False,
         
     | 
| 134 | 
         
            +
                ):
         
     | 
| 135 | 
         
            +
                    self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
         
     | 
| 136 | 
         
            +
                        manifest_path, max_keep_sample_size, min_keep_sample_size
         
     | 
| 137 | 
         
            +
                    )
         
     | 
| 138 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 139 | 
         
            +
                    self.shuffle = shuffle
         
     | 
| 140 | 
         
            +
                    self.random_crop = random_crop
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    self.num_labels = len(label_paths)
         
     | 
| 143 | 
         
            +
                    self.pad_list = pad_list
         
     | 
| 144 | 
         
            +
                    self.eos_list = eos_list
         
     | 
| 145 | 
         
            +
                    self.label_processors = label_processors
         
     | 
| 146 | 
         
            +
                    self.single_target = single_target
         
     | 
| 147 | 
         
            +
                    self.label_rates = (
         
     | 
| 148 | 
         
            +
                        [label_rates for _ in range(len(label_paths))]
         
     | 
| 149 | 
         
            +
                        if isinstance(label_rates, float)
         
     | 
| 150 | 
         
            +
                        else label_rates
         
     | 
| 151 | 
         
            +
                    )
         
     | 
| 152 | 
         
            +
                    self.store_labels = store_labels
         
     | 
| 153 | 
         
            +
                    if store_labels:
         
     | 
| 154 | 
         
            +
                        self.label_list = [load_label(p, inds, tot) for p in label_paths]
         
     | 
| 155 | 
         
            +
                    else:
         
     | 
| 156 | 
         
            +
                        self.label_paths = label_paths
         
     | 
| 157 | 
         
            +
                        self.label_offsets_list = [
         
     | 
| 158 | 
         
            +
                            load_label_offset(p, inds, tot) for p in label_paths
         
     | 
| 159 | 
         
            +
                        ]
         
     | 
| 160 | 
         
            +
                    assert label_processors is None or len(label_processors) == self.num_labels
         
     | 
| 161 | 
         
            +
                    for label_path, label_rate in zip(label_paths, self.label_rates):
         
     | 
| 162 | 
         
            +
                        verify_label_lengths(
         
     | 
| 163 | 
         
            +
                            self.sizes, sample_rate, label_path, label_rate, inds, tot
         
     | 
| 164 | 
         
            +
                        )
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    self.max_sample_size = (
         
     | 
| 167 | 
         
            +
                        max_sample_size if max_sample_size is not None else sys.maxsize
         
     | 
| 168 | 
         
            +
                    )
         
     | 
| 169 | 
         
            +
                    self.pad_audio = pad_audio
         
     | 
| 170 | 
         
            +
                    self.normalize = normalize
         
     | 
| 171 | 
         
            +
                    logger.info(
         
     | 
| 172 | 
         
            +
                        f"pad_audio={pad_audio}, random_crop={random_crop}, "
         
     | 
| 173 | 
         
            +
                        f"normalize={normalize}, max_sample_size={self.max_sample_size}"
         
     | 
| 174 | 
         
            +
                    )
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                def get_audio(self, index):
         
     | 
| 177 | 
         
            +
                    import soundfile as sf
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    wav_path = os.path.join(self.audio_root, self.audio_names[index])
         
     | 
| 180 | 
         
            +
                    _path, slice_ptr = parse_path(wav_path)
         
     | 
| 181 | 
         
            +
                    if len(slice_ptr) == 0:
         
     | 
| 182 | 
         
            +
                        wav, cur_sample_rate = sf.read(_path)
         
     | 
| 183 | 
         
            +
                    else:
         
     | 
| 184 | 
         
            +
                        assert _path.endswith(".zip")
         
     | 
| 185 | 
         
            +
                        data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
         
     | 
| 186 | 
         
            +
                        f = io.BytesIO(data)
         
     | 
| 187 | 
         
            +
                        wav, cur_sample_rate = sf.read(f)
         
     | 
| 188 | 
         
            +
                    wav = torch.from_numpy(wav).float()
         
     | 
| 189 | 
         
            +
                    wav = self.postprocess(wav, cur_sample_rate)
         
     | 
| 190 | 
         
            +
                    return wav
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                def get_label(self, index, label_idx):
         
     | 
| 193 | 
         
            +
                    if self.store_labels:
         
     | 
| 194 | 
         
            +
                        label = self.label_list[label_idx][index]
         
     | 
| 195 | 
         
            +
                    else:
         
     | 
| 196 | 
         
            +
                        with open(self.label_paths[label_idx]) as f:
         
     | 
| 197 | 
         
            +
                            offset_s, offset_e = self.label_offsets_list[label_idx][index]
         
     | 
| 198 | 
         
            +
                            f.seek(offset_s)
         
     | 
| 199 | 
         
            +
                            label = f.read(offset_e - offset_s)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    if self.label_processors is not None:
         
     | 
| 202 | 
         
            +
                        label = self.label_processors[label_idx](label)
         
     | 
| 203 | 
         
            +
                    return label
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def get_labels(self, index):
         
     | 
| 206 | 
         
            +
                    return [self.get_label(index, i) for i in range(self.num_labels)]
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 209 | 
         
            +
                    wav = self.get_audio(index)
         
     | 
| 210 | 
         
            +
                    labels = self.get_labels(index)
         
     | 
| 211 | 
         
            +
                    return {"id": index, "source": wav, "label_list": labels}
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                def __len__(self):
         
     | 
| 214 | 
         
            +
                    return len(self.sizes)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                def crop_to_max_size(self, wav, target_size):
         
     | 
| 217 | 
         
            +
                    size = len(wav)
         
     | 
| 218 | 
         
            +
                    diff = size - target_size
         
     | 
| 219 | 
         
            +
                    if diff <= 0:
         
     | 
| 220 | 
         
            +
                        return wav, 0
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    start, end = 0, target_size
         
     | 
| 223 | 
         
            +
                    if self.random_crop:
         
     | 
| 224 | 
         
            +
                        start = np.random.randint(0, diff + 1)
         
     | 
| 225 | 
         
            +
                        end = size - diff + start
         
     | 
| 226 | 
         
            +
                    return wav[start:end], start
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                def collater(self, samples):
         
     | 
| 229 | 
         
            +
                    # target = max(sizes) -> random_crop not used
         
     | 
| 230 | 
         
            +
                    # target = max_sample_size -> random_crop used for long
         
     | 
| 231 | 
         
            +
                    samples = [s for s in samples if s["source"] is not None]
         
     | 
| 232 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 233 | 
         
            +
                        return {}
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    audios = [s["source"] for s in samples]
         
     | 
| 236 | 
         
            +
                    audio_sizes = [len(s) for s in audios]
         
     | 
| 237 | 
         
            +
                    if self.pad_audio:
         
     | 
| 238 | 
         
            +
                        audio_size = min(max(audio_sizes), self.max_sample_size)
         
     | 
| 239 | 
         
            +
                    else:
         
     | 
| 240 | 
         
            +
                        audio_size = min(min(audio_sizes), self.max_sample_size)
         
     | 
| 241 | 
         
            +
                    collated_audios, padding_mask, audio_starts = self.collater_audio(
         
     | 
| 242 | 
         
            +
                        audios, audio_size
         
     | 
| 243 | 
         
            +
                    )
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    targets_by_label = [
         
     | 
| 246 | 
         
            +
                        [s["label_list"][i] for s in samples] for i in range(self.num_labels)
         
     | 
| 247 | 
         
            +
                    ]
         
     | 
| 248 | 
         
            +
                    targets_list, lengths_list, ntokens_list = self.collater_label(
         
     | 
| 249 | 
         
            +
                        targets_by_label, audio_size, audio_starts
         
     | 
| 250 | 
         
            +
                    )
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    net_input = {"source": collated_audios, "padding_mask": padding_mask}
         
     | 
| 253 | 
         
            +
                    batch = {
         
     | 
| 254 | 
         
            +
                        "id": torch.LongTensor([s["id"] for s in samples]),
         
     | 
| 255 | 
         
            +
                        "net_input": net_input,
         
     | 
| 256 | 
         
            +
                    }
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    if self.single_target:
         
     | 
| 259 | 
         
            +
                        batch["target_lengths"] = lengths_list[0]
         
     | 
| 260 | 
         
            +
                        batch["ntokens"] = ntokens_list[0]
         
     | 
| 261 | 
         
            +
                        batch["target"] = targets_list[0]
         
     | 
| 262 | 
         
            +
                    else:
         
     | 
| 263 | 
         
            +
                        batch["target_lengths_list"] = lengths_list
         
     | 
| 264 | 
         
            +
                        batch["ntokens_list"] = ntokens_list
         
     | 
| 265 | 
         
            +
                        batch["target_list"] = targets_list
         
     | 
| 266 | 
         
            +
                    return batch
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                def collater_audio(self, audios, audio_size):
         
     | 
| 269 | 
         
            +
                    collated_audios = audios[0].new_zeros(len(audios), audio_size)
         
     | 
| 270 | 
         
            +
                    padding_mask = (
         
     | 
| 271 | 
         
            +
                        torch.BoolTensor(collated_audios.shape).fill_(False)
         
     | 
| 272 | 
         
            +
                        # if self.pad_audio else None
         
     | 
| 273 | 
         
            +
                    )
         
     | 
| 274 | 
         
            +
                    audio_starts = [0 for _ in audios]
         
     | 
| 275 | 
         
            +
                    for i, audio in enumerate(audios):
         
     | 
| 276 | 
         
            +
                        diff = len(audio) - audio_size
         
     | 
| 277 | 
         
            +
                        if diff == 0:
         
     | 
| 278 | 
         
            +
                            collated_audios[i] = audio
         
     | 
| 279 | 
         
            +
                        elif diff < 0:
         
     | 
| 280 | 
         
            +
                            assert self.pad_audio
         
     | 
| 281 | 
         
            +
                            collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
         
     | 
| 282 | 
         
            +
                            padding_mask[i, diff:] = True
         
     | 
| 283 | 
         
            +
                        else:
         
     | 
| 284 | 
         
            +
                            collated_audios[i], audio_starts[i] = self.crop_to_max_size(
         
     | 
| 285 | 
         
            +
                                audio, audio_size
         
     | 
| 286 | 
         
            +
                            )
         
     | 
| 287 | 
         
            +
                    return collated_audios, padding_mask, audio_starts
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
         
     | 
| 290 | 
         
            +
                    assert label_rate > 0
         
     | 
| 291 | 
         
            +
                    s2f = label_rate / self.sample_rate
         
     | 
| 292 | 
         
            +
                    frm_starts = [int(round(s * s2f)) for s in audio_starts]
         
     | 
| 293 | 
         
            +
                    frm_size = int(round(audio_size * s2f))
         
     | 
| 294 | 
         
            +
                    if not self.pad_audio:
         
     | 
| 295 | 
         
            +
                        rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
         
     | 
| 296 | 
         
            +
                        frm_size = min(frm_size, *rem_size)
         
     | 
| 297 | 
         
            +
                    targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
         
     | 
| 298 | 
         
            +
                    logger.debug(f"audio_starts={audio_starts}")
         
     | 
| 299 | 
         
            +
                    logger.debug(f"frame_starts={frm_starts}")
         
     | 
| 300 | 
         
            +
                    logger.debug(f"frame_size={frm_size}")
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                    lengths = torch.LongTensor([len(t) for t in targets])
         
     | 
| 303 | 
         
            +
                    ntokens = lengths.sum().item()
         
     | 
| 304 | 
         
            +
                    targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
         
     | 
| 305 | 
         
            +
                    return targets, lengths, ntokens
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                def collater_seq_label(self, targets, pad):
         
     | 
| 308 | 
         
            +
                    lengths = torch.LongTensor([len(t) for t in targets])
         
     | 
| 309 | 
         
            +
                    ntokens = lengths.sum().item()
         
     | 
| 310 | 
         
            +
                    targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
         
     | 
| 311 | 
         
            +
                    return targets, lengths, ntokens
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                def collater_label(self, targets_by_label, audio_size, audio_starts):
         
     | 
| 314 | 
         
            +
                    targets_list, lengths_list, ntokens_list = [], [], []
         
     | 
| 315 | 
         
            +
                    itr = zip(targets_by_label, self.label_rates, self.pad_list)
         
     | 
| 316 | 
         
            +
                    for targets, label_rate, pad in itr:
         
     | 
| 317 | 
         
            +
                        if label_rate == -1.0:
         
     | 
| 318 | 
         
            +
                            targets, lengths, ntokens = self.collater_seq_label(targets, pad)
         
     | 
| 319 | 
         
            +
                        else:
         
     | 
| 320 | 
         
            +
                            targets, lengths, ntokens = self.collater_frm_label(
         
     | 
| 321 | 
         
            +
                                targets, audio_size, audio_starts, label_rate, pad
         
     | 
| 322 | 
         
            +
                            )
         
     | 
| 323 | 
         
            +
                        targets_list.append(targets)
         
     | 
| 324 | 
         
            +
                        lengths_list.append(lengths)
         
     | 
| 325 | 
         
            +
                        ntokens_list.append(ntokens)
         
     | 
| 326 | 
         
            +
                    return targets_list, lengths_list, ntokens_list
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 329 | 
         
            +
                    return self.size(index)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                def size(self, index):
         
     | 
| 332 | 
         
            +
                    if self.pad_audio:
         
     | 
| 333 | 
         
            +
                        return self.sizes[index]
         
     | 
| 334 | 
         
            +
                    return min(self.sizes[index], self.max_sample_size)
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                def ordered_indices(self):
         
     | 
| 337 | 
         
            +
                    if self.shuffle:
         
     | 
| 338 | 
         
            +
                        order = [np.random.permutation(len(self))]
         
     | 
| 339 | 
         
            +
                    else:
         
     | 
| 340 | 
         
            +
                        order = [np.arange(len(self))]
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    order.append(self.sizes)
         
     | 
| 343 | 
         
            +
                    return np.lexsort(order)[::-1]
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                def postprocess(self, wav, cur_sample_rate):
         
     | 
| 346 | 
         
            +
                    if wav.dim() == 2:
         
     | 
| 347 | 
         
            +
                        wav = wav.mean(-1)
         
     | 
| 348 | 
         
            +
                    assert wav.dim() == 1, wav.dim()
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    if cur_sample_rate != self.sample_rate:
         
     | 
| 351 | 
         
            +
                        raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    if self.normalize:
         
     | 
| 354 | 
         
            +
                        with torch.no_grad():
         
     | 
| 355 | 
         
            +
                            wav = F.layer_norm(wav, wav.shape)
         
     | 
| 356 | 
         
            +
                    return wav
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/multi_modality_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,284 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2021-present, Facebook, Inc.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the LICENSE file in
         
     | 
| 5 | 
         
            +
            # the root directory of this source tree. An additional grant of patent rights
         
     | 
| 6 | 
         
            +
            # can be found in the PATENTS file in the same directory.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import logging
         
     | 
| 9 | 
         
            +
            import math
         
     | 
| 10 | 
         
            +
            from typing import List, Optional, NamedTuple
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            from fairseq.data.resampling_dataset import ResamplingDataset
         
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            from fairseq.data import (
         
     | 
| 16 | 
         
            +
                ConcatDataset,
         
     | 
| 17 | 
         
            +
                LanguagePairDataset,
         
     | 
| 18 | 
         
            +
                FileAudioDataset,
         
     | 
| 19 | 
         
            +
                data_utils,
         
     | 
| 20 | 
         
            +
            )
         
     | 
| 21 | 
         
            +
            from fairseq.data import FairseqDataset
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            class ModalityDatasetItem(NamedTuple):
         
     | 
| 27 | 
         
            +
                datasetname: str
         
     | 
| 28 | 
         
            +
                dataset: any
         
     | 
| 29 | 
         
            +
                max_positions: List[int]
         
     | 
| 30 | 
         
            +
                max_tokens: Optional[int] = None
         
     | 
| 31 | 
         
            +
                max_sentences: Optional[int] = None
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def resampling_dataset_present(ds):
         
     | 
| 35 | 
         
            +
                if isinstance(ds, ResamplingDataset):
         
     | 
| 36 | 
         
            +
                    return True
         
     | 
| 37 | 
         
            +
                if isinstance(ds, ConcatDataset):
         
     | 
| 38 | 
         
            +
                    return any(resampling_dataset_present(d) for d in ds.datasets)
         
     | 
| 39 | 
         
            +
                if hasattr(ds, "dataset"):
         
     | 
| 40 | 
         
            +
                    return resampling_dataset_present(ds.dataset)
         
     | 
| 41 | 
         
            +
                return False
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            # MultiModalityDataset: it concate multiple datasets with different modalities.
         
     | 
| 45 | 
         
            +
            # Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
         
     | 
| 46 | 
         
            +
            # 2) it adds mode to indicate what type of the data samples come from.
         
     | 
| 47 | 
         
            +
            # It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
         
     | 
| 48 | 
         
            +
            # from the same type of dataset
         
     | 
| 49 | 
         
            +
            # If only one dataset is used, it will perform like the original dataset with mode added
         
     | 
| 50 | 
         
            +
            class MultiModalityDataset(ConcatDataset):
         
     | 
| 51 | 
         
            +
                def __init__(self, datasets: List[ModalityDatasetItem]):
         
     | 
| 52 | 
         
            +
                    id_to_mode = []
         
     | 
| 53 | 
         
            +
                    dsets = []
         
     | 
| 54 | 
         
            +
                    max_tokens = []
         
     | 
| 55 | 
         
            +
                    max_sentences = []
         
     | 
| 56 | 
         
            +
                    max_positions = []
         
     | 
| 57 | 
         
            +
                    for dset in datasets:
         
     | 
| 58 | 
         
            +
                        id_to_mode.append(dset.datasetname)
         
     | 
| 59 | 
         
            +
                        dsets.append(dset.dataset)
         
     | 
| 60 | 
         
            +
                        max_tokens.append(dset.max_tokens)
         
     | 
| 61 | 
         
            +
                        max_positions.append(dset.max_positions)
         
     | 
| 62 | 
         
            +
                        max_sentences.append(dset.max_sentences)
         
     | 
| 63 | 
         
            +
                    weights = [1.0 for s in dsets]
         
     | 
| 64 | 
         
            +
                    super().__init__(dsets, weights)
         
     | 
| 65 | 
         
            +
                    self.max_tokens = max_tokens
         
     | 
| 66 | 
         
            +
                    self.max_positions = max_positions
         
     | 
| 67 | 
         
            +
                    self.max_sentences = max_sentences
         
     | 
| 68 | 
         
            +
                    self.id_to_mode = id_to_mode
         
     | 
| 69 | 
         
            +
                    self.raw_sub_batch_samplers = []
         
     | 
| 70 | 
         
            +
                    self._cur_epoch = 0
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def set_epoch(self, epoch):
         
     | 
| 73 | 
         
            +
                    super().set_epoch(epoch)
         
     | 
| 74 | 
         
            +
                    self._cur_epoch = epoch
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 77 | 
         
            +
                    dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
         
     | 
| 78 | 
         
            +
                    sample = self.datasets[dataset_idx][sample_idx]
         
     | 
| 79 | 
         
            +
                    return (dataset_idx, sample)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def collater(self, samples):
         
     | 
| 82 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 83 | 
         
            +
                        return {}
         
     | 
| 84 | 
         
            +
                    dataset_idx = samples[0][0]
         
     | 
| 85 | 
         
            +
                    # make sure all samples in samples are from same dataset
         
     | 
| 86 | 
         
            +
                    assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
         
     | 
| 87 | 
         
            +
                    samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
         
     | 
| 88 | 
         
            +
                    # add mode
         
     | 
| 89 | 
         
            +
                    samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    return samples
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                def size(self, index: int):
         
     | 
| 94 | 
         
            +
                    if len(self.datasets) == 1:
         
     | 
| 95 | 
         
            +
                        return self.datasets[0].size(index)
         
     | 
| 96 | 
         
            +
                    return super().size(index)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                @property
         
     | 
| 99 | 
         
            +
                def sizes(self):
         
     | 
| 100 | 
         
            +
                    if len(self.datasets) == 1:
         
     | 
| 101 | 
         
            +
                        return self.datasets[0].sizes
         
     | 
| 102 | 
         
            +
                    return super().sizes
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def ordered_indices(self):
         
     | 
| 105 | 
         
            +
                    """
         
     | 
| 106 | 
         
            +
                    Returns indices sorted by length. So less padding is needed.
         
     | 
| 107 | 
         
            +
                    """
         
     | 
| 108 | 
         
            +
                    if len(self.datasets) == 1:
         
     | 
| 109 | 
         
            +
                        return self.datasets[0].ordered_indices()
         
     | 
| 110 | 
         
            +
                    indices_group = []
         
     | 
| 111 | 
         
            +
                    for d_idx, ds in enumerate(self.datasets):
         
     | 
| 112 | 
         
            +
                        sample_num = self.cumulative_sizes[d_idx]
         
     | 
| 113 | 
         
            +
                        if d_idx > 0:
         
     | 
| 114 | 
         
            +
                            sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
         
     | 
| 115 | 
         
            +
                        assert sample_num == len(ds)
         
     | 
| 116 | 
         
            +
                        indices_group.append(ds.ordered_indices())
         
     | 
| 117 | 
         
            +
                    return indices_group
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
         
     | 
| 120 | 
         
            +
                    with data_utils.numpy_seed(seed):
         
     | 
| 121 | 
         
            +
                        indices = self.ordered_indices()
         
     | 
| 122 | 
         
            +
                    for i, ds in enumerate(self.datasets):
         
     | 
| 123 | 
         
            +
                        # If we have ResamplingDataset, the same id can correpond to a different
         
     | 
| 124 | 
         
            +
                        # sample in the next epoch, so we need to rebuild this at every epoch
         
     | 
| 125 | 
         
            +
                        if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
         
     | 
| 126 | 
         
            +
                            ds
         
     | 
| 127 | 
         
            +
                        ):
         
     | 
| 128 | 
         
            +
                            logger.info(f"dataset {i} is valid and it is not re-sampled")
         
     | 
| 129 | 
         
            +
                            continue
         
     | 
| 130 | 
         
            +
                        indices[i] = ds.filter_indices_by_size(
         
     | 
| 131 | 
         
            +
                            indices[i],
         
     | 
| 132 | 
         
            +
                            self.max_positions[i],
         
     | 
| 133 | 
         
            +
                        )[0]
         
     | 
| 134 | 
         
            +
                        sub_batch_sampler = ds.batch_by_size(
         
     | 
| 135 | 
         
            +
                            indices[i],
         
     | 
| 136 | 
         
            +
                            max_tokens=self.max_tokens[i],
         
     | 
| 137 | 
         
            +
                            max_sentences=self.max_sentences[i],
         
     | 
| 138 | 
         
            +
                            required_batch_size_multiple=required_batch_size_multiple,
         
     | 
| 139 | 
         
            +
                        )
         
     | 
| 140 | 
         
            +
                        if i < len(self.raw_sub_batch_samplers):
         
     | 
| 141 | 
         
            +
                            self.raw_sub_batch_samplers[i] = sub_batch_sampler
         
     | 
| 142 | 
         
            +
                        else:
         
     | 
| 143 | 
         
            +
                            self.raw_sub_batch_samplers.append(sub_batch_sampler)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
         
     | 
| 146 | 
         
            +
                    self.get_raw_batch_samplers(required_batch_size_multiple, seed)
         
     | 
| 147 | 
         
            +
                    batch_samplers = []
         
     | 
| 148 | 
         
            +
                    for i, _ in enumerate(self.datasets):
         
     | 
| 149 | 
         
            +
                        if i > 0:
         
     | 
| 150 | 
         
            +
                            sub_batch_sampler = [
         
     | 
| 151 | 
         
            +
                                [y + self.cumulative_sizes[i - 1] for y in x]
         
     | 
| 152 | 
         
            +
                                for x in self.raw_sub_batch_samplers[i]
         
     | 
| 153 | 
         
            +
                            ]
         
     | 
| 154 | 
         
            +
                        else:
         
     | 
| 155 | 
         
            +
                            sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
         
     | 
| 156 | 
         
            +
                        smp_r = mult_ratios[i]
         
     | 
| 157 | 
         
            +
                        if smp_r != 1:
         
     | 
| 158 | 
         
            +
                            is_increase = "increased" if smp_r > 1 else "decreased"
         
     | 
| 159 | 
         
            +
                            logger.info(
         
     | 
| 160 | 
         
            +
                                "number of batch for the dataset {} is {} from {} to {}".format(
         
     | 
| 161 | 
         
            +
                                    self.id_to_mode[i],
         
     | 
| 162 | 
         
            +
                                    is_increase,
         
     | 
| 163 | 
         
            +
                                    len(sub_batch_sampler),
         
     | 
| 164 | 
         
            +
                                    int(len(sub_batch_sampler) * smp_r),
         
     | 
| 165 | 
         
            +
                                )
         
     | 
| 166 | 
         
            +
                            )
         
     | 
| 167 | 
         
            +
                            mul_samplers = []
         
     | 
| 168 | 
         
            +
                            for _ in range(math.floor(smp_r)):
         
     | 
| 169 | 
         
            +
                                mul_samplers = mul_samplers + sub_batch_sampler
         
     | 
| 170 | 
         
            +
                            if math.floor(smp_r) != smp_r:
         
     | 
| 171 | 
         
            +
                                with data_utils.numpy_seed(seed + self._cur_epoch):
         
     | 
| 172 | 
         
            +
                                    np.random.shuffle(sub_batch_sampler)
         
     | 
| 173 | 
         
            +
                                    smp_num = int(
         
     | 
| 174 | 
         
            +
                                        (smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
         
     | 
| 175 | 
         
            +
                                    )
         
     | 
| 176 | 
         
            +
                                mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
         
     | 
| 177 | 
         
            +
                            sub_batch_sampler = mul_samplers
         
     | 
| 178 | 
         
            +
                        else:
         
     | 
| 179 | 
         
            +
                            logger.info(
         
     | 
| 180 | 
         
            +
                                "dataset {} batch number is {} ".format(
         
     | 
| 181 | 
         
            +
                                    self.id_to_mode[i], len(sub_batch_sampler)
         
     | 
| 182 | 
         
            +
                                )
         
     | 
| 183 | 
         
            +
                            )
         
     | 
| 184 | 
         
            +
                        batch_samplers.append(sub_batch_sampler)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    return batch_samplers
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
            class LangPairMaskDataset(FairseqDataset):
         
     | 
| 190 | 
         
            +
                def __init__(
         
     | 
| 191 | 
         
            +
                    self,
         
     | 
| 192 | 
         
            +
                    dataset: LanguagePairDataset,
         
     | 
| 193 | 
         
            +
                    src_eos: int,
         
     | 
| 194 | 
         
            +
                    src_bos: Optional[int] = None,
         
     | 
| 195 | 
         
            +
                    noise_id: Optional[int] = -1,
         
     | 
| 196 | 
         
            +
                    mask_ratio: Optional[float] = 0,
         
     | 
| 197 | 
         
            +
                    mask_type: Optional[str] = "random",
         
     | 
| 198 | 
         
            +
                ):
         
     | 
| 199 | 
         
            +
                    self.dataset = dataset
         
     | 
| 200 | 
         
            +
                    self.src_eos = src_eos
         
     | 
| 201 | 
         
            +
                    self.src_bos = src_bos
         
     | 
| 202 | 
         
            +
                    self.noise_id = noise_id
         
     | 
| 203 | 
         
            +
                    self.mask_ratio = mask_ratio
         
     | 
| 204 | 
         
            +
                    self.mask_type = mask_type
         
     | 
| 205 | 
         
            +
                    assert mask_type in ("random", "tail")
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                @property
         
     | 
| 208 | 
         
            +
                def src_sizes(self):
         
     | 
| 209 | 
         
            +
                    return self.dataset.src_sizes
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                @property
         
     | 
| 212 | 
         
            +
                def tgt_sizes(self):
         
     | 
| 213 | 
         
            +
                    return self.dataset.tgt_sizes
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                @property
         
     | 
| 216 | 
         
            +
                def sizes(self):
         
     | 
| 217 | 
         
            +
                    # dataset.sizes can be a dynamically computed sizes:
         
     | 
| 218 | 
         
            +
                    return self.dataset.sizes
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                def get_batch_shapes(self):
         
     | 
| 221 | 
         
            +
                    if hasattr(self.dataset, "get_batch_shapes"):
         
     | 
| 222 | 
         
            +
                        return self.dataset.get_batch_shapes()
         
     | 
| 223 | 
         
            +
                    return self.dataset.buckets
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                def num_tokens_vec(self, indices):
         
     | 
| 226 | 
         
            +
                    return self.dataset.num_tokens_vec(indices)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                def __len__(self):
         
     | 
| 229 | 
         
            +
                    return len(self.dataset)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 232 | 
         
            +
                    return self.dataset.num_tokens(index)
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                def size(self, index):
         
     | 
| 235 | 
         
            +
                    return self.dataset.size(index)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                def ordered_indices(self):
         
     | 
| 238 | 
         
            +
                    return self.dataset.ordered_indices()
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                @property
         
     | 
| 241 | 
         
            +
                def supports_prefetch(self):
         
     | 
| 242 | 
         
            +
                    return getattr(self.dataset, "supports_prefetch", False)
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                def prefetch(self, indices):
         
     | 
| 245 | 
         
            +
                    return self.dataset.prefetch(indices)
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                def mask_src_tokens(self, sample):
         
     | 
| 248 | 
         
            +
                    src_item = sample["source"]
         
     | 
| 249 | 
         
            +
                    mask = None
         
     | 
| 250 | 
         
            +
                    if self.mask_type == "random":
         
     | 
| 251 | 
         
            +
                        mask = torch.rand(len(src_item)).le(self.mask_ratio)
         
     | 
| 252 | 
         
            +
                    else:
         
     | 
| 253 | 
         
            +
                        mask = torch.ones(len(src_item))
         
     | 
| 254 | 
         
            +
                        mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
         
     | 
| 255 | 
         
            +
                        mask = mask.eq(1)
         
     | 
| 256 | 
         
            +
                    if src_item[0] == self.src_bos:
         
     | 
| 257 | 
         
            +
                        mask[0] = False
         
     | 
| 258 | 
         
            +
                    if src_item[-1] == self.src_eos:
         
     | 
| 259 | 
         
            +
                        mask[-1] = False
         
     | 
| 260 | 
         
            +
                    mask_src_item = src_item.masked_fill(mask, self.noise_id)
         
     | 
| 261 | 
         
            +
                    smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
         
     | 
| 262 | 
         
            +
                    return smp
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 265 | 
         
            +
                    sample = self.dataset[index]
         
     | 
| 266 | 
         
            +
                    if self.mask_ratio > 0:
         
     | 
| 267 | 
         
            +
                        sample = self.mask_src_tokens(sample)
         
     | 
| 268 | 
         
            +
                    return sample
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                def collater(self, samples, pad_to_length=None):
         
     | 
| 271 | 
         
            +
                    return self.dataset.collater(samples, pad_to_length)
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
            class FileAudioDatasetWrapper(FileAudioDataset):
         
     | 
| 275 | 
         
            +
                def collater(self, samples):
         
     | 
| 276 | 
         
            +
                    samples = super().collater(samples)
         
     | 
| 277 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 278 | 
         
            +
                        return {}
         
     | 
| 279 | 
         
            +
                    samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
         
     | 
| 280 | 
         
            +
                    samples["net_input"]["prev_output_tokens"] = None
         
     | 
| 281 | 
         
            +
                    del samples["net_input"]["source"]
         
     | 
| 282 | 
         
            +
                    samples["net_input"]["src_lengths"] = None
         
     | 
| 283 | 
         
            +
                    samples["net_input"]["alignment"] = None
         
     | 
| 284 | 
         
            +
                    return samples
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,393 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import logging
         
     | 
| 8 | 
         
            +
            import os
         
     | 
| 9 | 
         
            +
            import sys
         
     | 
| 10 | 
         
            +
            import io
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import torch
         
     | 
| 14 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from .. import FairseqDataset
         
     | 
| 17 | 
         
            +
            from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
         
     | 
| 18 | 
         
            +
            from fairseq.data.audio.audio_utils import (
         
     | 
| 19 | 
         
            +
                parse_path,
         
     | 
| 20 | 
         
            +
                read_from_stored_zip,
         
     | 
| 21 | 
         
            +
                is_sf_audio_data,
         
     | 
| 22 | 
         
            +
            )
         
     | 
| 23 | 
         
            +
            from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class RawAudioDataset(FairseqDataset):
         
     | 
| 30 | 
         
            +
                def __init__(
         
     | 
| 31 | 
         
            +
                    self,
         
     | 
| 32 | 
         
            +
                    sample_rate,
         
     | 
| 33 | 
         
            +
                    max_sample_size=None,
         
     | 
| 34 | 
         
            +
                    min_sample_size=0,
         
     | 
| 35 | 
         
            +
                    shuffle=True,
         
     | 
| 36 | 
         
            +
                    pad=False,
         
     | 
| 37 | 
         
            +
                    normalize=False,
         
     | 
| 38 | 
         
            +
                    compute_mask_indices=False,
         
     | 
| 39 | 
         
            +
                    **mask_compute_kwargs,
         
     | 
| 40 | 
         
            +
                ):
         
     | 
| 41 | 
         
            +
                    super().__init__()
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 44 | 
         
            +
                    self.sizes = []
         
     | 
| 45 | 
         
            +
                    self.max_sample_size = (
         
     | 
| 46 | 
         
            +
                        max_sample_size if max_sample_size is not None else sys.maxsize
         
     | 
| 47 | 
         
            +
                    )
         
     | 
| 48 | 
         
            +
                    self.min_sample_size = min_sample_size
         
     | 
| 49 | 
         
            +
                    self.pad = pad
         
     | 
| 50 | 
         
            +
                    self.shuffle = shuffle
         
     | 
| 51 | 
         
            +
                    self.normalize = normalize
         
     | 
| 52 | 
         
            +
                    self.compute_mask_indices = compute_mask_indices
         
     | 
| 53 | 
         
            +
                    if self.compute_mask_indices:
         
     | 
| 54 | 
         
            +
                        self.mask_compute_kwargs = mask_compute_kwargs
         
     | 
| 55 | 
         
            +
                        self._features_size_map = {}
         
     | 
| 56 | 
         
            +
                        self._C = mask_compute_kwargs["encoder_embed_dim"]
         
     | 
| 57 | 
         
            +
                        self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 60 | 
         
            +
                    raise NotImplementedError()
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def __len__(self):
         
     | 
| 63 | 
         
            +
                    return len(self.sizes)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def postprocess(self, feats, curr_sample_rate):
         
     | 
| 66 | 
         
            +
                    if feats.dim() == 2:
         
     | 
| 67 | 
         
            +
                        feats = feats.mean(-1)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    if curr_sample_rate != self.sample_rate:
         
     | 
| 70 | 
         
            +
                        raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    assert feats.dim() == 1, feats.dim()
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    if self.normalize:
         
     | 
| 75 | 
         
            +
                        with torch.no_grad():
         
     | 
| 76 | 
         
            +
                            feats = F.layer_norm(feats, feats.shape)
         
     | 
| 77 | 
         
            +
                    return feats
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                def crop_to_max_size(self, wav, target_size):
         
     | 
| 80 | 
         
            +
                    size = len(wav)
         
     | 
| 81 | 
         
            +
                    diff = size - target_size
         
     | 
| 82 | 
         
            +
                    if diff <= 0:
         
     | 
| 83 | 
         
            +
                        return wav
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    start = np.random.randint(0, diff + 1)
         
     | 
| 86 | 
         
            +
                    end = size - diff + start
         
     | 
| 87 | 
         
            +
                    return wav[start:end]
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def _compute_mask_indices(self, dims, padding_mask):
         
     | 
| 90 | 
         
            +
                    B, T, C = dims
         
     | 
| 91 | 
         
            +
                    mask_indices, mask_channel_indices = None, None
         
     | 
| 92 | 
         
            +
                    if self.mask_compute_kwargs["mask_prob"] > 0:
         
     | 
| 93 | 
         
            +
                        mask_indices = compute_mask_indices(
         
     | 
| 94 | 
         
            +
                            (B, T),
         
     | 
| 95 | 
         
            +
                            padding_mask,
         
     | 
| 96 | 
         
            +
                            self.mask_compute_kwargs["mask_prob"],
         
     | 
| 97 | 
         
            +
                            self.mask_compute_kwargs["mask_length"],
         
     | 
| 98 | 
         
            +
                            self.mask_compute_kwargs["mask_selection"],
         
     | 
| 99 | 
         
            +
                            self.mask_compute_kwargs["mask_other"],
         
     | 
| 100 | 
         
            +
                            min_masks=2,
         
     | 
| 101 | 
         
            +
                            no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
         
     | 
| 102 | 
         
            +
                            min_space=self.mask_compute_kwargs["mask_min_space"],
         
     | 
| 103 | 
         
            +
                        )
         
     | 
| 104 | 
         
            +
                        mask_indices = torch.from_numpy(mask_indices)
         
     | 
| 105 | 
         
            +
                    if self.mask_compute_kwargs["mask_channel_prob"] > 0:
         
     | 
| 106 | 
         
            +
                        mask_channel_indices = compute_mask_indices(
         
     | 
| 107 | 
         
            +
                            (B, C),
         
     | 
| 108 | 
         
            +
                            None,
         
     | 
| 109 | 
         
            +
                            self.mask_compute_kwargs["mask_channel_prob"],
         
     | 
| 110 | 
         
            +
                            self.mask_compute_kwargs["mask_channel_length"],
         
     | 
| 111 | 
         
            +
                            self.mask_compute_kwargs["mask_channel_selection"],
         
     | 
| 112 | 
         
            +
                            self.mask_compute_kwargs["mask_channel_other"],
         
     | 
| 113 | 
         
            +
                            no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
         
     | 
| 114 | 
         
            +
                            min_space=self.mask_compute_kwargs["mask_channel_min_space"],
         
     | 
| 115 | 
         
            +
                        )
         
     | 
| 116 | 
         
            +
                        mask_channel_indices = (
         
     | 
| 117 | 
         
            +
                            torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
         
     | 
| 118 | 
         
            +
                        )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    return mask_indices, mask_channel_indices
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                @staticmethod
         
     | 
| 123 | 
         
            +
                def _bucket_tensor(tensor, num_pad, value):
         
     | 
| 124 | 
         
            +
                    return F.pad(tensor, (0, num_pad), value=value)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def collater(self, samples):
         
     | 
| 127 | 
         
            +
                    samples = [s for s in samples if s["source"] is not None]
         
     | 
| 128 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 129 | 
         
            +
                        return {}
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    sources = [s["source"] for s in samples]
         
     | 
| 132 | 
         
            +
                    sizes = [len(s) for s in sources]
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    if self.pad:
         
     | 
| 135 | 
         
            +
                        target_size = min(max(sizes), self.max_sample_size)
         
     | 
| 136 | 
         
            +
                    else:
         
     | 
| 137 | 
         
            +
                        target_size = min(min(sizes), self.max_sample_size)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    collated_sources = sources[0].new_zeros(len(sources), target_size)
         
     | 
| 140 | 
         
            +
                    padding_mask = (
         
     | 
| 141 | 
         
            +
                        torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
         
     | 
| 142 | 
         
            +
                    )
         
     | 
| 143 | 
         
            +
                    for i, (source, size) in enumerate(zip(sources, sizes)):
         
     | 
| 144 | 
         
            +
                        diff = size - target_size
         
     | 
| 145 | 
         
            +
                        if diff == 0:
         
     | 
| 146 | 
         
            +
                            collated_sources[i] = source
         
     | 
| 147 | 
         
            +
                        elif diff < 0:
         
     | 
| 148 | 
         
            +
                            assert self.pad
         
     | 
| 149 | 
         
            +
                            collated_sources[i] = torch.cat(
         
     | 
| 150 | 
         
            +
                                [source, source.new_full((-diff,), 0.0)]
         
     | 
| 151 | 
         
            +
                            )
         
     | 
| 152 | 
         
            +
                            padding_mask[i, diff:] = True
         
     | 
| 153 | 
         
            +
                        else:
         
     | 
| 154 | 
         
            +
                            collated_sources[i] = self.crop_to_max_size(source, target_size)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    input = {"source": collated_sources}
         
     | 
| 157 | 
         
            +
                    out = {"id": torch.LongTensor([s["id"] for s in samples])}
         
     | 
| 158 | 
         
            +
                    if self.pad:
         
     | 
| 159 | 
         
            +
                        input["padding_mask"] = padding_mask
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    if hasattr(self, "num_buckets") and self.num_buckets > 0:
         
     | 
| 162 | 
         
            +
                        assert self.pad, "Cannot bucket without padding first."
         
     | 
| 163 | 
         
            +
                        bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
         
     | 
| 164 | 
         
            +
                        num_pad = bucket - collated_sources.size(-1)
         
     | 
| 165 | 
         
            +
                        if num_pad:
         
     | 
| 166 | 
         
            +
                            input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
         
     | 
| 167 | 
         
            +
                            input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    if self.compute_mask_indices:
         
     | 
| 170 | 
         
            +
                        B = input["source"].size(0)
         
     | 
| 171 | 
         
            +
                        T = self._get_mask_indices_dims(input["source"].size(-1))
         
     | 
| 172 | 
         
            +
                        padding_mask_reshaped = input["padding_mask"].clone()
         
     | 
| 173 | 
         
            +
                        extra = padding_mask_reshaped.size(1) % T
         
     | 
| 174 | 
         
            +
                        if extra > 0:
         
     | 
| 175 | 
         
            +
                            padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
         
     | 
| 176 | 
         
            +
                        padding_mask_reshaped = padding_mask_reshaped.view(
         
     | 
| 177 | 
         
            +
                            padding_mask_reshaped.size(0), T, -1
         
     | 
| 178 | 
         
            +
                        )
         
     | 
| 179 | 
         
            +
                        padding_mask_reshaped = padding_mask_reshaped.all(-1)
         
     | 
| 180 | 
         
            +
                        input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
         
     | 
| 181 | 
         
            +
                        mask_indices, mask_channel_indices = self._compute_mask_indices(
         
     | 
| 182 | 
         
            +
                            (B, T, self._C),
         
     | 
| 183 | 
         
            +
                            padding_mask_reshaped,
         
     | 
| 184 | 
         
            +
                        )
         
     | 
| 185 | 
         
            +
                        input["mask_indices"] = mask_indices
         
     | 
| 186 | 
         
            +
                        input["mask_channel_indices"] = mask_channel_indices
         
     | 
| 187 | 
         
            +
                        out["sample_size"] = mask_indices.sum().item()
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    out["net_input"] = input
         
     | 
| 190 | 
         
            +
                    return out
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                def _get_mask_indices_dims(self, size, padding=0, dilation=1):
         
     | 
| 193 | 
         
            +
                    if size not in self._features_size_map:
         
     | 
| 194 | 
         
            +
                        L_in = size
         
     | 
| 195 | 
         
            +
                        for (_, kernel_size, stride) in self._conv_feature_layers:
         
     | 
| 196 | 
         
            +
                            L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
         
     | 
| 197 | 
         
            +
                            L_out = 1 + L_out // stride
         
     | 
| 198 | 
         
            +
                            L_in = L_out
         
     | 
| 199 | 
         
            +
                        self._features_size_map[size] = L_out
         
     | 
| 200 | 
         
            +
                    return self._features_size_map[size]
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 203 | 
         
            +
                    return self.size(index)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def size(self, index):
         
     | 
| 206 | 
         
            +
                    """Return an example's size as a float or tuple. This value is used when
         
     | 
| 207 | 
         
            +
                    filtering a dataset with ``--max-positions``."""
         
     | 
| 208 | 
         
            +
                    if self.pad:
         
     | 
| 209 | 
         
            +
                        return self.sizes[index]
         
     | 
| 210 | 
         
            +
                    return min(self.sizes[index], self.max_sample_size)
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                def ordered_indices(self):
         
     | 
| 213 | 
         
            +
                    """Return an ordered list of indices. Batches will be constructed based
         
     | 
| 214 | 
         
            +
                    on this order."""
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    if self.shuffle:
         
     | 
| 217 | 
         
            +
                        order = [np.random.permutation(len(self))]
         
     | 
| 218 | 
         
            +
                        order.append(
         
     | 
| 219 | 
         
            +
                            np.minimum(
         
     | 
| 220 | 
         
            +
                                np.array(self.sizes),
         
     | 
| 221 | 
         
            +
                                self.max_sample_size,
         
     | 
| 222 | 
         
            +
                            )
         
     | 
| 223 | 
         
            +
                        )
         
     | 
| 224 | 
         
            +
                        return np.lexsort(order)[::-1]
         
     | 
| 225 | 
         
            +
                    else:
         
     | 
| 226 | 
         
            +
                        return np.arange(len(self))
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                def set_bucket_info(self, num_buckets):
         
     | 
| 229 | 
         
            +
                    self.num_buckets = num_buckets
         
     | 
| 230 | 
         
            +
                    if self.num_buckets > 0:
         
     | 
| 231 | 
         
            +
                        self._collated_sizes = np.minimum(
         
     | 
| 232 | 
         
            +
                            np.array(self.sizes),
         
     | 
| 233 | 
         
            +
                            self.max_sample_size,
         
     | 
| 234 | 
         
            +
                        )
         
     | 
| 235 | 
         
            +
                        self.buckets = get_buckets(
         
     | 
| 236 | 
         
            +
                            self._collated_sizes,
         
     | 
| 237 | 
         
            +
                            self.num_buckets,
         
     | 
| 238 | 
         
            +
                        )
         
     | 
| 239 | 
         
            +
                        self._bucketed_sizes = get_bucketed_sizes(
         
     | 
| 240 | 
         
            +
                            self._collated_sizes, self.buckets
         
     | 
| 241 | 
         
            +
                        )
         
     | 
| 242 | 
         
            +
                        logger.info(
         
     | 
| 243 | 
         
            +
                            f"{len(self.buckets)} bucket(s) for the audio dataset: "
         
     | 
| 244 | 
         
            +
                            f"{self.buckets}"
         
     | 
| 245 | 
         
            +
                        )
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
            class FileAudioDataset(RawAudioDataset):
         
     | 
| 249 | 
         
            +
                def __init__(
         
     | 
| 250 | 
         
            +
                    self,
         
     | 
| 251 | 
         
            +
                    manifest_path,
         
     | 
| 252 | 
         
            +
                    sample_rate,
         
     | 
| 253 | 
         
            +
                    max_sample_size=None,
         
     | 
| 254 | 
         
            +
                    min_sample_size=0,
         
     | 
| 255 | 
         
            +
                    shuffle=True,
         
     | 
| 256 | 
         
            +
                    pad=False,
         
     | 
| 257 | 
         
            +
                    normalize=False,
         
     | 
| 258 | 
         
            +
                    num_buckets=0,
         
     | 
| 259 | 
         
            +
                    compute_mask_indices=False,
         
     | 
| 260 | 
         
            +
                    text_compression_level=TextCompressionLevel.none,
         
     | 
| 261 | 
         
            +
                    **mask_compute_kwargs,
         
     | 
| 262 | 
         
            +
                ):
         
     | 
| 263 | 
         
            +
                    super().__init__(
         
     | 
| 264 | 
         
            +
                        sample_rate=sample_rate,
         
     | 
| 265 | 
         
            +
                        max_sample_size=max_sample_size,
         
     | 
| 266 | 
         
            +
                        min_sample_size=min_sample_size,
         
     | 
| 267 | 
         
            +
                        shuffle=shuffle,
         
     | 
| 268 | 
         
            +
                        pad=pad,
         
     | 
| 269 | 
         
            +
                        normalize=normalize,
         
     | 
| 270 | 
         
            +
                        compute_mask_indices=compute_mask_indices,
         
     | 
| 271 | 
         
            +
                        **mask_compute_kwargs,
         
     | 
| 272 | 
         
            +
                    )
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    self.text_compressor = TextCompressor(level=text_compression_level)
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    skipped = 0
         
     | 
| 277 | 
         
            +
                    self.fnames = []
         
     | 
| 278 | 
         
            +
                    sizes = []
         
     | 
| 279 | 
         
            +
                    self.skipped_indices = set()
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    with open(manifest_path, "r") as f:
         
     | 
| 282 | 
         
            +
                        self.root_dir = f.readline().strip()
         
     | 
| 283 | 
         
            +
                        for i, line in enumerate(f):
         
     | 
| 284 | 
         
            +
                            items = line.strip().split("\t")
         
     | 
| 285 | 
         
            +
                            assert len(items) == 2, line
         
     | 
| 286 | 
         
            +
                            sz = int(items[1])
         
     | 
| 287 | 
         
            +
                            if min_sample_size is not None and sz < min_sample_size:
         
     | 
| 288 | 
         
            +
                                skipped += 1
         
     | 
| 289 | 
         
            +
                                self.skipped_indices.add(i)
         
     | 
| 290 | 
         
            +
                                continue
         
     | 
| 291 | 
         
            +
                            self.fnames.append(self.text_compressor.compress(items[0]))
         
     | 
| 292 | 
         
            +
                            sizes.append(sz)
         
     | 
| 293 | 
         
            +
                    logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    self.sizes = np.array(sizes, dtype=np.int64)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    try:
         
     | 
| 298 | 
         
            +
                        import pyarrow
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                        self.fnames = pyarrow.array(self.fnames)
         
     | 
| 301 | 
         
            +
                    except:
         
     | 
| 302 | 
         
            +
                        logger.debug(
         
     | 
| 303 | 
         
            +
                            "Could not create a pyarrow array. Please install pyarrow for better performance"
         
     | 
| 304 | 
         
            +
                        )
         
     | 
| 305 | 
         
            +
                        pass
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                    self.set_bucket_info(num_buckets)
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 310 | 
         
            +
                    import soundfile as sf
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                    fn = self.fnames[index]
         
     | 
| 313 | 
         
            +
                    fn = fn if isinstance(self.fnames, list) else fn.as_py()
         
     | 
| 314 | 
         
            +
                    fn = self.text_compressor.decompress(fn)
         
     | 
| 315 | 
         
            +
                    path_or_fp = os.path.join(self.root_dir, fn)
         
     | 
| 316 | 
         
            +
                    _path, slice_ptr = parse_path(path_or_fp)
         
     | 
| 317 | 
         
            +
                    if len(slice_ptr) == 2:
         
     | 
| 318 | 
         
            +
                        byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
         
     | 
| 319 | 
         
            +
                        assert is_sf_audio_data(byte_data)
         
     | 
| 320 | 
         
            +
                        path_or_fp = io.BytesIO(byte_data)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    feats = torch.from_numpy(wav).float()
         
     | 
| 325 | 
         
            +
                    feats = self.postprocess(feats, curr_sample_rate)
         
     | 
| 326 | 
         
            +
                    return {"id": index, "source": feats}
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
            class BinarizedAudioDataset(RawAudioDataset):
         
     | 
| 330 | 
         
            +
                def __init__(
         
     | 
| 331 | 
         
            +
                    self,
         
     | 
| 332 | 
         
            +
                    data_dir,
         
     | 
| 333 | 
         
            +
                    split,
         
     | 
| 334 | 
         
            +
                    sample_rate,
         
     | 
| 335 | 
         
            +
                    max_sample_size=None,
         
     | 
| 336 | 
         
            +
                    min_sample_size=0,
         
     | 
| 337 | 
         
            +
                    shuffle=True,
         
     | 
| 338 | 
         
            +
                    pad=False,
         
     | 
| 339 | 
         
            +
                    normalize=False,
         
     | 
| 340 | 
         
            +
                    num_buckets=0,
         
     | 
| 341 | 
         
            +
                    compute_mask_indices=False,
         
     | 
| 342 | 
         
            +
                    **mask_compute_kwargs,
         
     | 
| 343 | 
         
            +
                ):
         
     | 
| 344 | 
         
            +
                    super().__init__(
         
     | 
| 345 | 
         
            +
                        sample_rate=sample_rate,
         
     | 
| 346 | 
         
            +
                        max_sample_size=max_sample_size,
         
     | 
| 347 | 
         
            +
                        min_sample_size=min_sample_size,
         
     | 
| 348 | 
         
            +
                        shuffle=shuffle,
         
     | 
| 349 | 
         
            +
                        pad=pad,
         
     | 
| 350 | 
         
            +
                        normalize=normalize,
         
     | 
| 351 | 
         
            +
                        compute_mask_indices=compute_mask_indices,
         
     | 
| 352 | 
         
            +
                        **mask_compute_kwargs,
         
     | 
| 353 | 
         
            +
                    )
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    from fairseq.data import data_utils, Dictionary
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    root_path = os.path.join(data_dir, f"{split}.root")
         
     | 
| 360 | 
         
            +
                    if os.path.exists(root_path):
         
     | 
| 361 | 
         
            +
                        with open(root_path, "r") as f:
         
     | 
| 362 | 
         
            +
                            self.root_dir = next(f).strip()
         
     | 
| 363 | 
         
            +
                    else:
         
     | 
| 364 | 
         
            +
                        self.root_dir = None
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                    fnames_path = os.path.join(data_dir, split)
         
     | 
| 367 | 
         
            +
                    self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
         
     | 
| 368 | 
         
            +
                    lengths_path = os.path.join(data_dir, f"{split}.lengths")
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    with open(lengths_path, "r") as f:
         
     | 
| 371 | 
         
            +
                        for line in f:
         
     | 
| 372 | 
         
            +
                            sz = int(line.rstrip())
         
     | 
| 373 | 
         
            +
                            assert (
         
     | 
| 374 | 
         
            +
                                sz >= min_sample_size
         
     | 
| 375 | 
         
            +
                            ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
         
     | 
| 376 | 
         
            +
                            self.sizes.append(sz)
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                    self.sizes = np.array(self.sizes, dtype=np.int64)
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                    self.set_bucket_info(num_buckets)
         
     | 
| 381 | 
         
            +
                    logger.info(f"loaded {len(self.fnames)} samples")
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 384 | 
         
            +
                    import soundfile as sf
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                    fname = self.fnames_dict.string(self.fnames[index], separator="")
         
     | 
| 387 | 
         
            +
                    if self.root_dir:
         
     | 
| 388 | 
         
            +
                        fname = os.path.join(self.root_dir, fname)
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                    wav, curr_sample_rate = sf.read(fname)
         
     | 
| 391 | 
         
            +
                    feats = torch.from_numpy(wav).float()
         
     | 
| 392 | 
         
            +
                    feats = self.postprocess(feats, curr_sample_rate)
         
     | 
| 393 | 
         
            +
                    return {"id": index, "source": feats}
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/speech_to_speech_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,379 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import logging
         
     | 
| 7 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 8 | 
         
            +
            from pathlib import Path
         
     | 
| 9 | 
         
            +
            from typing import Dict, List, Optional, Tuple
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from fairseq.data import ConcatDataset, Dictionary
         
     | 
| 14 | 
         
            +
            from fairseq.data import data_utils as fairseq_data_utils
         
     | 
| 15 | 
         
            +
            from fairseq.data.audio.audio_utils import get_features_or_waveform
         
     | 
| 16 | 
         
            +
            from fairseq.data.audio.data_cfg import S2SDataConfig
         
     | 
| 17 | 
         
            +
            from fairseq.data.audio.speech_to_text_dataset import (
         
     | 
| 18 | 
         
            +
                SpeechToTextDataset,
         
     | 
| 19 | 
         
            +
                SpeechToTextDatasetCreator,
         
     | 
| 20 | 
         
            +
                TextTargetMultitaskData,
         
     | 
| 21 | 
         
            +
                _collate_frames,
         
     | 
| 22 | 
         
            +
            )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            @dataclass
         
     | 
| 28 | 
         
            +
            class SpeechToSpeechDatasetItem(object):
         
     | 
| 29 | 
         
            +
                index: int
         
     | 
| 30 | 
         
            +
                source: torch.Tensor
         
     | 
| 31 | 
         
            +
                target: Optional[torch.Tensor] = None
         
     | 
| 32 | 
         
            +
                target_speaker: Optional[torch.Tensor] = None
         
     | 
| 33 | 
         
            +
                tgt_lang_tag: Optional[int] = None
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            class SpeechToSpeechDataset(SpeechToTextDataset):
         
     | 
| 37 | 
         
            +
                def __init__(
         
     | 
| 38 | 
         
            +
                    self,
         
     | 
| 39 | 
         
            +
                    split: str,
         
     | 
| 40 | 
         
            +
                    is_train_split: bool,
         
     | 
| 41 | 
         
            +
                    data_cfg: S2SDataConfig,
         
     | 
| 42 | 
         
            +
                    src_audio_paths: List[str],
         
     | 
| 43 | 
         
            +
                    src_n_frames: List[int],
         
     | 
| 44 | 
         
            +
                    tgt_audio_paths: List[str],
         
     | 
| 45 | 
         
            +
                    tgt_n_frames: List[int],
         
     | 
| 46 | 
         
            +
                    src_langs: Optional[List[str]] = None,
         
     | 
| 47 | 
         
            +
                    tgt_langs: Optional[List[str]] = None,
         
     | 
| 48 | 
         
            +
                    ids: Optional[List[str]] = None,
         
     | 
| 49 | 
         
            +
                    target_is_code: bool = False,
         
     | 
| 50 | 
         
            +
                    tgt_dict: Dictionary = None,
         
     | 
| 51 | 
         
            +
                    n_frames_per_step: int = 1,
         
     | 
| 52 | 
         
            +
                ):
         
     | 
| 53 | 
         
            +
                    tgt_texts = tgt_audio_paths if target_is_code else None
         
     | 
| 54 | 
         
            +
                    super().__init__(
         
     | 
| 55 | 
         
            +
                        split=split,
         
     | 
| 56 | 
         
            +
                        is_train_split=is_train_split,
         
     | 
| 57 | 
         
            +
                        cfg=data_cfg,
         
     | 
| 58 | 
         
            +
                        audio_paths=src_audio_paths,
         
     | 
| 59 | 
         
            +
                        n_frames=src_n_frames,
         
     | 
| 60 | 
         
            +
                        ids=ids,
         
     | 
| 61 | 
         
            +
                        tgt_dict=tgt_dict,
         
     | 
| 62 | 
         
            +
                        tgt_texts=tgt_texts,
         
     | 
| 63 | 
         
            +
                        src_langs=src_langs,
         
     | 
| 64 | 
         
            +
                        tgt_langs=tgt_langs,
         
     | 
| 65 | 
         
            +
                        n_frames_per_step=n_frames_per_step,
         
     | 
| 66 | 
         
            +
                    )
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    self.tgt_audio_paths = tgt_audio_paths
         
     | 
| 69 | 
         
            +
                    self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    assert not target_is_code or tgt_dict is not None
         
     | 
| 72 | 
         
            +
                    self.target_is_code = target_is_code
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    assert len(tgt_audio_paths) == self.n_samples
         
     | 
| 75 | 
         
            +
                    assert len(tgt_n_frames) == self.n_samples
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    self.tgt_speakers = None
         
     | 
| 78 | 
         
            +
                    if self.cfg.target_speaker_embed:
         
     | 
| 79 | 
         
            +
                        samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
         
     | 
| 80 | 
         
            +
                            self.cfg.target_speaker_embed, split
         
     | 
| 81 | 
         
            +
                        )
         
     | 
| 82 | 
         
            +
                        spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
         
     | 
| 83 | 
         
            +
                        self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
         
     | 
| 84 | 
         
            +
                        assert len(self.tgt_speakers) == self.n_samples
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    logger.info(self.__repr__())
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def pack_units(self, input: torch.Tensor) -> torch.Tensor:
         
     | 
| 89 | 
         
            +
                    if self.n_frames_per_step <= 1:
         
     | 
| 90 | 
         
            +
                        return input
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    offset = 4
         
     | 
| 93 | 
         
            +
                    vocab_size = (
         
     | 
| 94 | 
         
            +
                        len(self.tgt_dict) - offset
         
     | 
| 95 | 
         
            +
                    )  # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    assert input.dim() == 1
         
     | 
| 98 | 
         
            +
                    stacked_input = (
         
     | 
| 99 | 
         
            +
                        input[:-1].view(-1, self.n_frames_per_step) - offset
         
     | 
| 100 | 
         
            +
                    )  # remove <eos>
         
     | 
| 101 | 
         
            +
                    scale = [
         
     | 
| 102 | 
         
            +
                        pow(vocab_size, self.n_frames_per_step - 1 - i)
         
     | 
| 103 | 
         
            +
                        for i in range(self.n_frames_per_step)
         
     | 
| 104 | 
         
            +
                    ]
         
     | 
| 105 | 
         
            +
                    scale = torch.LongTensor(scale).squeeze(0)
         
     | 
| 106 | 
         
            +
                    res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
         
     | 
| 107 | 
         
            +
                    res[:-1] = (stacked_input * scale).sum(dim=1) + offset
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    return res
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
         
     | 
| 112 | 
         
            +
                    source = self._get_source_audio(index)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    tgt_lang_tag = None
         
     | 
| 115 | 
         
            +
                    if self.cfg.prepend_tgt_lang_tag_as_bos:
         
     | 
| 116 | 
         
            +
                        # prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
         
     | 
| 117 | 
         
            +
                        tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    if not self.target_is_code:
         
     | 
| 120 | 
         
            +
                        target = get_features_or_waveform(self.tgt_audio_paths[index])
         
     | 
| 121 | 
         
            +
                        target = torch.from_numpy(target).float()
         
     | 
| 122 | 
         
            +
                        target = self.pack_frames(target)
         
     | 
| 123 | 
         
            +
                    else:
         
     | 
| 124 | 
         
            +
                        target = self.tgt_dict.encode_line(
         
     | 
| 125 | 
         
            +
                            self.tgt_audio_paths[index],
         
     | 
| 126 | 
         
            +
                            add_if_not_exist=False,
         
     | 
| 127 | 
         
            +
                            append_eos=True,
         
     | 
| 128 | 
         
            +
                        ).long()
         
     | 
| 129 | 
         
            +
                        if self.n_frames_per_step > 1:
         
     | 
| 130 | 
         
            +
                            n_tgt_frame = target.size(0) - 1  # exclude <eos>
         
     | 
| 131 | 
         
            +
                            keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
         
     | 
| 132 | 
         
            +
                            target = torch.cat(
         
     | 
| 133 | 
         
            +
                                (
         
     | 
| 134 | 
         
            +
                                    target[:keep_n_tgt_frame],
         
     | 
| 135 | 
         
            +
                                    target.new_full((1,), self.tgt_dict.eos()),
         
     | 
| 136 | 
         
            +
                                ),
         
     | 
| 137 | 
         
            +
                                dim=0,
         
     | 
| 138 | 
         
            +
                            )
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    if self.tgt_speakers:
         
     | 
| 141 | 
         
            +
                        tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
         
     | 
| 142 | 
         
            +
                        tgt_spk = torch.from_numpy(tgt_spk).float()
         
     | 
| 143 | 
         
            +
                    else:
         
     | 
| 144 | 
         
            +
                        tgt_spk = torch.FloatTensor([])
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    return SpeechToSpeechDatasetItem(
         
     | 
| 147 | 
         
            +
                        index=index,
         
     | 
| 148 | 
         
            +
                        source=source,
         
     | 
| 149 | 
         
            +
                        target=target,
         
     | 
| 150 | 
         
            +
                        target_speaker=tgt_spk,
         
     | 
| 151 | 
         
            +
                        tgt_lang_tag=tgt_lang_tag,
         
     | 
| 152 | 
         
            +
                    )
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
         
     | 
| 155 | 
         
            +
                    if self.target_is_code:
         
     | 
| 156 | 
         
            +
                        target = fairseq_data_utils.collate_tokens(
         
     | 
| 157 | 
         
            +
                            [x.target for x in samples],
         
     | 
| 158 | 
         
            +
                            self.tgt_dict.pad(),
         
     | 
| 159 | 
         
            +
                            self.tgt_dict.eos(),
         
     | 
| 160 | 
         
            +
                            left_pad=False,
         
     | 
| 161 | 
         
            +
                            move_eos_to_beginning=False,
         
     | 
| 162 | 
         
            +
                        )
         
     | 
| 163 | 
         
            +
                        # convert stacked units to a single id
         
     | 
| 164 | 
         
            +
                        pack_targets = [self.pack_units(x.target) for x in samples]
         
     | 
| 165 | 
         
            +
                        prev_output_tokens = fairseq_data_utils.collate_tokens(
         
     | 
| 166 | 
         
            +
                            pack_targets,
         
     | 
| 167 | 
         
            +
                            self.tgt_dict.pad(),
         
     | 
| 168 | 
         
            +
                            self.tgt_dict.eos(),
         
     | 
| 169 | 
         
            +
                            left_pad=False,
         
     | 
| 170 | 
         
            +
                            move_eos_to_beginning=True,
         
     | 
| 171 | 
         
            +
                        )
         
     | 
| 172 | 
         
            +
                        target_lengths = torch.tensor(
         
     | 
| 173 | 
         
            +
                            [x.size(0) for x in pack_targets], dtype=torch.long
         
     | 
| 174 | 
         
            +
                        )
         
     | 
| 175 | 
         
            +
                    else:
         
     | 
| 176 | 
         
            +
                        target = _collate_frames([x.target for x in samples], is_audio_input=False)
         
     | 
| 177 | 
         
            +
                        bsz, _, d = target.size()
         
     | 
| 178 | 
         
            +
                        prev_output_tokens = torch.cat(
         
     | 
| 179 | 
         
            +
                            (target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
         
     | 
| 180 | 
         
            +
                        )
         
     | 
| 181 | 
         
            +
                        target_lengths = torch.tensor(
         
     | 
| 182 | 
         
            +
                            [x.target.size(0) for x in samples], dtype=torch.long
         
     | 
| 183 | 
         
            +
                        )
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    return target, prev_output_tokens, target_lengths
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                def collater(
         
     | 
| 188 | 
         
            +
                    self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
         
     | 
| 189 | 
         
            +
                ) -> Dict:
         
     | 
| 190 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 191 | 
         
            +
                        return {}
         
     | 
| 192 | 
         
            +
                    indices = torch.tensor([x.index for x in samples], dtype=torch.long)
         
     | 
| 193 | 
         
            +
                    frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
         
     | 
| 194 | 
         
            +
                    # sort samples by descending number of frames
         
     | 
| 195 | 
         
            +
                    n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
         
     | 
| 196 | 
         
            +
                    n_frames, order = n_frames.sort(descending=True)
         
     | 
| 197 | 
         
            +
                    indices = indices.index_select(0, order)
         
     | 
| 198 | 
         
            +
                    frames = frames.index_select(0, order)
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    target, prev_output_tokens, target_lengths = self._collate_target(samples)
         
     | 
| 201 | 
         
            +
                    target = target.index_select(0, order)
         
     | 
| 202 | 
         
            +
                    target_lengths = target_lengths.index_select(0, order)
         
     | 
| 203 | 
         
            +
                    prev_output_tokens = prev_output_tokens.index_select(0, order)
         
     | 
| 204 | 
         
            +
                    ntokens = sum(x.target.size(0) for x in samples)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    tgt_speakers = None
         
     | 
| 207 | 
         
            +
                    if self.cfg.target_speaker_embed:
         
     | 
| 208 | 
         
            +
                        tgt_speakers = _collate_frames(
         
     | 
| 209 | 
         
            +
                            [x.target_speaker for x in samples], is_audio_input=True
         
     | 
| 210 | 
         
            +
                        ).index_select(0, order)
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    net_input = {
         
     | 
| 213 | 
         
            +
                        "src_tokens": frames,
         
     | 
| 214 | 
         
            +
                        "src_lengths": n_frames,
         
     | 
| 215 | 
         
            +
                        "prev_output_tokens": prev_output_tokens,
         
     | 
| 216 | 
         
            +
                        "tgt_speaker": tgt_speakers,  # TODO: unify "speaker" and "tgt_speaker"
         
     | 
| 217 | 
         
            +
                    }
         
     | 
| 218 | 
         
            +
                    if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
         
     | 
| 219 | 
         
            +
                        for i in range(len(samples)):
         
     | 
| 220 | 
         
            +
                            net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
         
     | 
| 221 | 
         
            +
                    out = {
         
     | 
| 222 | 
         
            +
                        "id": indices,
         
     | 
| 223 | 
         
            +
                        "net_input": net_input,
         
     | 
| 224 | 
         
            +
                        "speaker": tgt_speakers,  # to support Tacotron2 loss for speech-to-spectrogram model
         
     | 
| 225 | 
         
            +
                        "target": target,
         
     | 
| 226 | 
         
            +
                        "target_lengths": target_lengths,
         
     | 
| 227 | 
         
            +
                        "ntokens": ntokens,
         
     | 
| 228 | 
         
            +
                        "nsentences": len(samples),
         
     | 
| 229 | 
         
            +
                    }
         
     | 
| 230 | 
         
            +
                    if return_order:
         
     | 
| 231 | 
         
            +
                        out["order"] = order
         
     | 
| 232 | 
         
            +
                    return out
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
            class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
         
     | 
| 236 | 
         
            +
                def __init__(self, **kwargs):
         
     | 
| 237 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 238 | 
         
            +
                    self.multitask_data = {}
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                def add_multitask_dataset(self, task_name, task_data):
         
     | 
| 241 | 
         
            +
                    self.multitask_data[task_name] = task_data
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                def __getitem__(
         
     | 
| 244 | 
         
            +
                    self, index: int
         
     | 
| 245 | 
         
            +
                ) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
         
     | 
| 246 | 
         
            +
                    s2s_data = super().__getitem__(index)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    multitask_target = {}
         
     | 
| 249 | 
         
            +
                    sample_id = self.ids[index]
         
     | 
| 250 | 
         
            +
                    tgt_lang = self.tgt_langs[index]
         
     | 
| 251 | 
         
            +
                    for task_name, task_dataset in self.multitask_data.items():
         
     | 
| 252 | 
         
            +
                        multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                    return s2s_data, multitask_target
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                def collater(
         
     | 
| 257 | 
         
            +
                    self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
         
     | 
| 258 | 
         
            +
                ) -> Dict:
         
     | 
| 259 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 260 | 
         
            +
                        return {}
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    out = super().collater([s for s, _ in samples], return_order=True)
         
     | 
| 263 | 
         
            +
                    order = out["order"]
         
     | 
| 264 | 
         
            +
                    del out["order"]
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    for task_name, task_dataset in self.multitask_data.items():
         
     | 
| 267 | 
         
            +
                        if "multitask" not in out:
         
     | 
| 268 | 
         
            +
                            out["multitask"] = {}
         
     | 
| 269 | 
         
            +
                        d = [s[task_name] for _, s in samples]
         
     | 
| 270 | 
         
            +
                        task_target = task_dataset.collater(d)
         
     | 
| 271 | 
         
            +
                        out["multitask"][task_name] = {
         
     | 
| 272 | 
         
            +
                            "target": task_target["target"].index_select(0, order),
         
     | 
| 273 | 
         
            +
                            "target_lengths": task_target["target_lengths"].index_select(0, order),
         
     | 
| 274 | 
         
            +
                            "ntokens": task_target["ntokens"],
         
     | 
| 275 | 
         
            +
                        }
         
     | 
| 276 | 
         
            +
                        out["multitask"][task_name]["net_input"] = {
         
     | 
| 277 | 
         
            +
                            "prev_output_tokens": task_target["prev_output_tokens"].index_select(
         
     | 
| 278 | 
         
            +
                                0, order
         
     | 
| 279 | 
         
            +
                            ),
         
     | 
| 280 | 
         
            +
                        }
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    return out
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
            class SpeechToSpeechDatasetCreator(object):
         
     | 
| 286 | 
         
            +
                # mandatory columns
         
     | 
| 287 | 
         
            +
                KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
         
     | 
| 288 | 
         
            +
                KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
         
     | 
| 289 | 
         
            +
                # optional columns
         
     | 
| 290 | 
         
            +
                KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
         
     | 
| 291 | 
         
            +
                # default values
         
     | 
| 292 | 
         
            +
                DEFAULT_LANG = ""
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                @classmethod
         
     | 
| 295 | 
         
            +
                def _from_list(
         
     | 
| 296 | 
         
            +
                    cls,
         
     | 
| 297 | 
         
            +
                    split_name: str,
         
     | 
| 298 | 
         
            +
                    is_train_split,
         
     | 
| 299 | 
         
            +
                    samples: List[Dict],
         
     | 
| 300 | 
         
            +
                    data_cfg: S2SDataConfig,
         
     | 
| 301 | 
         
            +
                    target_is_code: bool = False,
         
     | 
| 302 | 
         
            +
                    tgt_dict: Dictionary = None,
         
     | 
| 303 | 
         
            +
                    n_frames_per_step: int = 1,
         
     | 
| 304 | 
         
            +
                    multitask: Optional[Dict] = None,
         
     | 
| 305 | 
         
            +
                ) -> SpeechToSpeechDataset:
         
     | 
| 306 | 
         
            +
                    audio_root = Path(data_cfg.audio_root)
         
     | 
| 307 | 
         
            +
                    ids = [s[cls.KEY_ID] for s in samples]
         
     | 
| 308 | 
         
            +
                    src_audio_paths = [
         
     | 
| 309 | 
         
            +
                        (audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
         
     | 
| 310 | 
         
            +
                    ]
         
     | 
| 311 | 
         
            +
                    tgt_audio_paths = [
         
     | 
| 312 | 
         
            +
                        s[cls.KEY_TGT_AUDIO]
         
     | 
| 313 | 
         
            +
                        if target_is_code
         
     | 
| 314 | 
         
            +
                        else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
         
     | 
| 315 | 
         
            +
                        for s in samples
         
     | 
| 316 | 
         
            +
                    ]
         
     | 
| 317 | 
         
            +
                    src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
         
     | 
| 318 | 
         
            +
                    tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
         
     | 
| 319 | 
         
            +
                    src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
         
     | 
| 320 | 
         
            +
                    tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    has_multitask = multitask is not None and len(multitask.keys()) > 0
         
     | 
| 323 | 
         
            +
                    dataset_cls = (
         
     | 
| 324 | 
         
            +
                        SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
         
     | 
| 325 | 
         
            +
                    )
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    ds = dataset_cls(
         
     | 
| 328 | 
         
            +
                        split=split_name,
         
     | 
| 329 | 
         
            +
                        is_train_split=is_train_split,
         
     | 
| 330 | 
         
            +
                        data_cfg=data_cfg,
         
     | 
| 331 | 
         
            +
                        src_audio_paths=src_audio_paths,
         
     | 
| 332 | 
         
            +
                        src_n_frames=src_n_frames,
         
     | 
| 333 | 
         
            +
                        tgt_audio_paths=tgt_audio_paths,
         
     | 
| 334 | 
         
            +
                        tgt_n_frames=tgt_n_frames,
         
     | 
| 335 | 
         
            +
                        src_langs=src_langs,
         
     | 
| 336 | 
         
            +
                        tgt_langs=tgt_langs,
         
     | 
| 337 | 
         
            +
                        ids=ids,
         
     | 
| 338 | 
         
            +
                        target_is_code=target_is_code,
         
     | 
| 339 | 
         
            +
                        tgt_dict=tgt_dict,
         
     | 
| 340 | 
         
            +
                        n_frames_per_step=n_frames_per_step,
         
     | 
| 341 | 
         
            +
                    )
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    if has_multitask:
         
     | 
| 344 | 
         
            +
                        for task_name, task_obj in multitask.items():
         
     | 
| 345 | 
         
            +
                            task_data = TextTargetMultitaskData(
         
     | 
| 346 | 
         
            +
                                task_obj.args, split_name, task_obj.target_dictionary
         
     | 
| 347 | 
         
            +
                            )
         
     | 
| 348 | 
         
            +
                            ds.add_multitask_dataset(task_name, task_data)
         
     | 
| 349 | 
         
            +
                    return ds
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                @classmethod
         
     | 
| 352 | 
         
            +
                def from_tsv(
         
     | 
| 353 | 
         
            +
                    cls,
         
     | 
| 354 | 
         
            +
                    root: str,
         
     | 
| 355 | 
         
            +
                    data_cfg: S2SDataConfig,
         
     | 
| 356 | 
         
            +
                    splits: str,
         
     | 
| 357 | 
         
            +
                    is_train_split: bool,
         
     | 
| 358 | 
         
            +
                    epoch: int,
         
     | 
| 359 | 
         
            +
                    seed: int,
         
     | 
| 360 | 
         
            +
                    target_is_code: bool = False,
         
     | 
| 361 | 
         
            +
                    tgt_dict: Dictionary = None,
         
     | 
| 362 | 
         
            +
                    n_frames_per_step: int = 1,
         
     | 
| 363 | 
         
            +
                    multitask: Optional[Dict] = None,
         
     | 
| 364 | 
         
            +
                ) -> SpeechToSpeechDataset:
         
     | 
| 365 | 
         
            +
                    datasets = []
         
     | 
| 366 | 
         
            +
                    for split in splits.split(","):
         
     | 
| 367 | 
         
            +
                        samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
         
     | 
| 368 | 
         
            +
                        ds = cls._from_list(
         
     | 
| 369 | 
         
            +
                            split_name=split,
         
     | 
| 370 | 
         
            +
                            is_train_split=is_train_split,
         
     | 
| 371 | 
         
            +
                            samples=samples,
         
     | 
| 372 | 
         
            +
                            data_cfg=data_cfg,
         
     | 
| 373 | 
         
            +
                            target_is_code=target_is_code,
         
     | 
| 374 | 
         
            +
                            tgt_dict=tgt_dict,
         
     | 
| 375 | 
         
            +
                            n_frames_per_step=n_frames_per_step,
         
     | 
| 376 | 
         
            +
                            multitask=multitask,
         
     | 
| 377 | 
         
            +
                        )
         
     | 
| 378 | 
         
            +
                        datasets.append(ds)
         
     | 
| 379 | 
         
            +
                    return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/speech_to_text_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,733 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import csv
         
     | 
| 7 | 
         
            +
            import logging
         
     | 
| 8 | 
         
            +
            import re
         
     | 
| 9 | 
         
            +
            from argparse import Namespace
         
     | 
| 10 | 
         
            +
            from collections import defaultdict
         
     | 
| 11 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 12 | 
         
            +
            from pathlib import Path
         
     | 
| 13 | 
         
            +
            from typing import Dict, List, Optional, Tuple, Union
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import numpy as np
         
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
         
     | 
| 20 | 
         
            +
            from fairseq.data import data_utils as fairseq_data_utils
         
     | 
| 21 | 
         
            +
            from fairseq.data import encoders
         
     | 
| 22 | 
         
            +
            from fairseq.data.audio.audio_utils import get_features_or_waveform
         
     | 
| 23 | 
         
            +
            from fairseq.data.audio.data_cfg import S2TDataConfig
         
     | 
| 24 | 
         
            +
            from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
         
     | 
| 25 | 
         
            +
            from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
         
     | 
| 26 | 
         
            +
            from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
         
     | 
| 27 | 
         
            +
                NoisyOverlapAugment,
         
     | 
| 28 | 
         
            +
            )
         
     | 
| 29 | 
         
            +
            from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
         
     | 
| 30 | 
         
            +
            from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def _collate_frames(
         
     | 
| 36 | 
         
            +
                frames: List[torch.Tensor], is_audio_input: bool = False
         
     | 
| 37 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 38 | 
         
            +
                """
         
     | 
| 39 | 
         
            +
                Convert a list of 2D frames into a padded 3D tensor
         
     | 
| 40 | 
         
            +
                Args:
         
     | 
| 41 | 
         
            +
                    frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
         
     | 
| 42 | 
         
            +
                        length of i-th frame and f_dim is static dimension of features
         
     | 
| 43 | 
         
            +
                Returns:
         
     | 
| 44 | 
         
            +
                    3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
         
     | 
| 45 | 
         
            +
                """
         
     | 
| 46 | 
         
            +
                max_len = max(frame.size(0) for frame in frames)
         
     | 
| 47 | 
         
            +
                if is_audio_input:
         
     | 
| 48 | 
         
            +
                    out = frames[0].new_zeros((len(frames), max_len))
         
     | 
| 49 | 
         
            +
                else:
         
     | 
| 50 | 
         
            +
                    out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
         
     | 
| 51 | 
         
            +
                for i, v in enumerate(frames):
         
     | 
| 52 | 
         
            +
                    out[i, : v.size(0)] = v
         
     | 
| 53 | 
         
            +
                return out
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def _is_int_or_np_int(n):
         
     | 
| 57 | 
         
            +
                return isinstance(n, int) or (
         
     | 
| 58 | 
         
            +
                    isinstance(n, np.generic) and isinstance(n.item(), int)
         
     | 
| 59 | 
         
            +
                )
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            @dataclass
         
     | 
| 63 | 
         
            +
            class SpeechToTextDatasetItem(object):
         
     | 
| 64 | 
         
            +
                index: int
         
     | 
| 65 | 
         
            +
                source: torch.Tensor
         
     | 
| 66 | 
         
            +
                target: Optional[torch.Tensor] = None
         
     | 
| 67 | 
         
            +
                speaker_id: Optional[int] = None
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            class SpeechToTextDataset(FairseqDataset):
         
     | 
| 71 | 
         
            +
                LANG_TAG_TEMPLATE = "<lang:{}>"
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def __init__(
         
     | 
| 74 | 
         
            +
                    self,
         
     | 
| 75 | 
         
            +
                    split: str,
         
     | 
| 76 | 
         
            +
                    is_train_split: bool,
         
     | 
| 77 | 
         
            +
                    cfg: S2TDataConfig,
         
     | 
| 78 | 
         
            +
                    audio_paths: List[str],
         
     | 
| 79 | 
         
            +
                    n_frames: List[int],
         
     | 
| 80 | 
         
            +
                    src_texts: Optional[List[str]] = None,
         
     | 
| 81 | 
         
            +
                    tgt_texts: Optional[List[str]] = None,
         
     | 
| 82 | 
         
            +
                    speakers: Optional[List[str]] = None,
         
     | 
| 83 | 
         
            +
                    src_langs: Optional[List[str]] = None,
         
     | 
| 84 | 
         
            +
                    tgt_langs: Optional[List[str]] = None,
         
     | 
| 85 | 
         
            +
                    ids: Optional[List[str]] = None,
         
     | 
| 86 | 
         
            +
                    tgt_dict: Optional[Dictionary] = None,
         
     | 
| 87 | 
         
            +
                    pre_tokenizer=None,
         
     | 
| 88 | 
         
            +
                    bpe_tokenizer=None,
         
     | 
| 89 | 
         
            +
                    n_frames_per_step=1,
         
     | 
| 90 | 
         
            +
                    speaker_to_id=None,
         
     | 
| 91 | 
         
            +
                    append_eos=True,
         
     | 
| 92 | 
         
            +
                ):
         
     | 
| 93 | 
         
            +
                    self.split, self.is_train_split = split, is_train_split
         
     | 
| 94 | 
         
            +
                    self.cfg = cfg
         
     | 
| 95 | 
         
            +
                    self.audio_paths, self.n_frames = audio_paths, n_frames
         
     | 
| 96 | 
         
            +
                    self.n_samples = len(audio_paths)
         
     | 
| 97 | 
         
            +
                    assert len(n_frames) == self.n_samples > 0
         
     | 
| 98 | 
         
            +
                    assert src_texts is None or len(src_texts) == self.n_samples
         
     | 
| 99 | 
         
            +
                    assert tgt_texts is None or len(tgt_texts) == self.n_samples
         
     | 
| 100 | 
         
            +
                    assert speakers is None or len(speakers) == self.n_samples
         
     | 
| 101 | 
         
            +
                    assert src_langs is None or len(src_langs) == self.n_samples
         
     | 
| 102 | 
         
            +
                    assert tgt_langs is None or len(tgt_langs) == self.n_samples
         
     | 
| 103 | 
         
            +
                    assert ids is None or len(ids) == self.n_samples
         
     | 
| 104 | 
         
            +
                    assert (tgt_dict is None and tgt_texts is None) or (
         
     | 
| 105 | 
         
            +
                        tgt_dict is not None and tgt_texts is not None
         
     | 
| 106 | 
         
            +
                    )
         
     | 
| 107 | 
         
            +
                    self.src_texts, self.tgt_texts = src_texts, tgt_texts
         
     | 
| 108 | 
         
            +
                    self.src_langs, self.tgt_langs = src_langs, tgt_langs
         
     | 
| 109 | 
         
            +
                    self.speakers = speakers
         
     | 
| 110 | 
         
            +
                    self.tgt_dict = tgt_dict
         
     | 
| 111 | 
         
            +
                    self.check_tgt_lang_tag()
         
     | 
| 112 | 
         
            +
                    self.ids = ids
         
     | 
| 113 | 
         
            +
                    self.shuffle = cfg.shuffle if is_train_split else False
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
         
     | 
| 116 | 
         
            +
                        self.cfg.get_feature_transforms(split, is_train_split)
         
     | 
| 117 | 
         
            +
                    )
         
     | 
| 118 | 
         
            +
                    self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
         
     | 
| 119 | 
         
            +
                        self.cfg.get_waveform_transforms(split, is_train_split)
         
     | 
| 120 | 
         
            +
                    )
         
     | 
| 121 | 
         
            +
                    # TODO: add these to data_cfg.py
         
     | 
| 122 | 
         
            +
                    self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
         
     | 
| 123 | 
         
            +
                        self.cfg.get_dataset_transforms(split, is_train_split)
         
     | 
| 124 | 
         
            +
                    )
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    # check proper usage of transforms
         
     | 
| 127 | 
         
            +
                    if self.feature_transforms and self.cfg.use_audio_input:
         
     | 
| 128 | 
         
            +
                        logger.warning(
         
     | 
| 129 | 
         
            +
                            "Feature transforms will not be applied. To use feature transforms, "
         
     | 
| 130 | 
         
            +
                            "set use_audio_input as False in config."
         
     | 
| 131 | 
         
            +
                        )
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    self.pre_tokenizer = pre_tokenizer
         
     | 
| 134 | 
         
            +
                    self.bpe_tokenizer = bpe_tokenizer
         
     | 
| 135 | 
         
            +
                    self.n_frames_per_step = n_frames_per_step
         
     | 
| 136 | 
         
            +
                    self.speaker_to_id = speaker_to_id
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    self.tgt_lens = self.get_tgt_lens_and_check_oov()
         
     | 
| 139 | 
         
            +
                    self.append_eos = append_eos
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    logger.info(self.__repr__())
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                def get_tgt_lens_and_check_oov(self):
         
     | 
| 144 | 
         
            +
                    if self.tgt_texts is None:
         
     | 
| 145 | 
         
            +
                        return [0 for _ in range(self.n_samples)]
         
     | 
| 146 | 
         
            +
                    tgt_lens = []
         
     | 
| 147 | 
         
            +
                    n_tokens, n_oov_tokens = 0, 0
         
     | 
| 148 | 
         
            +
                    for i in range(self.n_samples):
         
     | 
| 149 | 
         
            +
                        tokenized = self.get_tokenized_tgt_text(i).split(" ")
         
     | 
| 150 | 
         
            +
                        oov_tokens = [
         
     | 
| 151 | 
         
            +
                            t
         
     | 
| 152 | 
         
            +
                            for t in tokenized
         
     | 
| 153 | 
         
            +
                            if self.tgt_dict.index(t) == self.tgt_dict.unk_index
         
     | 
| 154 | 
         
            +
                        ]
         
     | 
| 155 | 
         
            +
                        n_tokens += len(tokenized)
         
     | 
| 156 | 
         
            +
                        n_oov_tokens += len(oov_tokens)
         
     | 
| 157 | 
         
            +
                        tgt_lens.append(len(tokenized))
         
     | 
| 158 | 
         
            +
                    logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
         
     | 
| 159 | 
         
            +
                    return tgt_lens
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                def __repr__(self):
         
     | 
| 162 | 
         
            +
                    return (
         
     | 
| 163 | 
         
            +
                        self.__class__.__name__
         
     | 
| 164 | 
         
            +
                        + f'(split="{self.split}", n_samples={self.n_samples:_}, '
         
     | 
| 165 | 
         
            +
                        f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
         
     | 
| 166 | 
         
            +
                        f"n_frames_per_step={self.n_frames_per_step}, "
         
     | 
| 167 | 
         
            +
                        f"shuffle={self.shuffle}, "
         
     | 
| 168 | 
         
            +
                        f"feature_transforms={self.feature_transforms}, "
         
     | 
| 169 | 
         
            +
                        f"waveform_transforms={self.waveform_transforms}, "
         
     | 
| 170 | 
         
            +
                        f"dataset_transforms={self.dataset_transforms})"
         
     | 
| 171 | 
         
            +
                    )
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                @classmethod
         
     | 
| 174 | 
         
            +
                def is_lang_tag(cls, token):
         
     | 
| 175 | 
         
            +
                    pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
         
     | 
| 176 | 
         
            +
                    return re.match(pattern, token)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                def check_tgt_lang_tag(self):
         
     | 
| 179 | 
         
            +
                    if self.cfg.prepend_tgt_lang_tag:
         
     | 
| 180 | 
         
            +
                        assert self.tgt_langs is not None and self.tgt_dict is not None
         
     | 
| 181 | 
         
            +
                        tgt_lang_tags = [
         
     | 
| 182 | 
         
            +
                            self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
         
     | 
| 183 | 
         
            +
                        ]
         
     | 
| 184 | 
         
            +
                        assert all(t in self.tgt_dict for t in tgt_lang_tags)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                @classmethod
         
     | 
| 187 | 
         
            +
                def tokenize(cls, tokenizer, text: str):
         
     | 
| 188 | 
         
            +
                    return text if tokenizer is None else tokenizer.encode(text)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
         
     | 
| 191 | 
         
            +
                    if _is_int_or_np_int(index):
         
     | 
| 192 | 
         
            +
                        text = self.tgt_texts[index]
         
     | 
| 193 | 
         
            +
                    else:
         
     | 
| 194 | 
         
            +
                        text = " ".join([self.tgt_texts[i] for i in index])
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    text = self.tokenize(self.pre_tokenizer, text)
         
     | 
| 197 | 
         
            +
                    text = self.tokenize(self.bpe_tokenizer, text)
         
     | 
| 198 | 
         
            +
                    return text
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                def pack_frames(self, feature: torch.Tensor):
         
     | 
| 201 | 
         
            +
                    if self.n_frames_per_step == 1:
         
     | 
| 202 | 
         
            +
                        return feature
         
     | 
| 203 | 
         
            +
                    n_packed_frames = feature.shape[0] // self.n_frames_per_step
         
     | 
| 204 | 
         
            +
                    feature = feature[: self.n_frames_per_step * n_packed_frames]
         
     | 
| 205 | 
         
            +
                    return feature.reshape(n_packed_frames, -1)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                @classmethod
         
     | 
| 208 | 
         
            +
                def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
         
     | 
| 209 | 
         
            +
                    lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
         
     | 
| 210 | 
         
            +
                    assert lang_tag_idx != dictionary.unk()
         
     | 
| 211 | 
         
            +
                    return lang_tag_idx
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
         
     | 
| 214 | 
         
            +
                    """
         
     | 
| 215 | 
         
            +
                    Gives source audio for given index with any relevant transforms
         
     | 
| 216 | 
         
            +
                    applied. For ConcatAug, source audios for given indices are
         
     | 
| 217 | 
         
            +
                    concatenated in given order.
         
     | 
| 218 | 
         
            +
                    Args:
         
     | 
| 219 | 
         
            +
                        index (int or List[int]): index—or in the case of ConcatAug,
         
     | 
| 220 | 
         
            +
                        indices—to pull the source audio for
         
     | 
| 221 | 
         
            +
                    Returns:
         
     | 
| 222 | 
         
            +
                        source audios concatenated for given indices with
         
     | 
| 223 | 
         
            +
                        relevant transforms appplied
         
     | 
| 224 | 
         
            +
                    """
         
     | 
| 225 | 
         
            +
                    if _is_int_or_np_int(index):
         
     | 
| 226 | 
         
            +
                        source = get_features_or_waveform(
         
     | 
| 227 | 
         
            +
                            self.audio_paths[index],
         
     | 
| 228 | 
         
            +
                            need_waveform=self.cfg.use_audio_input,
         
     | 
| 229 | 
         
            +
                            use_sample_rate=self.cfg.use_sample_rate,
         
     | 
| 230 | 
         
            +
                            waveform_transforms=self.waveform_transforms,
         
     | 
| 231 | 
         
            +
                        )
         
     | 
| 232 | 
         
            +
                    else:
         
     | 
| 233 | 
         
            +
                        source = np.concatenate(
         
     | 
| 234 | 
         
            +
                            [
         
     | 
| 235 | 
         
            +
                                get_features_or_waveform(
         
     | 
| 236 | 
         
            +
                                    self.audio_paths[i],
         
     | 
| 237 | 
         
            +
                                    need_waveform=self.cfg.use_audio_input,
         
     | 
| 238 | 
         
            +
                                    use_sample_rate=self.cfg.use_sample_rate,
         
     | 
| 239 | 
         
            +
                                    waveform_transforms=self.waveform_transforms,
         
     | 
| 240 | 
         
            +
                                )
         
     | 
| 241 | 
         
            +
                                for i in index
         
     | 
| 242 | 
         
            +
                            ]
         
     | 
| 243 | 
         
            +
                        )
         
     | 
| 244 | 
         
            +
                    if self.cfg.use_audio_input:
         
     | 
| 245 | 
         
            +
                        source = torch.from_numpy(source).float()
         
     | 
| 246 | 
         
            +
                        if self.cfg.standardize_audio:
         
     | 
| 247 | 
         
            +
                            with torch.no_grad():
         
     | 
| 248 | 
         
            +
                                source = F.layer_norm(source, source.shape)
         
     | 
| 249 | 
         
            +
                    else:
         
     | 
| 250 | 
         
            +
                        if self.feature_transforms is not None:
         
     | 
| 251 | 
         
            +
                            source = self.feature_transforms(source)
         
     | 
| 252 | 
         
            +
                        source = torch.from_numpy(source).float()
         
     | 
| 253 | 
         
            +
                    return source
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
         
     | 
| 256 | 
         
            +
                    has_concat = self.dataset_transforms.has_transform(ConcatAugment)
         
     | 
| 257 | 
         
            +
                    if has_concat:
         
     | 
| 258 | 
         
            +
                        concat = self.dataset_transforms.get_transform(ConcatAugment)
         
     | 
| 259 | 
         
            +
                        indices = concat.find_indices(index, self.n_frames, self.n_samples)
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    source = self._get_source_audio(indices if has_concat else index)
         
     | 
| 262 | 
         
            +
                    source = self.pack_frames(source)
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    target = None
         
     | 
| 265 | 
         
            +
                    if self.tgt_texts is not None:
         
     | 
| 266 | 
         
            +
                        tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
         
     | 
| 267 | 
         
            +
                        target = self.tgt_dict.encode_line(
         
     | 
| 268 | 
         
            +
                            tokenized, add_if_not_exist=False, append_eos=self.append_eos
         
     | 
| 269 | 
         
            +
                        ).long()
         
     | 
| 270 | 
         
            +
                        if self.cfg.prepend_tgt_lang_tag:
         
     | 
| 271 | 
         
            +
                            lang_tag_idx = self.get_lang_tag_idx(
         
     | 
| 272 | 
         
            +
                                self.tgt_langs[index], self.tgt_dict
         
     | 
| 273 | 
         
            +
                            )
         
     | 
| 274 | 
         
            +
                            target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    if self.cfg.prepend_bos_and_append_tgt_lang_tag:
         
     | 
| 277 | 
         
            +
                        bos = torch.LongTensor([self.tgt_dict.bos()])
         
     | 
| 278 | 
         
            +
                        lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
         
     | 
| 279 | 
         
            +
                        assert lang_tag_idx != self.tgt_dict.unk()
         
     | 
| 280 | 
         
            +
                        lang_tag_idx = torch.LongTensor([lang_tag_idx])
         
     | 
| 281 | 
         
            +
                        target = torch.cat((bos, target, lang_tag_idx), 0)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    speaker_id = None
         
     | 
| 284 | 
         
            +
                    if self.speaker_to_id is not None:
         
     | 
| 285 | 
         
            +
                        speaker_id = self.speaker_to_id[self.speakers[index]]
         
     | 
| 286 | 
         
            +
                    return SpeechToTextDatasetItem(
         
     | 
| 287 | 
         
            +
                        index=index, source=source, target=target, speaker_id=speaker_id
         
     | 
| 288 | 
         
            +
                    )
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                def __len__(self):
         
     | 
| 291 | 
         
            +
                    return self.n_samples
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def collater(
         
     | 
| 294 | 
         
            +
                    self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
         
     | 
| 295 | 
         
            +
                ) -> Dict:
         
     | 
| 296 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 297 | 
         
            +
                        return {}
         
     | 
| 298 | 
         
            +
                    indices = torch.tensor([x.index for x in samples], dtype=torch.long)
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                    sources = [x.source for x in samples]
         
     | 
| 301 | 
         
            +
                    has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
         
     | 
| 302 | 
         
            +
                    if has_NOAug and self.cfg.use_audio_input:
         
     | 
| 303 | 
         
            +
                        NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
         
     | 
| 304 | 
         
            +
                        sources = NOAug(sources)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    frames = _collate_frames(sources, self.cfg.use_audio_input)
         
     | 
| 307 | 
         
            +
                    # sort samples by descending number of frames
         
     | 
| 308 | 
         
            +
                    n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
         
     | 
| 309 | 
         
            +
                    n_frames, order = n_frames.sort(descending=True)
         
     | 
| 310 | 
         
            +
                    indices = indices.index_select(0, order)
         
     | 
| 311 | 
         
            +
                    frames = frames.index_select(0, order)
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                    target, target_lengths = None, None
         
     | 
| 314 | 
         
            +
                    prev_output_tokens = None
         
     | 
| 315 | 
         
            +
                    ntokens = None
         
     | 
| 316 | 
         
            +
                    if self.tgt_texts is not None:
         
     | 
| 317 | 
         
            +
                        target = fairseq_data_utils.collate_tokens(
         
     | 
| 318 | 
         
            +
                            [x.target for x in samples],
         
     | 
| 319 | 
         
            +
                            self.tgt_dict.pad(),
         
     | 
| 320 | 
         
            +
                            self.tgt_dict.eos(),
         
     | 
| 321 | 
         
            +
                            left_pad=False,
         
     | 
| 322 | 
         
            +
                            move_eos_to_beginning=False,
         
     | 
| 323 | 
         
            +
                        )
         
     | 
| 324 | 
         
            +
                        target = target.index_select(0, order)
         
     | 
| 325 | 
         
            +
                        target_lengths = torch.tensor(
         
     | 
| 326 | 
         
            +
                            [x.target.size(0) for x in samples], dtype=torch.long
         
     | 
| 327 | 
         
            +
                        ).index_select(0, order)
         
     | 
| 328 | 
         
            +
                        prev_output_tokens = fairseq_data_utils.collate_tokens(
         
     | 
| 329 | 
         
            +
                            [x.target for x in samples],
         
     | 
| 330 | 
         
            +
                            self.tgt_dict.pad(),
         
     | 
| 331 | 
         
            +
                            eos_idx=None,
         
     | 
| 332 | 
         
            +
                            left_pad=False,
         
     | 
| 333 | 
         
            +
                            move_eos_to_beginning=True,
         
     | 
| 334 | 
         
            +
                        )
         
     | 
| 335 | 
         
            +
                        prev_output_tokens = prev_output_tokens.index_select(0, order)
         
     | 
| 336 | 
         
            +
                        ntokens = sum(x.target.size(0) for x in samples)
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                    speaker = None
         
     | 
| 339 | 
         
            +
                    if self.speaker_to_id is not None:
         
     | 
| 340 | 
         
            +
                        speaker = (
         
     | 
| 341 | 
         
            +
                            torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
         
     | 
| 342 | 
         
            +
                            .index_select(0, order)
         
     | 
| 343 | 
         
            +
                            .view(-1, 1)
         
     | 
| 344 | 
         
            +
                        )
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    net_input = {
         
     | 
| 347 | 
         
            +
                        "src_tokens": frames,
         
     | 
| 348 | 
         
            +
                        "src_lengths": n_frames,
         
     | 
| 349 | 
         
            +
                        "prev_output_tokens": prev_output_tokens,
         
     | 
| 350 | 
         
            +
                    }
         
     | 
| 351 | 
         
            +
                    out = {
         
     | 
| 352 | 
         
            +
                        "id": indices,
         
     | 
| 353 | 
         
            +
                        "net_input": net_input,
         
     | 
| 354 | 
         
            +
                        "speaker": speaker,
         
     | 
| 355 | 
         
            +
                        "target": target,
         
     | 
| 356 | 
         
            +
                        "target_lengths": target_lengths,
         
     | 
| 357 | 
         
            +
                        "ntokens": ntokens,
         
     | 
| 358 | 
         
            +
                        "nsentences": len(samples),
         
     | 
| 359 | 
         
            +
                    }
         
     | 
| 360 | 
         
            +
                    if return_order:
         
     | 
| 361 | 
         
            +
                        out["order"] = order
         
     | 
| 362 | 
         
            +
                    return out
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 365 | 
         
            +
                    return self.n_frames[index]
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                def size(self, index):
         
     | 
| 368 | 
         
            +
                    return self.n_frames[index], self.tgt_lens[index]
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                @property
         
     | 
| 371 | 
         
            +
                def sizes(self):
         
     | 
| 372 | 
         
            +
                    return np.array(self.n_frames)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                @property
         
     | 
| 375 | 
         
            +
                def can_reuse_epoch_itr_across_epochs(self):
         
     | 
| 376 | 
         
            +
                    return True
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                def ordered_indices(self):
         
     | 
| 379 | 
         
            +
                    if self.shuffle:
         
     | 
| 380 | 
         
            +
                        order = [np.random.permutation(len(self))]
         
     | 
| 381 | 
         
            +
                    else:
         
     | 
| 382 | 
         
            +
                        order = [np.arange(len(self))]
         
     | 
| 383 | 
         
            +
                    # first by descending order of # of frames then by original/random order
         
     | 
| 384 | 
         
            +
                    order.append([-n for n in self.n_frames])
         
     | 
| 385 | 
         
            +
                    return np.lexsort(order)
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                def prefetch(self, indices):
         
     | 
| 388 | 
         
            +
                    raise False
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
            class TextTargetMultitaskData(object):
         
     | 
| 392 | 
         
            +
                # mandatory columns
         
     | 
| 393 | 
         
            +
                KEY_ID, KEY_TEXT = "id", "tgt_text"
         
     | 
| 394 | 
         
            +
                LANG_TAG_TEMPLATE = "<lang:{}>"
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                def __init__(self, args, split, tgt_dict):
         
     | 
| 397 | 
         
            +
                    samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
         
     | 
| 398 | 
         
            +
                    self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
         
     | 
| 399 | 
         
            +
                    self.dict = tgt_dict
         
     | 
| 400 | 
         
            +
                    self.append_eos = args.decoder_type != "ctc"
         
     | 
| 401 | 
         
            +
                    self.pre_tokenizer = self.build_tokenizer(args)
         
     | 
| 402 | 
         
            +
                    self.bpe_tokenizer = self.build_bpe(args)
         
     | 
| 403 | 
         
            +
                    self.prepend_bos_and_append_tgt_lang_tag = (
         
     | 
| 404 | 
         
            +
                        args.prepend_bos_and_append_tgt_lang_tag
         
     | 
| 405 | 
         
            +
                    )
         
     | 
| 406 | 
         
            +
                    self.eos_token = args.eos_token
         
     | 
| 407 | 
         
            +
                    self.lang_tag_mapping = args.get_lang_tag_mapping
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                @classmethod
         
     | 
| 410 | 
         
            +
                def is_lang_tag(cls, token):
         
     | 
| 411 | 
         
            +
                    pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
         
     | 
| 412 | 
         
            +
                    return re.match(pattern, token)
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                @classmethod
         
     | 
| 415 | 
         
            +
                def tokenize(cls, tokenizer, text: str):
         
     | 
| 416 | 
         
            +
                    return text if tokenizer is None else tokenizer.encode(text)
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                def get_tokenized_tgt_text(self, index: int):
         
     | 
| 419 | 
         
            +
                    text = self.tokenize(self.pre_tokenizer, self.data[index])
         
     | 
| 420 | 
         
            +
                    text = self.tokenize(self.bpe_tokenizer, text)
         
     | 
| 421 | 
         
            +
                    return text
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
         
     | 
| 424 | 
         
            +
                    lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
         
     | 
| 425 | 
         
            +
                    lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
         
     | 
| 426 | 
         
            +
                    lang_tag_idx = dictionary.index(lang_tag)
         
     | 
| 427 | 
         
            +
                    assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
         
     | 
| 428 | 
         
            +
                    return lang_tag_idx
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
                def build_tokenizer(self, args):
         
     | 
| 431 | 
         
            +
                    pre_tokenizer = args.config.get("pre_tokenizer")
         
     | 
| 432 | 
         
            +
                    if pre_tokenizer is not None:
         
     | 
| 433 | 
         
            +
                        logger.info(f"pre-tokenizer: {pre_tokenizer}")
         
     | 
| 434 | 
         
            +
                        return encoders.build_tokenizer(Namespace(**pre_tokenizer))
         
     | 
| 435 | 
         
            +
                    else:
         
     | 
| 436 | 
         
            +
                        return None
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                def build_bpe(self, args):
         
     | 
| 439 | 
         
            +
                    bpe_tokenizer = args.config.get("bpe_tokenizer")
         
     | 
| 440 | 
         
            +
                    if bpe_tokenizer is not None:
         
     | 
| 441 | 
         
            +
                        logger.info(f"tokenizer: {bpe_tokenizer}")
         
     | 
| 442 | 
         
            +
                        return encoders.build_bpe(Namespace(**bpe_tokenizer))
         
     | 
| 443 | 
         
            +
                    else:
         
     | 
| 444 | 
         
            +
                        return None
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                def get(self, sample_id, tgt_lang=None):
         
     | 
| 447 | 
         
            +
                    if sample_id in self.data:
         
     | 
| 448 | 
         
            +
                        tokenized = self.get_tokenized_tgt_text(sample_id)
         
     | 
| 449 | 
         
            +
                        target = self.dict.encode_line(
         
     | 
| 450 | 
         
            +
                            tokenized,
         
     | 
| 451 | 
         
            +
                            add_if_not_exist=False,
         
     | 
| 452 | 
         
            +
                            append_eos=self.append_eos,
         
     | 
| 453 | 
         
            +
                        )
         
     | 
| 454 | 
         
            +
                        if self.prepend_bos_and_append_tgt_lang_tag:
         
     | 
| 455 | 
         
            +
                            bos = torch.LongTensor([self.dict.bos()])
         
     | 
| 456 | 
         
            +
                            lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
         
     | 
| 457 | 
         
            +
                            assert lang_tag_idx != self.dict.unk()
         
     | 
| 458 | 
         
            +
                            lang_tag_idx = torch.LongTensor([lang_tag_idx])
         
     | 
| 459 | 
         
            +
                            target = torch.cat((bos, target, lang_tag_idx), 0)
         
     | 
| 460 | 
         
            +
                        return target
         
     | 
| 461 | 
         
            +
                    else:
         
     | 
| 462 | 
         
            +
                        logger.warning(f"no target for {sample_id}")
         
     | 
| 463 | 
         
            +
                        return torch.IntTensor([])
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
         
     | 
| 466 | 
         
            +
                    out = fairseq_data_utils.collate_tokens(
         
     | 
| 467 | 
         
            +
                        samples,
         
     | 
| 468 | 
         
            +
                        self.dict.pad(),
         
     | 
| 469 | 
         
            +
                        eos_idx=None,
         
     | 
| 470 | 
         
            +
                        left_pad=False,
         
     | 
| 471 | 
         
            +
                        move_eos_to_beginning=False,
         
     | 
| 472 | 
         
            +
                    ).long()
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                    prev_out = fairseq_data_utils.collate_tokens(
         
     | 
| 475 | 
         
            +
                        samples,
         
     | 
| 476 | 
         
            +
                        self.dict.pad(),
         
     | 
| 477 | 
         
            +
                        eos_idx=None,
         
     | 
| 478 | 
         
            +
                        left_pad=False,
         
     | 
| 479 | 
         
            +
                        move_eos_to_beginning=True,
         
     | 
| 480 | 
         
            +
                    ).long()
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                    target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
         
     | 
| 483 | 
         
            +
                    ntokens = sum(t.size(0) for t in samples)
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                    output = {
         
     | 
| 486 | 
         
            +
                        "prev_output_tokens": prev_out,
         
     | 
| 487 | 
         
            +
                        "target": out,
         
     | 
| 488 | 
         
            +
                        "target_lengths": target_lengths,
         
     | 
| 489 | 
         
            +
                        "ntokens": ntokens,
         
     | 
| 490 | 
         
            +
                    }
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                    return output
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
            class SpeechToTextMultitaskDataset(SpeechToTextDataset):
         
     | 
| 496 | 
         
            +
                def __init__(self, **kwargs):
         
     | 
| 497 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 498 | 
         
            +
                    self.multitask_data = {}
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                def add_multitask_dataset(self, task_name, task_data):
         
     | 
| 501 | 
         
            +
                    self.multitask_data[task_name] = task_data
         
     | 
| 502 | 
         
            +
             
     | 
| 503 | 
         
            +
                def __getitem__(
         
     | 
| 504 | 
         
            +
                    self, index: int
         
     | 
| 505 | 
         
            +
                ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
         
     | 
| 506 | 
         
            +
                    s2t_data = super().__getitem__(index)
         
     | 
| 507 | 
         
            +
             
     | 
| 508 | 
         
            +
                    multitask_target = {}
         
     | 
| 509 | 
         
            +
                    sample_id = self.ids[index]
         
     | 
| 510 | 
         
            +
                    tgt_lang = self.tgt_langs[index]
         
     | 
| 511 | 
         
            +
                    for task_name, task_dataset in self.multitask_data.items():
         
     | 
| 512 | 
         
            +
                        multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
                    return s2t_data, multitask_target
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
                def collater(
         
     | 
| 517 | 
         
            +
                    self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
         
     | 
| 518 | 
         
            +
                ) -> Dict:
         
     | 
| 519 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 520 | 
         
            +
                        return {}
         
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
                    out = super().collater([s for s, _ in samples], return_order=True)
         
     | 
| 523 | 
         
            +
                    order = out["order"]
         
     | 
| 524 | 
         
            +
                    del out["order"]
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    for task_name, task_dataset in self.multitask_data.items():
         
     | 
| 527 | 
         
            +
                        if "multitask" not in out:
         
     | 
| 528 | 
         
            +
                            out["multitask"] = {}
         
     | 
| 529 | 
         
            +
                        d = [s[task_name] for _, s in samples]
         
     | 
| 530 | 
         
            +
                        task_target = task_dataset.collater(d)
         
     | 
| 531 | 
         
            +
                        out["multitask"][task_name] = {
         
     | 
| 532 | 
         
            +
                            "target": task_target["target"].index_select(0, order),
         
     | 
| 533 | 
         
            +
                            "target_lengths": task_target["target_lengths"].index_select(0, order),
         
     | 
| 534 | 
         
            +
                            "ntokens": task_target["ntokens"],
         
     | 
| 535 | 
         
            +
                        }
         
     | 
| 536 | 
         
            +
                        out["multitask"][task_name]["net_input"] = {
         
     | 
| 537 | 
         
            +
                            "prev_output_tokens": task_target["prev_output_tokens"].index_select(
         
     | 
| 538 | 
         
            +
                                0, order
         
     | 
| 539 | 
         
            +
                            ),
         
     | 
| 540 | 
         
            +
                        }
         
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
            +
                    return out
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
            class SpeechToTextDatasetCreator(object):
         
     | 
| 546 | 
         
            +
                # mandatory columns
         
     | 
| 547 | 
         
            +
                KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
         
     | 
| 548 | 
         
            +
                KEY_TGT_TEXT = "tgt_text"
         
     | 
| 549 | 
         
            +
                # optional columns
         
     | 
| 550 | 
         
            +
                KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
         
     | 
| 551 | 
         
            +
                KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
         
     | 
| 552 | 
         
            +
                # default values
         
     | 
| 553 | 
         
            +
                DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                @classmethod
         
     | 
| 556 | 
         
            +
                def _from_list(
         
     | 
| 557 | 
         
            +
                    cls,
         
     | 
| 558 | 
         
            +
                    split_name: str,
         
     | 
| 559 | 
         
            +
                    is_train_split,
         
     | 
| 560 | 
         
            +
                    samples: List[Dict],
         
     | 
| 561 | 
         
            +
                    cfg: S2TDataConfig,
         
     | 
| 562 | 
         
            +
                    tgt_dict,
         
     | 
| 563 | 
         
            +
                    pre_tokenizer,
         
     | 
| 564 | 
         
            +
                    bpe_tokenizer,
         
     | 
| 565 | 
         
            +
                    n_frames_per_step,
         
     | 
| 566 | 
         
            +
                    speaker_to_id,
         
     | 
| 567 | 
         
            +
                    multitask: Optional[Dict] = None,
         
     | 
| 568 | 
         
            +
                ) -> SpeechToTextDataset:
         
     | 
| 569 | 
         
            +
                    audio_root = Path(cfg.audio_root)
         
     | 
| 570 | 
         
            +
                    ids = [s[cls.KEY_ID] for s in samples]
         
     | 
| 571 | 
         
            +
                    audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
         
     | 
| 572 | 
         
            +
                    n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
         
     | 
| 573 | 
         
            +
                    tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
         
     | 
| 574 | 
         
            +
                    src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
         
     | 
| 575 | 
         
            +
                    speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
         
     | 
| 576 | 
         
            +
                    src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
         
     | 
| 577 | 
         
            +
                    tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
         
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
                    has_multitask = multitask is not None and len(multitask.keys()) > 0
         
     | 
| 580 | 
         
            +
                    dataset_cls = (
         
     | 
| 581 | 
         
            +
                        SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
         
     | 
| 582 | 
         
            +
                    )
         
     | 
| 583 | 
         
            +
             
     | 
| 584 | 
         
            +
                    ds = dataset_cls(
         
     | 
| 585 | 
         
            +
                        split=split_name,
         
     | 
| 586 | 
         
            +
                        is_train_split=is_train_split,
         
     | 
| 587 | 
         
            +
                        cfg=cfg,
         
     | 
| 588 | 
         
            +
                        audio_paths=audio_paths,
         
     | 
| 589 | 
         
            +
                        n_frames=n_frames,
         
     | 
| 590 | 
         
            +
                        src_texts=src_texts,
         
     | 
| 591 | 
         
            +
                        tgt_texts=tgt_texts,
         
     | 
| 592 | 
         
            +
                        speakers=speakers,
         
     | 
| 593 | 
         
            +
                        src_langs=src_langs,
         
     | 
| 594 | 
         
            +
                        tgt_langs=tgt_langs,
         
     | 
| 595 | 
         
            +
                        ids=ids,
         
     | 
| 596 | 
         
            +
                        tgt_dict=tgt_dict,
         
     | 
| 597 | 
         
            +
                        pre_tokenizer=pre_tokenizer,
         
     | 
| 598 | 
         
            +
                        bpe_tokenizer=bpe_tokenizer,
         
     | 
| 599 | 
         
            +
                        n_frames_per_step=n_frames_per_step,
         
     | 
| 600 | 
         
            +
                        speaker_to_id=speaker_to_id,
         
     | 
| 601 | 
         
            +
                    )
         
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
                    if has_multitask:
         
     | 
| 604 | 
         
            +
                        for task_name, task_obj in multitask.items():
         
     | 
| 605 | 
         
            +
                            task_data = TextTargetMultitaskData(
         
     | 
| 606 | 
         
            +
                                task_obj.args, split_name, task_obj.target_dictionary
         
     | 
| 607 | 
         
            +
                            )
         
     | 
| 608 | 
         
            +
                            ds.add_multitask_dataset(task_name, task_data)
         
     | 
| 609 | 
         
            +
                    return ds
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
                @classmethod
         
     | 
| 612 | 
         
            +
                def get_size_ratios(
         
     | 
| 613 | 
         
            +
                    cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
         
     | 
| 614 | 
         
            +
                ) -> List[float]:
         
     | 
| 615 | 
         
            +
                    """Size ratios for temperature-based sampling
         
     | 
| 616 | 
         
            +
                    (https://arxiv.org/abs/1907.05019)"""
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
                    id_to_lp, lp_to_sz = {}, defaultdict(int)
         
     | 
| 619 | 
         
            +
                    for ds in datasets:
         
     | 
| 620 | 
         
            +
                        lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
         
     | 
| 621 | 
         
            +
                        assert len(lang_pairs) == 1
         
     | 
| 622 | 
         
            +
                        lang_pair = list(lang_pairs)[0]
         
     | 
| 623 | 
         
            +
                        id_to_lp[ds.split] = lang_pair
         
     | 
| 624 | 
         
            +
                        lp_to_sz[lang_pair] += sum(ds.n_frames)
         
     | 
| 625 | 
         
            +
             
     | 
| 626 | 
         
            +
                    sz_sum = sum(v for v in lp_to_sz.values())
         
     | 
| 627 | 
         
            +
                    lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
         
     | 
| 628 | 
         
            +
                    lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
         
     | 
| 629 | 
         
            +
                    prob_sum = sum(v for v in lp_to_tgt_prob.values())
         
     | 
| 630 | 
         
            +
                    lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
         
     | 
| 631 | 
         
            +
                    lp_to_sz_ratio = {
         
     | 
| 632 | 
         
            +
                        k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
         
     | 
| 633 | 
         
            +
                    }
         
     | 
| 634 | 
         
            +
                    size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                    p_formatted = {
         
     | 
| 637 | 
         
            +
                        k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
         
     | 
| 638 | 
         
            +
                    }
         
     | 
| 639 | 
         
            +
                    logger.info(f"sampling probability balancing: {p_formatted}")
         
     | 
| 640 | 
         
            +
                    sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
         
     | 
| 641 | 
         
            +
                    logger.info(f"balanced sampling size ratio: {sr_formatted}")
         
     | 
| 642 | 
         
            +
                    return size_ratio
         
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
                @classmethod
         
     | 
| 645 | 
         
            +
                def _load_samples_from_tsv(cls, root: str, split: str):
         
     | 
| 646 | 
         
            +
                    tsv_path = Path(root) / f"{split}.tsv"
         
     | 
| 647 | 
         
            +
                    if not tsv_path.is_file():
         
     | 
| 648 | 
         
            +
                        raise FileNotFoundError(f"Dataset not found: {tsv_path}")
         
     | 
| 649 | 
         
            +
                    with open(tsv_path) as f:
         
     | 
| 650 | 
         
            +
                        reader = csv.DictReader(
         
     | 
| 651 | 
         
            +
                            f,
         
     | 
| 652 | 
         
            +
                            delimiter="\t",
         
     | 
| 653 | 
         
            +
                            quotechar=None,
         
     | 
| 654 | 
         
            +
                            doublequote=False,
         
     | 
| 655 | 
         
            +
                            lineterminator="\n",
         
     | 
| 656 | 
         
            +
                            quoting=csv.QUOTE_NONE,
         
     | 
| 657 | 
         
            +
                        )
         
     | 
| 658 | 
         
            +
                        samples = [dict(e) for e in reader]
         
     | 
| 659 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 660 | 
         
            +
                        raise ValueError(f"Empty manifest: {tsv_path}")
         
     | 
| 661 | 
         
            +
                    return samples
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                @classmethod
         
     | 
| 664 | 
         
            +
                def _from_tsv(
         
     | 
| 665 | 
         
            +
                    cls,
         
     | 
| 666 | 
         
            +
                    root: str,
         
     | 
| 667 | 
         
            +
                    cfg: S2TDataConfig,
         
     | 
| 668 | 
         
            +
                    split: str,
         
     | 
| 669 | 
         
            +
                    tgt_dict,
         
     | 
| 670 | 
         
            +
                    is_train_split: bool,
         
     | 
| 671 | 
         
            +
                    pre_tokenizer,
         
     | 
| 672 | 
         
            +
                    bpe_tokenizer,
         
     | 
| 673 | 
         
            +
                    n_frames_per_step,
         
     | 
| 674 | 
         
            +
                    speaker_to_id,
         
     | 
| 675 | 
         
            +
                    multitask: Optional[Dict] = None,
         
     | 
| 676 | 
         
            +
                ) -> SpeechToTextDataset:
         
     | 
| 677 | 
         
            +
                    samples = cls._load_samples_from_tsv(root, split)
         
     | 
| 678 | 
         
            +
                    return cls._from_list(
         
     | 
| 679 | 
         
            +
                        split,
         
     | 
| 680 | 
         
            +
                        is_train_split,
         
     | 
| 681 | 
         
            +
                        samples,
         
     | 
| 682 | 
         
            +
                        cfg,
         
     | 
| 683 | 
         
            +
                        tgt_dict,
         
     | 
| 684 | 
         
            +
                        pre_tokenizer,
         
     | 
| 685 | 
         
            +
                        bpe_tokenizer,
         
     | 
| 686 | 
         
            +
                        n_frames_per_step,
         
     | 
| 687 | 
         
            +
                        speaker_to_id,
         
     | 
| 688 | 
         
            +
                        multitask,
         
     | 
| 689 | 
         
            +
                    )
         
     | 
| 690 | 
         
            +
             
     | 
| 691 | 
         
            +
                @classmethod
         
     | 
| 692 | 
         
            +
                def from_tsv(
         
     | 
| 693 | 
         
            +
                    cls,
         
     | 
| 694 | 
         
            +
                    root: str,
         
     | 
| 695 | 
         
            +
                    cfg: S2TDataConfig,
         
     | 
| 696 | 
         
            +
                    splits: str,
         
     | 
| 697 | 
         
            +
                    tgt_dict,
         
     | 
| 698 | 
         
            +
                    pre_tokenizer,
         
     | 
| 699 | 
         
            +
                    bpe_tokenizer,
         
     | 
| 700 | 
         
            +
                    is_train_split: bool,
         
     | 
| 701 | 
         
            +
                    epoch: int,
         
     | 
| 702 | 
         
            +
                    seed: int,
         
     | 
| 703 | 
         
            +
                    n_frames_per_step: int = 1,
         
     | 
| 704 | 
         
            +
                    speaker_to_id=None,
         
     | 
| 705 | 
         
            +
                    multitask: Optional[Dict] = None,
         
     | 
| 706 | 
         
            +
                ) -> SpeechToTextDataset:
         
     | 
| 707 | 
         
            +
                    datasets = [
         
     | 
| 708 | 
         
            +
                        cls._from_tsv(
         
     | 
| 709 | 
         
            +
                            root=root,
         
     | 
| 710 | 
         
            +
                            cfg=cfg,
         
     | 
| 711 | 
         
            +
                            split=split,
         
     | 
| 712 | 
         
            +
                            tgt_dict=tgt_dict,
         
     | 
| 713 | 
         
            +
                            is_train_split=is_train_split,
         
     | 
| 714 | 
         
            +
                            pre_tokenizer=pre_tokenizer,
         
     | 
| 715 | 
         
            +
                            bpe_tokenizer=bpe_tokenizer,
         
     | 
| 716 | 
         
            +
                            n_frames_per_step=n_frames_per_step,
         
     | 
| 717 | 
         
            +
                            speaker_to_id=speaker_to_id,
         
     | 
| 718 | 
         
            +
                            multitask=multitask,
         
     | 
| 719 | 
         
            +
                        )
         
     | 
| 720 | 
         
            +
                        for split in splits.split(",")
         
     | 
| 721 | 
         
            +
                    ]
         
     | 
| 722 | 
         
            +
             
     | 
| 723 | 
         
            +
                    if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
         
     | 
| 724 | 
         
            +
                        # temperature-based sampling
         
     | 
| 725 | 
         
            +
                        size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
         
     | 
| 726 | 
         
            +
                        datasets = [
         
     | 
| 727 | 
         
            +
                            ResamplingDataset(
         
     | 
| 728 | 
         
            +
                                d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
         
     | 
| 729 | 
         
            +
                            )
         
     | 
| 730 | 
         
            +
                            for r, d in zip(size_ratios, datasets)
         
     | 
| 731 | 
         
            +
                        ]
         
     | 
| 732 | 
         
            +
             
     | 
| 733 | 
         
            +
                    return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/speech_to_text_joint_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,359 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import logging
         
     | 
| 7 | 
         
            +
            from pathlib import Path
         
     | 
| 8 | 
         
            +
            from typing import Dict, List, NamedTuple, Optional
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
         
     | 
| 13 | 
         
            +
            from fairseq.data import data_utils as fairseq_data_utils
         
     | 
| 14 | 
         
            +
            from fairseq.data.audio.speech_to_text_dataset import (
         
     | 
| 15 | 
         
            +
                S2TDataConfig,
         
     | 
| 16 | 
         
            +
                SpeechToTextDataset,
         
     | 
| 17 | 
         
            +
                SpeechToTextDatasetCreator,
         
     | 
| 18 | 
         
            +
            )
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            class S2TJointDataConfig(S2TDataConfig):
         
     | 
| 24 | 
         
            +
                """Wrapper class for data config YAML"""
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                @property
         
     | 
| 27 | 
         
            +
                def src_vocab_filename(self):
         
     | 
| 28 | 
         
            +
                    """fairseq vocabulary file under data root"""
         
     | 
| 29 | 
         
            +
                    return self.config.get("src_vocab_filename", "src_dict.txt")
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                @property
         
     | 
| 32 | 
         
            +
                def src_pre_tokenizer(self) -> Dict:
         
     | 
| 33 | 
         
            +
                    """Pre-tokenizer to apply before subword tokenization. Returning
         
     | 
| 34 | 
         
            +
                    a dictionary with `tokenizer` providing the tokenizer name and
         
     | 
| 35 | 
         
            +
                    the other items providing the tokenizer-specific arguments.
         
     | 
| 36 | 
         
            +
                    Tokenizers are defined in `fairseq.data.encoders.*`"""
         
     | 
| 37 | 
         
            +
                    return self.config.get("src_pre_tokenizer", {"tokenizer": None})
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                @property
         
     | 
| 40 | 
         
            +
                def src_bpe_tokenizer(self) -> Dict:
         
     | 
| 41 | 
         
            +
                    """Subword tokenizer to apply on source text after pre-tokenization.
         
     | 
| 42 | 
         
            +
                    Returning a dictionary with `bpe` providing the tokenizer name and
         
     | 
| 43 | 
         
            +
                    the other items providing the tokenizer-specific arguments.
         
     | 
| 44 | 
         
            +
                    Tokenizers are defined in `fairseq.data.encoders.*`"""
         
     | 
| 45 | 
         
            +
                    return self.config.get("src_bpe_tokenizer", {"bpe": None})
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                @property
         
     | 
| 48 | 
         
            +
                def prepend_tgt_lang_tag_no_change(self) -> bool:
         
     | 
| 49 | 
         
            +
                    """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
         
     | 
| 50 | 
         
            +
                    to-many multilingual setting). No change needed during inference.
         
     | 
| 51 | 
         
            +
                    This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
         
     | 
| 52 | 
         
            +
                    """
         
     | 
| 53 | 
         
            +
                    value = self.config.get("prepend_tgt_lang_tag_no_change", None)
         
     | 
| 54 | 
         
            +
                    if value is None:
         
     | 
| 55 | 
         
            +
                        return self.config.get("prepend_tgt_lang_tag_as_bos", False)
         
     | 
| 56 | 
         
            +
                    return value
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                @property
         
     | 
| 59 | 
         
            +
                def sampling_text_alpha(self):
         
     | 
| 60 | 
         
            +
                    """Hyper-parameter alpha = 1/T for temperature-based resampling. (text
         
     | 
| 61 | 
         
            +
                    input only) (alpha = 1 for no resampling)"""
         
     | 
| 62 | 
         
            +
                    return self.config.get("sampling_text_alpha", 1.0)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            class SpeechToTextJointDatasetItem(NamedTuple):
         
     | 
| 66 | 
         
            +
                index: int
         
     | 
| 67 | 
         
            +
                source: torch.Tensor
         
     | 
| 68 | 
         
            +
                target: Optional[torch.Tensor] = None
         
     | 
| 69 | 
         
            +
                src_txt_tokens: Optional[torch.Tensor] = None
         
     | 
| 70 | 
         
            +
                tgt_lang_tag: Optional[int] = None
         
     | 
| 71 | 
         
            +
                src_lang_tag: Optional[int] = None
         
     | 
| 72 | 
         
            +
                tgt_alignment: Optional[torch.Tensor] = None
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            # use_src_lang_id:
         
     | 
| 76 | 
         
            +
            #   0: don't use src_lang_id
         
     | 
| 77 | 
         
            +
            #   1: attach src_lang_id to the src_txt_tokens as eos
         
     | 
| 78 | 
         
            +
            class SpeechToTextJointDataset(SpeechToTextDataset):
         
     | 
| 79 | 
         
            +
                def __init__(
         
     | 
| 80 | 
         
            +
                    self,
         
     | 
| 81 | 
         
            +
                    split: str,
         
     | 
| 82 | 
         
            +
                    is_train_split: bool,
         
     | 
| 83 | 
         
            +
                    cfg: S2TJointDataConfig,
         
     | 
| 84 | 
         
            +
                    audio_paths: List[str],
         
     | 
| 85 | 
         
            +
                    n_frames: List[int],
         
     | 
| 86 | 
         
            +
                    src_texts: Optional[List[str]] = None,
         
     | 
| 87 | 
         
            +
                    tgt_texts: Optional[List[str]] = None,
         
     | 
| 88 | 
         
            +
                    speakers: Optional[List[str]] = None,
         
     | 
| 89 | 
         
            +
                    src_langs: Optional[List[str]] = None,
         
     | 
| 90 | 
         
            +
                    tgt_langs: Optional[List[str]] = None,
         
     | 
| 91 | 
         
            +
                    ids: Optional[List[str]] = None,
         
     | 
| 92 | 
         
            +
                    tgt_dict: Optional[Dictionary] = None,
         
     | 
| 93 | 
         
            +
                    src_dict: Optional[Dictionary] = None,
         
     | 
| 94 | 
         
            +
                    pre_tokenizer=None,
         
     | 
| 95 | 
         
            +
                    bpe_tokenizer=None,
         
     | 
| 96 | 
         
            +
                    src_pre_tokenizer=None,
         
     | 
| 97 | 
         
            +
                    src_bpe_tokenizer=None,
         
     | 
| 98 | 
         
            +
                    append_eos: Optional[bool] = True,
         
     | 
| 99 | 
         
            +
                    alignment: Optional[List[str]] = None,
         
     | 
| 100 | 
         
            +
                    use_src_lang_id: Optional[int] = 0,
         
     | 
| 101 | 
         
            +
                ):
         
     | 
| 102 | 
         
            +
                    super().__init__(
         
     | 
| 103 | 
         
            +
                        split,
         
     | 
| 104 | 
         
            +
                        is_train_split,
         
     | 
| 105 | 
         
            +
                        cfg,
         
     | 
| 106 | 
         
            +
                        audio_paths,
         
     | 
| 107 | 
         
            +
                        n_frames,
         
     | 
| 108 | 
         
            +
                        src_texts=src_texts,
         
     | 
| 109 | 
         
            +
                        tgt_texts=tgt_texts,
         
     | 
| 110 | 
         
            +
                        speakers=speakers,
         
     | 
| 111 | 
         
            +
                        src_langs=src_langs,
         
     | 
| 112 | 
         
            +
                        tgt_langs=tgt_langs,
         
     | 
| 113 | 
         
            +
                        ids=ids,
         
     | 
| 114 | 
         
            +
                        tgt_dict=tgt_dict,
         
     | 
| 115 | 
         
            +
                        pre_tokenizer=pre_tokenizer,
         
     | 
| 116 | 
         
            +
                        bpe_tokenizer=bpe_tokenizer,
         
     | 
| 117 | 
         
            +
                        append_eos=append_eos,
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    self.src_dict = src_dict
         
     | 
| 121 | 
         
            +
                    self.src_pre_tokenizer = src_pre_tokenizer
         
     | 
| 122 | 
         
            +
                    self.src_bpe_tokenizer = src_bpe_tokenizer
         
     | 
| 123 | 
         
            +
                    self.alignment = None
         
     | 
| 124 | 
         
            +
                    self.use_src_lang_id = use_src_lang_id
         
     | 
| 125 | 
         
            +
                    if alignment is not None:
         
     | 
| 126 | 
         
            +
                        self.alignment = [
         
     | 
| 127 | 
         
            +
                            [float(s) for s in sample.split()] for sample in alignment
         
     | 
| 128 | 
         
            +
                        ]
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                def get_tokenized_src_text(self, index: int):
         
     | 
| 131 | 
         
            +
                    text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
         
     | 
| 132 | 
         
            +
                    text = self.tokenize(self.src_bpe_tokenizer, text)
         
     | 
| 133 | 
         
            +
                    return text
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
         
     | 
| 136 | 
         
            +
                    s2t_dataset_item = super().__getitem__(index)
         
     | 
| 137 | 
         
            +
                    src_tokens = None
         
     | 
| 138 | 
         
            +
                    src_lang_tag = None
         
     | 
| 139 | 
         
            +
                    if self.src_texts is not None and self.src_dict is not None:
         
     | 
| 140 | 
         
            +
                        src_tokens = self.get_tokenized_src_text(index)
         
     | 
| 141 | 
         
            +
                        src_tokens = self.src_dict.encode_line(
         
     | 
| 142 | 
         
            +
                            src_tokens, add_if_not_exist=False, append_eos=True
         
     | 
| 143 | 
         
            +
                        ).long()
         
     | 
| 144 | 
         
            +
                        if self.use_src_lang_id > 0:
         
     | 
| 145 | 
         
            +
                            src_lang_tag = self.get_lang_tag_idx(
         
     | 
| 146 | 
         
            +
                                self.src_langs[index], self.src_dict
         
     | 
| 147 | 
         
            +
                            )
         
     | 
| 148 | 
         
            +
                    tgt_lang_tag = None
         
     | 
| 149 | 
         
            +
                    if self.cfg.prepend_tgt_lang_tag_no_change:
         
     | 
| 150 | 
         
            +
                        # prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
         
     | 
| 151 | 
         
            +
                        tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
         
     | 
| 152 | 
         
            +
                    ali = None
         
     | 
| 153 | 
         
            +
                    if self.alignment is not None:
         
     | 
| 154 | 
         
            +
                        ali = torch.Tensor(self.alignment[index]).float()
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    return SpeechToTextJointDatasetItem(
         
     | 
| 157 | 
         
            +
                        index=index,
         
     | 
| 158 | 
         
            +
                        source=s2t_dataset_item.source,
         
     | 
| 159 | 
         
            +
                        target=s2t_dataset_item.target,
         
     | 
| 160 | 
         
            +
                        src_txt_tokens=src_tokens,
         
     | 
| 161 | 
         
            +
                        tgt_lang_tag=tgt_lang_tag,
         
     | 
| 162 | 
         
            +
                        src_lang_tag=src_lang_tag,
         
     | 
| 163 | 
         
            +
                        tgt_alignment=ali,
         
     | 
| 164 | 
         
            +
                    )
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                def __len__(self):
         
     | 
| 167 | 
         
            +
                    return self.n_samples
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
         
     | 
| 170 | 
         
            +
                    s2t_out = super().collater(samples, return_order=True)
         
     | 
| 171 | 
         
            +
                    if s2t_out == {}:
         
     | 
| 172 | 
         
            +
                        return s2t_out
         
     | 
| 173 | 
         
            +
                    net_input, order = s2t_out["net_input"], s2t_out["order"]
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    if self.src_texts is not None and self.src_dict is not None:
         
     | 
| 176 | 
         
            +
                        src_txt_tokens = fairseq_data_utils.collate_tokens(
         
     | 
| 177 | 
         
            +
                            [x.src_txt_tokens for x in samples],
         
     | 
| 178 | 
         
            +
                            self.src_dict.pad(),
         
     | 
| 179 | 
         
            +
                            self.src_dict.eos(),
         
     | 
| 180 | 
         
            +
                            left_pad=False,
         
     | 
| 181 | 
         
            +
                            move_eos_to_beginning=False,
         
     | 
| 182 | 
         
            +
                        )
         
     | 
| 183 | 
         
            +
                        src_txt_lengths = torch.tensor(
         
     | 
| 184 | 
         
            +
                            [x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
         
     | 
| 185 | 
         
            +
                        )
         
     | 
| 186 | 
         
            +
                        if self.use_src_lang_id > 0:
         
     | 
| 187 | 
         
            +
                            src_lang_idxs = torch.tensor(
         
     | 
| 188 | 
         
            +
                                [s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
         
     | 
| 189 | 
         
            +
                            )
         
     | 
| 190 | 
         
            +
                            if self.use_src_lang_id == 1:  # replace eos with lang_id
         
     | 
| 191 | 
         
            +
                                eos_idx = src_txt_lengths - 1
         
     | 
| 192 | 
         
            +
                                src_txt_tokens.scatter_(
         
     | 
| 193 | 
         
            +
                                    1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
         
     | 
| 194 | 
         
            +
                                )
         
     | 
| 195 | 
         
            +
                            else:
         
     | 
| 196 | 
         
            +
                                raise NotImplementedError("Implementation is required")
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                        src_txt_tokens = src_txt_tokens.index_select(0, order)
         
     | 
| 199 | 
         
            +
                        src_txt_lengths = src_txt_lengths.index_select(0, order)
         
     | 
| 200 | 
         
            +
                        net_input["src_txt_tokens"] = src_txt_tokens
         
     | 
| 201 | 
         
            +
                        net_input["src_txt_lengths"] = src_txt_lengths
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    net_input["alignment"] = None
         
     | 
| 204 | 
         
            +
                    if self.alignment is not None:
         
     | 
| 205 | 
         
            +
                        max_len = max([s.tgt_alignment.size(0) for s in samples])
         
     | 
| 206 | 
         
            +
                        alignment = torch.ones(len(samples), max_len).float()
         
     | 
| 207 | 
         
            +
                        for i, s in enumerate(samples):
         
     | 
| 208 | 
         
            +
                            cur_len = s.tgt_alignment.size(0)
         
     | 
| 209 | 
         
            +
                            alignment[i][:cur_len].copy_(s.tgt_alignment)
         
     | 
| 210 | 
         
            +
                        net_input["alignment"] = alignment.index_select(0, order)
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
         
     | 
| 213 | 
         
            +
                        for i in range(len(samples)):
         
     | 
| 214 | 
         
            +
                            net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    out = {
         
     | 
| 217 | 
         
            +
                        "id": s2t_out["id"],
         
     | 
| 218 | 
         
            +
                        "net_input": net_input,
         
     | 
| 219 | 
         
            +
                        "target": s2t_out["target"],
         
     | 
| 220 | 
         
            +
                        "target_lengths": s2t_out["target_lengths"],
         
     | 
| 221 | 
         
            +
                        "ntokens": s2t_out["ntokens"],
         
     | 
| 222 | 
         
            +
                        "nsentences": len(samples),
         
     | 
| 223 | 
         
            +
                    }
         
     | 
| 224 | 
         
            +
                    return out
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
            class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
         
     | 
| 228 | 
         
            +
                KEY_ALIGN = "align"
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                @classmethod
         
     | 
| 231 | 
         
            +
                def _from_list(
         
     | 
| 232 | 
         
            +
                    cls,
         
     | 
| 233 | 
         
            +
                    split_name: str,
         
     | 
| 234 | 
         
            +
                    is_train_split,
         
     | 
| 235 | 
         
            +
                    samples: List[Dict],
         
     | 
| 236 | 
         
            +
                    cfg: S2TJointDataConfig,
         
     | 
| 237 | 
         
            +
                    tgt_dict,
         
     | 
| 238 | 
         
            +
                    src_dict,
         
     | 
| 239 | 
         
            +
                    pre_tokenizer,
         
     | 
| 240 | 
         
            +
                    bpe_tokenizer,
         
     | 
| 241 | 
         
            +
                    src_pre_tokenizer,
         
     | 
| 242 | 
         
            +
                    src_bpe_tokenizer,
         
     | 
| 243 | 
         
            +
                    append_eos,
         
     | 
| 244 | 
         
            +
                    use_src_lang_id,
         
     | 
| 245 | 
         
            +
                ) -> SpeechToTextJointDataset:
         
     | 
| 246 | 
         
            +
                    audio_root = Path(cfg.audio_root)
         
     | 
| 247 | 
         
            +
                    ids = [s[cls.KEY_ID] for s in samples]
         
     | 
| 248 | 
         
            +
                    audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
         
     | 
| 249 | 
         
            +
                    n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
         
     | 
| 250 | 
         
            +
                    tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
         
     | 
| 251 | 
         
            +
                    src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
         
     | 
| 252 | 
         
            +
                    speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
         
     | 
| 253 | 
         
            +
                    src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
         
     | 
| 254 | 
         
            +
                    tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
         
     | 
| 255 | 
         
            +
                    tgt_alignment = None
         
     | 
| 256 | 
         
            +
                    if cls.KEY_ALIGN in samples[0].keys():
         
     | 
| 257 | 
         
            +
                        tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
         
     | 
| 258 | 
         
            +
                    return SpeechToTextJointDataset(
         
     | 
| 259 | 
         
            +
                        split_name,
         
     | 
| 260 | 
         
            +
                        is_train_split,
         
     | 
| 261 | 
         
            +
                        cfg,
         
     | 
| 262 | 
         
            +
                        audio_paths,
         
     | 
| 263 | 
         
            +
                        n_frames,
         
     | 
| 264 | 
         
            +
                        src_texts=src_texts,
         
     | 
| 265 | 
         
            +
                        tgt_texts=tgt_texts,
         
     | 
| 266 | 
         
            +
                        speakers=speakers,
         
     | 
| 267 | 
         
            +
                        src_langs=src_langs,
         
     | 
| 268 | 
         
            +
                        tgt_langs=tgt_langs,
         
     | 
| 269 | 
         
            +
                        ids=ids,
         
     | 
| 270 | 
         
            +
                        tgt_dict=tgt_dict,
         
     | 
| 271 | 
         
            +
                        src_dict=src_dict,
         
     | 
| 272 | 
         
            +
                        pre_tokenizer=pre_tokenizer,
         
     | 
| 273 | 
         
            +
                        bpe_tokenizer=bpe_tokenizer,
         
     | 
| 274 | 
         
            +
                        src_pre_tokenizer=src_pre_tokenizer,
         
     | 
| 275 | 
         
            +
                        src_bpe_tokenizer=src_bpe_tokenizer,
         
     | 
| 276 | 
         
            +
                        append_eos=append_eos,
         
     | 
| 277 | 
         
            +
                        alignment=tgt_alignment,
         
     | 
| 278 | 
         
            +
                        use_src_lang_id=use_src_lang_id,
         
     | 
| 279 | 
         
            +
                    )
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                @classmethod
         
     | 
| 282 | 
         
            +
                def _from_tsv(
         
     | 
| 283 | 
         
            +
                    cls,
         
     | 
| 284 | 
         
            +
                    root: str,
         
     | 
| 285 | 
         
            +
                    cfg: S2TJointDataConfig,
         
     | 
| 286 | 
         
            +
                    split: str,
         
     | 
| 287 | 
         
            +
                    tgt_dict,
         
     | 
| 288 | 
         
            +
                    src_dict,
         
     | 
| 289 | 
         
            +
                    is_train_split: bool,
         
     | 
| 290 | 
         
            +
                    pre_tokenizer,
         
     | 
| 291 | 
         
            +
                    bpe_tokenizer,
         
     | 
| 292 | 
         
            +
                    src_pre_tokenizer,
         
     | 
| 293 | 
         
            +
                    src_bpe_tokenizer,
         
     | 
| 294 | 
         
            +
                    append_eos: bool,
         
     | 
| 295 | 
         
            +
                    use_src_lang_id: int,
         
     | 
| 296 | 
         
            +
                ) -> SpeechToTextJointDataset:
         
     | 
| 297 | 
         
            +
                    samples = cls._load_samples_from_tsv(root, split)
         
     | 
| 298 | 
         
            +
                    return cls._from_list(
         
     | 
| 299 | 
         
            +
                        split,
         
     | 
| 300 | 
         
            +
                        is_train_split,
         
     | 
| 301 | 
         
            +
                        samples,
         
     | 
| 302 | 
         
            +
                        cfg,
         
     | 
| 303 | 
         
            +
                        tgt_dict,
         
     | 
| 304 | 
         
            +
                        src_dict,
         
     | 
| 305 | 
         
            +
                        pre_tokenizer,
         
     | 
| 306 | 
         
            +
                        bpe_tokenizer,
         
     | 
| 307 | 
         
            +
                        src_pre_tokenizer,
         
     | 
| 308 | 
         
            +
                        src_bpe_tokenizer,
         
     | 
| 309 | 
         
            +
                        append_eos,
         
     | 
| 310 | 
         
            +
                        use_src_lang_id,
         
     | 
| 311 | 
         
            +
                    )
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                @classmethod
         
     | 
| 314 | 
         
            +
                def from_tsv(
         
     | 
| 315 | 
         
            +
                    cls,
         
     | 
| 316 | 
         
            +
                    root: str,
         
     | 
| 317 | 
         
            +
                    cfg: S2TJointDataConfig,
         
     | 
| 318 | 
         
            +
                    splits: str,
         
     | 
| 319 | 
         
            +
                    tgt_dict,
         
     | 
| 320 | 
         
            +
                    src_dict,
         
     | 
| 321 | 
         
            +
                    pre_tokenizer,
         
     | 
| 322 | 
         
            +
                    bpe_tokenizer,
         
     | 
| 323 | 
         
            +
                    src_pre_tokenizer,
         
     | 
| 324 | 
         
            +
                    src_bpe_tokenizer,
         
     | 
| 325 | 
         
            +
                    is_train_split: bool,
         
     | 
| 326 | 
         
            +
                    epoch: int,
         
     | 
| 327 | 
         
            +
                    seed: int,
         
     | 
| 328 | 
         
            +
                    append_eos: Optional[bool] = True,
         
     | 
| 329 | 
         
            +
                    use_src_lang_id: Optional[int] = 0,
         
     | 
| 330 | 
         
            +
                ) -> SpeechToTextJointDataset:
         
     | 
| 331 | 
         
            +
                    datasets = [
         
     | 
| 332 | 
         
            +
                        cls._from_tsv(
         
     | 
| 333 | 
         
            +
                            root,
         
     | 
| 334 | 
         
            +
                            cfg,
         
     | 
| 335 | 
         
            +
                            split,
         
     | 
| 336 | 
         
            +
                            tgt_dict,
         
     | 
| 337 | 
         
            +
                            src_dict,
         
     | 
| 338 | 
         
            +
                            is_train_split,
         
     | 
| 339 | 
         
            +
                            pre_tokenizer,
         
     | 
| 340 | 
         
            +
                            bpe_tokenizer,
         
     | 
| 341 | 
         
            +
                            src_pre_tokenizer,
         
     | 
| 342 | 
         
            +
                            src_bpe_tokenizer,
         
     | 
| 343 | 
         
            +
                            append_eos=append_eos,
         
     | 
| 344 | 
         
            +
                            use_src_lang_id=use_src_lang_id,
         
     | 
| 345 | 
         
            +
                        )
         
     | 
| 346 | 
         
            +
                        for split in splits.split(",")
         
     | 
| 347 | 
         
            +
                    ]
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
         
     | 
| 350 | 
         
            +
                        # temperature-based sampling
         
     | 
| 351 | 
         
            +
                        size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
         
     | 
| 352 | 
         
            +
                        datasets = [
         
     | 
| 353 | 
         
            +
                            ResamplingDataset(
         
     | 
| 354 | 
         
            +
                                d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
         
     | 
| 355 | 
         
            +
                            )
         
     | 
| 356 | 
         
            +
                            for r, d in zip(size_ratios, datasets)
         
     | 
| 357 | 
         
            +
                        ]
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/text_to_speech_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,250 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2017-present, Facebook, Inc.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the LICENSE file in
         
     | 
| 5 | 
         
            +
            # the root directory of this source tree. An additional grant of patent rights
         
     | 
| 6 | 
         
            +
            # can be found in the PATENTS file in the same directory.abs
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 9 | 
         
            +
            from pathlib import Path
         
     | 
| 10 | 
         
            +
            from typing import Any, Dict, List, Optional
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import torch
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from fairseq.data import Dictionary
         
     | 
| 16 | 
         
            +
            from fairseq.data import data_utils as fairseq_data_utils
         
     | 
| 17 | 
         
            +
            from fairseq.data.audio.audio_utils import get_features_or_waveform
         
     | 
| 18 | 
         
            +
            from fairseq.data.audio.speech_to_text_dataset import (
         
     | 
| 19 | 
         
            +
                S2TDataConfig,
         
     | 
| 20 | 
         
            +
                SpeechToTextDataset,
         
     | 
| 21 | 
         
            +
                SpeechToTextDatasetCreator,
         
     | 
| 22 | 
         
            +
                _collate_frames,
         
     | 
| 23 | 
         
            +
            )
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            @dataclass
         
     | 
| 27 | 
         
            +
            class TextToSpeechDatasetItem(object):
         
     | 
| 28 | 
         
            +
                index: int
         
     | 
| 29 | 
         
            +
                source: torch.Tensor
         
     | 
| 30 | 
         
            +
                target: Optional[torch.Tensor] = None
         
     | 
| 31 | 
         
            +
                speaker_id: Optional[int] = None
         
     | 
| 32 | 
         
            +
                duration: Optional[torch.Tensor] = None
         
     | 
| 33 | 
         
            +
                pitch: Optional[torch.Tensor] = None
         
     | 
| 34 | 
         
            +
                energy: Optional[torch.Tensor] = None
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            class TextToSpeechDataset(SpeechToTextDataset):
         
     | 
| 38 | 
         
            +
                def __init__(
         
     | 
| 39 | 
         
            +
                    self,
         
     | 
| 40 | 
         
            +
                    split: str,
         
     | 
| 41 | 
         
            +
                    is_train_split: bool,
         
     | 
| 42 | 
         
            +
                    cfg: S2TDataConfig,
         
     | 
| 43 | 
         
            +
                    audio_paths: List[str],
         
     | 
| 44 | 
         
            +
                    n_frames: List[int],
         
     | 
| 45 | 
         
            +
                    src_texts: Optional[List[str]] = None,
         
     | 
| 46 | 
         
            +
                    tgt_texts: Optional[List[str]] = None,
         
     | 
| 47 | 
         
            +
                    speakers: Optional[List[str]] = None,
         
     | 
| 48 | 
         
            +
                    src_langs: Optional[List[str]] = None,
         
     | 
| 49 | 
         
            +
                    tgt_langs: Optional[List[str]] = None,
         
     | 
| 50 | 
         
            +
                    ids: Optional[List[str]] = None,
         
     | 
| 51 | 
         
            +
                    tgt_dict: Optional[Dictionary] = None,
         
     | 
| 52 | 
         
            +
                    pre_tokenizer=None,
         
     | 
| 53 | 
         
            +
                    bpe_tokenizer=None,
         
     | 
| 54 | 
         
            +
                    n_frames_per_step=1,
         
     | 
| 55 | 
         
            +
                    speaker_to_id=None,
         
     | 
| 56 | 
         
            +
                    durations: Optional[List[List[int]]] = None,
         
     | 
| 57 | 
         
            +
                    pitches: Optional[List[str]] = None,
         
     | 
| 58 | 
         
            +
                    energies: Optional[List[str]] = None,
         
     | 
| 59 | 
         
            +
                ):
         
     | 
| 60 | 
         
            +
                    super(TextToSpeechDataset, self).__init__(
         
     | 
| 61 | 
         
            +
                        split,
         
     | 
| 62 | 
         
            +
                        is_train_split,
         
     | 
| 63 | 
         
            +
                        cfg,
         
     | 
| 64 | 
         
            +
                        audio_paths,
         
     | 
| 65 | 
         
            +
                        n_frames,
         
     | 
| 66 | 
         
            +
                        src_texts=src_texts,
         
     | 
| 67 | 
         
            +
                        tgt_texts=tgt_texts,
         
     | 
| 68 | 
         
            +
                        speakers=speakers,
         
     | 
| 69 | 
         
            +
                        src_langs=src_langs,
         
     | 
| 70 | 
         
            +
                        tgt_langs=tgt_langs,
         
     | 
| 71 | 
         
            +
                        ids=ids,
         
     | 
| 72 | 
         
            +
                        tgt_dict=tgt_dict,
         
     | 
| 73 | 
         
            +
                        pre_tokenizer=pre_tokenizer,
         
     | 
| 74 | 
         
            +
                        bpe_tokenizer=bpe_tokenizer,
         
     | 
| 75 | 
         
            +
                        n_frames_per_step=n_frames_per_step,
         
     | 
| 76 | 
         
            +
                        speaker_to_id=speaker_to_id,
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
                    self.durations = durations
         
     | 
| 79 | 
         
            +
                    self.pitches = pitches
         
     | 
| 80 | 
         
            +
                    self.energies = energies
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
         
     | 
| 83 | 
         
            +
                    s2t_item = super().__getitem__(index)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    duration, pitch, energy = None, None, None
         
     | 
| 86 | 
         
            +
                    if self.durations is not None:
         
     | 
| 87 | 
         
            +
                        duration = torch.tensor(
         
     | 
| 88 | 
         
            +
                            self.durations[index] + [0], dtype=torch.long  # pad 0 for EOS
         
     | 
| 89 | 
         
            +
                        )
         
     | 
| 90 | 
         
            +
                    if self.pitches is not None:
         
     | 
| 91 | 
         
            +
                        pitch = get_features_or_waveform(self.pitches[index])
         
     | 
| 92 | 
         
            +
                        pitch = torch.from_numpy(
         
     | 
| 93 | 
         
            +
                            np.concatenate((pitch, [0]))  # pad 0 for EOS
         
     | 
| 94 | 
         
            +
                        ).float()
         
     | 
| 95 | 
         
            +
                    if self.energies is not None:
         
     | 
| 96 | 
         
            +
                        energy = get_features_or_waveform(self.energies[index])
         
     | 
| 97 | 
         
            +
                        energy = torch.from_numpy(
         
     | 
| 98 | 
         
            +
                            np.concatenate((energy, [0]))  # pad 0 for EOS
         
     | 
| 99 | 
         
            +
                        ).float()
         
     | 
| 100 | 
         
            +
                    return TextToSpeechDatasetItem(
         
     | 
| 101 | 
         
            +
                        index=index,
         
     | 
| 102 | 
         
            +
                        source=s2t_item.source,
         
     | 
| 103 | 
         
            +
                        target=s2t_item.target,
         
     | 
| 104 | 
         
            +
                        speaker_id=s2t_item.speaker_id,
         
     | 
| 105 | 
         
            +
                        duration=duration,
         
     | 
| 106 | 
         
            +
                        pitch=pitch,
         
     | 
| 107 | 
         
            +
                        energy=energy,
         
     | 
| 108 | 
         
            +
                    )
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
         
     | 
| 111 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 112 | 
         
            +
                        return {}
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    src_lengths, order = torch.tensor(
         
     | 
| 115 | 
         
            +
                        [s.target.shape[0] for s in samples], dtype=torch.long
         
     | 
| 116 | 
         
            +
                    ).sort(descending=True)
         
     | 
| 117 | 
         
            +
                    id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
         
     | 
| 118 | 
         
            +
                        0, order
         
     | 
| 119 | 
         
            +
                    )
         
     | 
| 120 | 
         
            +
                    feat = _collate_frames(
         
     | 
| 121 | 
         
            +
                        [s.source for s in samples], self.cfg.use_audio_input
         
     | 
| 122 | 
         
            +
                    ).index_select(0, order)
         
     | 
| 123 | 
         
            +
                    target_lengths = torch.tensor(
         
     | 
| 124 | 
         
            +
                        [s.source.shape[0] for s in samples], dtype=torch.long
         
     | 
| 125 | 
         
            +
                    ).index_select(0, order)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    src_tokens = fairseq_data_utils.collate_tokens(
         
     | 
| 128 | 
         
            +
                        [s.target for s in samples],
         
     | 
| 129 | 
         
            +
                        self.tgt_dict.pad(),
         
     | 
| 130 | 
         
            +
                        self.tgt_dict.eos(),
         
     | 
| 131 | 
         
            +
                        left_pad=False,
         
     | 
| 132 | 
         
            +
                        move_eos_to_beginning=False,
         
     | 
| 133 | 
         
            +
                    ).index_select(0, order)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    speaker = None
         
     | 
| 136 | 
         
            +
                    if self.speaker_to_id is not None:
         
     | 
| 137 | 
         
            +
                        speaker = (
         
     | 
| 138 | 
         
            +
                            torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
         
     | 
| 139 | 
         
            +
                            .index_select(0, order)
         
     | 
| 140 | 
         
            +
                            .view(-1, 1)
         
     | 
| 141 | 
         
            +
                        )
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    bsz, _, d = feat.size()
         
     | 
| 144 | 
         
            +
                    prev_output_tokens = torch.cat(
         
     | 
| 145 | 
         
            +
                        (feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
         
     | 
| 146 | 
         
            +
                    )
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    durations, pitches, energies = None, None, None
         
     | 
| 149 | 
         
            +
                    if self.durations is not None:
         
     | 
| 150 | 
         
            +
                        durations = fairseq_data_utils.collate_tokens(
         
     | 
| 151 | 
         
            +
                            [s.duration for s in samples], 0
         
     | 
| 152 | 
         
            +
                        ).index_select(0, order)
         
     | 
| 153 | 
         
            +
                        assert src_tokens.shape[1] == durations.shape[1]
         
     | 
| 154 | 
         
            +
                    if self.pitches is not None:
         
     | 
| 155 | 
         
            +
                        pitches = _collate_frames([s.pitch for s in samples], True)
         
     | 
| 156 | 
         
            +
                        pitches = pitches.index_select(0, order)
         
     | 
| 157 | 
         
            +
                        assert src_tokens.shape[1] == pitches.shape[1]
         
     | 
| 158 | 
         
            +
                    if self.energies is not None:
         
     | 
| 159 | 
         
            +
                        energies = _collate_frames([s.energy for s in samples], True)
         
     | 
| 160 | 
         
            +
                        energies = energies.index_select(0, order)
         
     | 
| 161 | 
         
            +
                        assert src_tokens.shape[1] == energies.shape[1]
         
     | 
| 162 | 
         
            +
                    src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    return {
         
     | 
| 165 | 
         
            +
                        "id": id_,
         
     | 
| 166 | 
         
            +
                        "net_input": {
         
     | 
| 167 | 
         
            +
                            "src_tokens": src_tokens,
         
     | 
| 168 | 
         
            +
                            "src_lengths": src_lengths,
         
     | 
| 169 | 
         
            +
                            "prev_output_tokens": prev_output_tokens,
         
     | 
| 170 | 
         
            +
                        },
         
     | 
| 171 | 
         
            +
                        "speaker": speaker,
         
     | 
| 172 | 
         
            +
                        "target": feat,
         
     | 
| 173 | 
         
            +
                        "durations": durations,
         
     | 
| 174 | 
         
            +
                        "pitches": pitches,
         
     | 
| 175 | 
         
            +
                        "energies": energies,
         
     | 
| 176 | 
         
            +
                        "target_lengths": target_lengths,
         
     | 
| 177 | 
         
            +
                        "ntokens": sum(target_lengths).item(),
         
     | 
| 178 | 
         
            +
                        "nsentences": len(samples),
         
     | 
| 179 | 
         
            +
                        "src_texts": src_texts,
         
     | 
| 180 | 
         
            +
                    }
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
         
     | 
| 184 | 
         
            +
                KEY_DURATION = "duration"
         
     | 
| 185 | 
         
            +
                KEY_PITCH = "pitch"
         
     | 
| 186 | 
         
            +
                KEY_ENERGY = "energy"
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                @classmethod
         
     | 
| 189 | 
         
            +
                def _from_list(
         
     | 
| 190 | 
         
            +
                    cls,
         
     | 
| 191 | 
         
            +
                    split_name: str,
         
     | 
| 192 | 
         
            +
                    is_train_split,
         
     | 
| 193 | 
         
            +
                    samples: List[Dict],
         
     | 
| 194 | 
         
            +
                    cfg: S2TDataConfig,
         
     | 
| 195 | 
         
            +
                    tgt_dict,
         
     | 
| 196 | 
         
            +
                    pre_tokenizer,
         
     | 
| 197 | 
         
            +
                    bpe_tokenizer,
         
     | 
| 198 | 
         
            +
                    n_frames_per_step,
         
     | 
| 199 | 
         
            +
                    speaker_to_id,
         
     | 
| 200 | 
         
            +
                    multitask=None,
         
     | 
| 201 | 
         
            +
                ) -> TextToSpeechDataset:
         
     | 
| 202 | 
         
            +
                    audio_root = Path(cfg.audio_root)
         
     | 
| 203 | 
         
            +
                    ids = [s[cls.KEY_ID] for s in samples]
         
     | 
| 204 | 
         
            +
                    audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
         
     | 
| 205 | 
         
            +
                    n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
         
     | 
| 206 | 
         
            +
                    tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
         
     | 
| 207 | 
         
            +
                    src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
         
     | 
| 208 | 
         
            +
                    speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
         
     | 
| 209 | 
         
            +
                    src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
         
     | 
| 210 | 
         
            +
                    tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    durations = [s.get(cls.KEY_DURATION, None) for s in samples]
         
     | 
| 213 | 
         
            +
                    durations = [
         
     | 
| 214 | 
         
            +
                        None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
         
     | 
| 215 | 
         
            +
                    ]
         
     | 
| 216 | 
         
            +
                    durations = None if any(dd is None for dd in durations) else durations
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
         
     | 
| 219 | 
         
            +
                    pitches = [
         
     | 
| 220 | 
         
            +
                        None if pp is None else (audio_root / pp).as_posix() for pp in pitches
         
     | 
| 221 | 
         
            +
                    ]
         
     | 
| 222 | 
         
            +
                    pitches = None if any(pp is None for pp in pitches) else pitches
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
         
     | 
| 225 | 
         
            +
                    energies = [
         
     | 
| 226 | 
         
            +
                        None if ee is None else (audio_root / ee).as_posix() for ee in energies
         
     | 
| 227 | 
         
            +
                    ]
         
     | 
| 228 | 
         
            +
                    energies = None if any(ee is None for ee in energies) else energies
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    return TextToSpeechDataset(
         
     | 
| 231 | 
         
            +
                        split_name,
         
     | 
| 232 | 
         
            +
                        is_train_split,
         
     | 
| 233 | 
         
            +
                        cfg,
         
     | 
| 234 | 
         
            +
                        audio_paths,
         
     | 
| 235 | 
         
            +
                        n_frames,
         
     | 
| 236 | 
         
            +
                        src_texts,
         
     | 
| 237 | 
         
            +
                        tgt_texts,
         
     | 
| 238 | 
         
            +
                        speakers,
         
     | 
| 239 | 
         
            +
                        src_langs,
         
     | 
| 240 | 
         
            +
                        tgt_langs,
         
     | 
| 241 | 
         
            +
                        ids,
         
     | 
| 242 | 
         
            +
                        tgt_dict,
         
     | 
| 243 | 
         
            +
                        pre_tokenizer,
         
     | 
| 244 | 
         
            +
                        bpe_tokenizer,
         
     | 
| 245 | 
         
            +
                        n_frames_per_step,
         
     | 
| 246 | 
         
            +
                        speaker_to_id,
         
     | 
| 247 | 
         
            +
                        durations,
         
     | 
| 248 | 
         
            +
                        pitches,
         
     | 
| 249 | 
         
            +
                        energies,
         
     | 
| 250 | 
         
            +
                    )
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/waveform_transforms/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,48 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            from fairseq.data.audio import (
         
     | 
| 3 | 
         
            +
                AudioTransform,
         
     | 
| 4 | 
         
            +
                CompositeAudioTransform,
         
     | 
| 5 | 
         
            +
                import_transforms,
         
     | 
| 6 | 
         
            +
                register_audio_transform,
         
     | 
| 7 | 
         
            +
            )
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class AudioWaveformTransform(AudioTransform):
         
     | 
| 11 | 
         
            +
                pass
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
         
     | 
| 15 | 
         
            +
            AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def get_audio_waveform_transform(name):
         
     | 
| 19 | 
         
            +
                return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def register_audio_waveform_transform(name):
         
     | 
| 23 | 
         
            +
                return register_audio_transform(
         
     | 
| 24 | 
         
            +
                    name,
         
     | 
| 25 | 
         
            +
                    AudioWaveformTransform,
         
     | 
| 26 | 
         
            +
                    AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
         
     | 
| 27 | 
         
            +
                    AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
         
     | 
| 28 | 
         
            +
                )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            import_transforms(os.path.dirname(__file__), "waveform")
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            class CompositeAudioWaveformTransform(CompositeAudioTransform):
         
     | 
| 35 | 
         
            +
                @classmethod
         
     | 
| 36 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 37 | 
         
            +
                    return super()._from_config_dict(
         
     | 
| 38 | 
         
            +
                        cls,
         
     | 
| 39 | 
         
            +
                        "waveform",
         
     | 
| 40 | 
         
            +
                        get_audio_waveform_transform,
         
     | 
| 41 | 
         
            +
                        CompositeAudioWaveformTransform,
         
     | 
| 42 | 
         
            +
                        config,
         
     | 
| 43 | 
         
            +
                    )
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def __call__(self, x, sample_rate):
         
     | 
| 46 | 
         
            +
                    for t in self.transforms:
         
     | 
| 47 | 
         
            +
                        x, sample_rate = t(x, sample_rate)
         
     | 
| 48 | 
         
            +
                    return x, sample_rate
         
     | 
    	
        modules/voice_conversion/fairseq/data/audio/waveform_transforms/noiseaugment.py
    ADDED
    
    | 
         @@ -0,0 +1,201 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from pathlib import Path
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            from math import ceil
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from fairseq.data.audio import rand_uniform
         
     | 
| 6 | 
         
            +
            from fairseq.data.audio.waveform_transforms import (
         
     | 
| 7 | 
         
            +
                AudioWaveformTransform,
         
     | 
| 8 | 
         
            +
                register_audio_waveform_transform,
         
     | 
| 9 | 
         
            +
            )
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            SNR_MIN = 5.0
         
     | 
| 12 | 
         
            +
            SNR_MAX = 15.0
         
     | 
| 13 | 
         
            +
            RATE = 0.25
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            NOISE_RATE = 1.0
         
     | 
| 16 | 
         
            +
            NOISE_LEN_MEAN = 0.2
         
     | 
| 17 | 
         
            +
            NOISE_LEN_STD = 0.05
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            class NoiseAugmentTransform(AudioWaveformTransform):
         
     | 
| 21 | 
         
            +
                @classmethod
         
     | 
| 22 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 23 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 24 | 
         
            +
                    return cls(
         
     | 
| 25 | 
         
            +
                        _config.get("samples_path", None),
         
     | 
| 26 | 
         
            +
                        _config.get("snr_min", SNR_MIN),
         
     | 
| 27 | 
         
            +
                        _config.get("snr_max", SNR_MAX),
         
     | 
| 28 | 
         
            +
                        _config.get("rate", RATE),
         
     | 
| 29 | 
         
            +
                    )
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def __init__(
         
     | 
| 32 | 
         
            +
                    self,
         
     | 
| 33 | 
         
            +
                    samples_path: str,
         
     | 
| 34 | 
         
            +
                    snr_min: float = SNR_MIN,
         
     | 
| 35 | 
         
            +
                    snr_max: float = SNR_MAX,
         
     | 
| 36 | 
         
            +
                    rate: float = RATE,
         
     | 
| 37 | 
         
            +
                ):
         
     | 
| 38 | 
         
            +
                    # Sanity checks
         
     | 
| 39 | 
         
            +
                    assert (
         
     | 
| 40 | 
         
            +
                        samples_path
         
     | 
| 41 | 
         
            +
                    ), "need to provide path to audio samples for noise augmentation"
         
     | 
| 42 | 
         
            +
                    assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
         
     | 
| 43 | 
         
            +
                    assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    self.paths = list(Path(samples_path).glob("**/*.wav"))  # load music
         
     | 
| 46 | 
         
            +
                    self.n_samples = len(self.paths)
         
     | 
| 47 | 
         
            +
                    assert self.n_samples > 0, f"no audio files found in {samples_path}"
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    self.snr_min = snr_min
         
     | 
| 50 | 
         
            +
                    self.snr_max = snr_max
         
     | 
| 51 | 
         
            +
                    self.rate = rate
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def __repr__(self):
         
     | 
| 54 | 
         
            +
                    return (
         
     | 
| 55 | 
         
            +
                        self.__class__.__name__
         
     | 
| 56 | 
         
            +
                        + "("
         
     | 
| 57 | 
         
            +
                        + ", ".join(
         
     | 
| 58 | 
         
            +
                            [
         
     | 
| 59 | 
         
            +
                                f"n_samples={self.n_samples}",
         
     | 
| 60 | 
         
            +
                                f"snr={self.snr_min}-{self.snr_max}dB",
         
     | 
| 61 | 
         
            +
                                f"rate={self.rate}",
         
     | 
| 62 | 
         
            +
                            ]
         
     | 
| 63 | 
         
            +
                        )
         
     | 
| 64 | 
         
            +
                        + ")"
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
         
     | 
| 68 | 
         
            +
                    from fairseq.data.audio.audio_utils import get_waveform
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    path = self.paths[np.random.randint(0, self.n_samples)]
         
     | 
| 71 | 
         
            +
                    sample = get_waveform(
         
     | 
| 72 | 
         
            +
                        path, always_2d=always_2d, output_sample_rate=use_sample_rate
         
     | 
| 73 | 
         
            +
                    )[0]
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # Check dimensions match, else silently skip adding noise to sample
         
     | 
| 76 | 
         
            +
                    # NOTE: SHOULD THIS QUIT WITH AN ERROR?
         
     | 
| 77 | 
         
            +
                    is_2d = len(goal_shape) == 2
         
     | 
| 78 | 
         
            +
                    if len(goal_shape) != sample.ndim or (
         
     | 
| 79 | 
         
            +
                        is_2d and goal_shape[0] != sample.shape[0]
         
     | 
| 80 | 
         
            +
                    ):
         
     | 
| 81 | 
         
            +
                        return np.zeros(goal_shape)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    # Cut/repeat sample to size
         
     | 
| 84 | 
         
            +
                    len_dim = len(goal_shape) - 1
         
     | 
| 85 | 
         
            +
                    n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
         
     | 
| 86 | 
         
            +
                    repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
         
     | 
| 87 | 
         
            +
                    start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
         
     | 
| 88 | 
         
            +
                    return (
         
     | 
| 89 | 
         
            +
                        repeated[:, start : start + goal_shape[len_dim]]
         
     | 
| 90 | 
         
            +
                        if is_2d
         
     | 
| 91 | 
         
            +
                        else repeated[start : start + goal_shape[len_dim]]
         
     | 
| 92 | 
         
            +
                    )
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                def _mix(self, source, noise, snr):
         
     | 
| 95 | 
         
            +
                    get_power = lambda x: np.mean(x**2)
         
     | 
| 96 | 
         
            +
                    if get_power(noise):
         
     | 
| 97 | 
         
            +
                        scl = np.sqrt(
         
     | 
| 98 | 
         
            +
                            get_power(source) / (np.power(10, snr / 10) * get_power(noise))
         
     | 
| 99 | 
         
            +
                        )
         
     | 
| 100 | 
         
            +
                    else:
         
     | 
| 101 | 
         
            +
                        scl = 0
         
     | 
| 102 | 
         
            +
                    return 1 * source + scl * noise
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
         
     | 
| 105 | 
         
            +
                    return self.pick_sample(goal_shape, always_2d, use_sample_rate)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def __call__(self, source, sample_rate):
         
     | 
| 108 | 
         
            +
                    if np.random.random() > self.rate:
         
     | 
| 109 | 
         
            +
                        return source, sample_rate
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    noise = self._get_noise(
         
     | 
| 112 | 
         
            +
                        source.shape, always_2d=True, use_sample_rate=sample_rate
         
     | 
| 113 | 
         
            +
                    )
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    return (
         
     | 
| 116 | 
         
            +
                        self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
         
     | 
| 117 | 
         
            +
                        sample_rate,
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            @register_audio_waveform_transform("musicaugment")
         
     | 
| 122 | 
         
            +
            class MusicAugmentTransform(NoiseAugmentTransform):
         
     | 
| 123 | 
         
            +
                pass
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
            @register_audio_waveform_transform("backgroundnoiseaugment")
         
     | 
| 127 | 
         
            +
            class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
         
     | 
| 128 | 
         
            +
                pass
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            @register_audio_waveform_transform("babbleaugment")
         
     | 
| 132 | 
         
            +
            class BabbleAugmentTransform(NoiseAugmentTransform):
         
     | 
| 133 | 
         
            +
                def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
         
     | 
| 134 | 
         
            +
                    for i in range(np.random.randint(3, 8)):
         
     | 
| 135 | 
         
            +
                        speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
         
     | 
| 136 | 
         
            +
                        if i == 0:
         
     | 
| 137 | 
         
            +
                            agg_noise = speech
         
     | 
| 138 | 
         
            +
                        else:  # SNR scaled by i (how many noise signals already in agg_noise)
         
     | 
| 139 | 
         
            +
                            agg_noise = self._mix(agg_noise, speech, i)
         
     | 
| 140 | 
         
            +
                    return agg_noise
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            @register_audio_waveform_transform("sporadicnoiseaugment")
         
     | 
| 144 | 
         
            +
            class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
         
     | 
| 145 | 
         
            +
                @classmethod
         
     | 
| 146 | 
         
            +
                def from_config_dict(cls, config=None):
         
     | 
| 147 | 
         
            +
                    _config = {} if config is None else config
         
     | 
| 148 | 
         
            +
                    return cls(
         
     | 
| 149 | 
         
            +
                        _config.get("samples_path", None),
         
     | 
| 150 | 
         
            +
                        _config.get("snr_min", SNR_MIN),
         
     | 
| 151 | 
         
            +
                        _config.get("snr_max", SNR_MAX),
         
     | 
| 152 | 
         
            +
                        _config.get("rate", RATE),
         
     | 
| 153 | 
         
            +
                        _config.get("noise_rate", NOISE_RATE),
         
     | 
| 154 | 
         
            +
                        _config.get("noise_len_mean", NOISE_LEN_MEAN),
         
     | 
| 155 | 
         
            +
                        _config.get("noise_len_std", NOISE_LEN_STD),
         
     | 
| 156 | 
         
            +
                    )
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                def __init__(
         
     | 
| 159 | 
         
            +
                    self,
         
     | 
| 160 | 
         
            +
                    samples_path: str,
         
     | 
| 161 | 
         
            +
                    snr_min: float = SNR_MIN,
         
     | 
| 162 | 
         
            +
                    snr_max: float = SNR_MAX,
         
     | 
| 163 | 
         
            +
                    rate: float = RATE,
         
     | 
| 164 | 
         
            +
                    noise_rate: float = NOISE_RATE,  # noises per second
         
     | 
| 165 | 
         
            +
                    noise_len_mean: float = NOISE_LEN_MEAN,  # length of noises in seconds
         
     | 
| 166 | 
         
            +
                    noise_len_std: float = NOISE_LEN_STD,
         
     | 
| 167 | 
         
            +
                ):
         
     | 
| 168 | 
         
            +
                    super().__init__(samples_path, snr_min, snr_max, rate)
         
     | 
| 169 | 
         
            +
                    self.noise_rate = noise_rate
         
     | 
| 170 | 
         
            +
                    self.noise_len_mean = noise_len_mean
         
     | 
| 171 | 
         
            +
                    self.noise_len_std = noise_len_std
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
         
     | 
| 174 | 
         
            +
                    agg_noise = np.zeros(goal_shape)
         
     | 
| 175 | 
         
            +
                    len_dim = len(goal_shape) - 1
         
     | 
| 176 | 
         
            +
                    is_2d = len(goal_shape) == 2
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
         
     | 
| 179 | 
         
            +
                    start_pointers = [
         
     | 
| 180 | 
         
            +
                        round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
         
     | 
| 181 | 
         
            +
                    ]
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    for start_pointer in start_pointers:
         
     | 
| 184 | 
         
            +
                        noise_shape = list(goal_shape)
         
     | 
| 185 | 
         
            +
                        len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
         
     | 
| 186 | 
         
            +
                        noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
         
     | 
| 187 | 
         
            +
                        end_pointer = start_pointer + noise_shape[len_dim]
         
     | 
| 188 | 
         
            +
                        if end_pointer >= goal_shape[len_dim]:
         
     | 
| 189 | 
         
            +
                            continue
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                        noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
         
     | 
| 192 | 
         
            +
                        if is_2d:
         
     | 
| 193 | 
         
            +
                            agg_noise[:, start_pointer:end_pointer] = (
         
     | 
| 194 | 
         
            +
                                agg_noise[:, start_pointer:end_pointer] + noise
         
     | 
| 195 | 
         
            +
                            )
         
     | 
| 196 | 
         
            +
                        else:
         
     | 
| 197 | 
         
            +
                            agg_noise[start_pointer:end_pointer] = (
         
     | 
| 198 | 
         
            +
                                agg_noise[start_pointer:end_pointer] + noise
         
     | 
| 199 | 
         
            +
                            )
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    return agg_noise
         
     | 
    	
        modules/voice_conversion/fairseq/data/backtranslation_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,165 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from fairseq import utils
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from . import FairseqDataset
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
         
     | 
| 13 | 
         
            +
                """Backtranslate a list of samples.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                Given an input (*samples*) of the form:
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    [{'id': 1, 'source': 'hallo welt'}]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                this will return:
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                Args:
         
     | 
| 24 | 
         
            +
                    samples (List[dict]): samples to backtranslate. Individual samples are
         
     | 
| 25 | 
         
            +
                        expected to have a 'source' key, which will become the 'target'
         
     | 
| 26 | 
         
            +
                        after backtranslation.
         
     | 
| 27 | 
         
            +
                    collate_fn (callable): function to collate samples into a mini-batch
         
     | 
| 28 | 
         
            +
                    generate_fn (callable): function to generate backtranslations
         
     | 
| 29 | 
         
            +
                    cuda (bool): use GPU for generation (default: ``True``)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Returns:
         
     | 
| 32 | 
         
            +
                    List[dict]: an updated list of samples with a backtranslated source
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
                collated_samples = collate_fn(samples)
         
     | 
| 35 | 
         
            +
                s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
         
     | 
| 36 | 
         
            +
                generated_sources = generate_fn(s)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                id_to_src = {sample["id"]: sample["source"] for sample in samples}
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                # Go through each tgt sentence in batch and its corresponding best
         
     | 
| 41 | 
         
            +
                # generated hypothesis and create a backtranslation data pair
         
     | 
| 42 | 
         
            +
                # {id: id, source: generated backtranslation, target: original tgt}
         
     | 
| 43 | 
         
            +
                return [
         
     | 
| 44 | 
         
            +
                    {
         
     | 
| 45 | 
         
            +
                        "id": id.item(),
         
     | 
| 46 | 
         
            +
                        "target": id_to_src[id.item()],
         
     | 
| 47 | 
         
            +
                        "source": hypos[0]["tokens"].cpu(),
         
     | 
| 48 | 
         
            +
                    }
         
     | 
| 49 | 
         
            +
                    for id, hypos in zip(collated_samples["id"], generated_sources)
         
     | 
| 50 | 
         
            +
                ]
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            class BacktranslationDataset(FairseqDataset):
         
     | 
| 54 | 
         
            +
                """
         
     | 
| 55 | 
         
            +
                Sets up a backtranslation dataset which takes a tgt batch, generates
         
     | 
| 56 | 
         
            +
                a src using a tgt-src backtranslation function (*backtranslation_fn*),
         
     | 
| 57 | 
         
            +
                and returns the corresponding `{generated src, input tgt}` batch.
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                Args:
         
     | 
| 60 | 
         
            +
                    tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
         
     | 
| 61 | 
         
            +
                        backtranslated. Only the source side of this dataset will be used.
         
     | 
| 62 | 
         
            +
                        After backtranslation, the source sentences in this dataset will be
         
     | 
| 63 | 
         
            +
                        returned as the targets.
         
     | 
| 64 | 
         
            +
                    src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
         
     | 
| 65 | 
         
            +
                        sentences.
         
     | 
| 66 | 
         
            +
                    tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
         
     | 
| 67 | 
         
            +
                        sentences to be backtranslated.
         
     | 
| 68 | 
         
            +
                    backtranslation_fn (callable, optional): function to call to generate
         
     | 
| 69 | 
         
            +
                        backtranslations. This is typically the `generate` method of a
         
     | 
| 70 | 
         
            +
                        :class:`~fairseq.sequence_generator.SequenceGenerator` object.
         
     | 
| 71 | 
         
            +
                        Pass in None when it is not available at initialization time, and
         
     | 
| 72 | 
         
            +
                        use set_backtranslation_fn function to set it when available.
         
     | 
| 73 | 
         
            +
                    output_collater (callable, optional): function to call on the
         
     | 
| 74 | 
         
            +
                        backtranslated samples to create the final batch
         
     | 
| 75 | 
         
            +
                        (default: ``tgt_dataset.collater``).
         
     | 
| 76 | 
         
            +
                    cuda: use GPU for generation
         
     | 
| 77 | 
         
            +
                """
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                def __init__(
         
     | 
| 80 | 
         
            +
                    self,
         
     | 
| 81 | 
         
            +
                    tgt_dataset,
         
     | 
| 82 | 
         
            +
                    src_dict,
         
     | 
| 83 | 
         
            +
                    tgt_dict=None,
         
     | 
| 84 | 
         
            +
                    backtranslation_fn=None,
         
     | 
| 85 | 
         
            +
                    output_collater=None,
         
     | 
| 86 | 
         
            +
                    cuda=True,
         
     | 
| 87 | 
         
            +
                    **kwargs
         
     | 
| 88 | 
         
            +
                ):
         
     | 
| 89 | 
         
            +
                    self.tgt_dataset = tgt_dataset
         
     | 
| 90 | 
         
            +
                    self.backtranslation_fn = backtranslation_fn
         
     | 
| 91 | 
         
            +
                    self.output_collater = (
         
     | 
| 92 | 
         
            +
                        output_collater if output_collater is not None else tgt_dataset.collater
         
     | 
| 93 | 
         
            +
                    )
         
     | 
| 94 | 
         
            +
                    self.cuda = cuda if torch.cuda.is_available() else False
         
     | 
| 95 | 
         
            +
                    self.src_dict = src_dict
         
     | 
| 96 | 
         
            +
                    self.tgt_dict = tgt_dict
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 99 | 
         
            +
                    """
         
     | 
| 100 | 
         
            +
                    Returns a single sample from *tgt_dataset*. Note that backtranslation is
         
     | 
| 101 | 
         
            +
                    not applied in this step; use :func:`collater` instead to backtranslate
         
     | 
| 102 | 
         
            +
                    a batch of samples.
         
     | 
| 103 | 
         
            +
                    """
         
     | 
| 104 | 
         
            +
                    return self.tgt_dataset[index]
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                def __len__(self):
         
     | 
| 107 | 
         
            +
                    return len(self.tgt_dataset)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def set_backtranslation_fn(self, backtranslation_fn):
         
     | 
| 110 | 
         
            +
                    self.backtranslation_fn = backtranslation_fn
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def collater(self, samples):
         
     | 
| 113 | 
         
            +
                    """Merge and backtranslate a list of samples to form a mini-batch.
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    Using the samples from *tgt_dataset*, load a collated target sample to
         
     | 
| 116 | 
         
            +
                    feed to the backtranslation model. Then take the backtranslation with
         
     | 
| 117 | 
         
            +
                    the best score as the source and the original input as the target.
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    Note: we expect *tgt_dataset* to provide a function `collater()` that
         
     | 
| 120 | 
         
            +
                    will collate samples into the format expected by *backtranslation_fn*.
         
     | 
| 121 | 
         
            +
                    After backtranslation, we will feed the new list of samples (i.e., the
         
     | 
| 122 | 
         
            +
                    `(backtranslated source, original source)` pairs) to *output_collater*
         
     | 
| 123 | 
         
            +
                    and return the result.
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    Args:
         
     | 
| 126 | 
         
            +
                        samples (List[dict]): samples to backtranslate and collate
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    Returns:
         
     | 
| 129 | 
         
            +
                        dict: a mini-batch with keys coming from *output_collater*
         
     | 
| 130 | 
         
            +
                    """
         
     | 
| 131 | 
         
            +
                    if samples[0].get("is_dummy", False):
         
     | 
| 132 | 
         
            +
                        return samples
         
     | 
| 133 | 
         
            +
                    samples = backtranslate_samples(
         
     | 
| 134 | 
         
            +
                        samples=samples,
         
     | 
| 135 | 
         
            +
                        collate_fn=self.tgt_dataset.collater,
         
     | 
| 136 | 
         
            +
                        generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
         
     | 
| 137 | 
         
            +
                        cuda=self.cuda,
         
     | 
| 138 | 
         
            +
                    )
         
     | 
| 139 | 
         
            +
                    return self.output_collater(samples)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 142 | 
         
            +
                    """Just use the tgt dataset num_tokens"""
         
     | 
| 143 | 
         
            +
                    return self.tgt_dataset.num_tokens(index)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                def ordered_indices(self):
         
     | 
| 146 | 
         
            +
                    """Just use the tgt dataset ordered_indices"""
         
     | 
| 147 | 
         
            +
                    return self.tgt_dataset.ordered_indices()
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def size(self, index):
         
     | 
| 150 | 
         
            +
                    """Return an example's size as a float or tuple. This value is used
         
     | 
| 151 | 
         
            +
                    when filtering a dataset with ``--max-positions``.
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    Note: we use *tgt_dataset* to approximate the length of the source
         
     | 
| 154 | 
         
            +
                    sentence, since we do not know the actual length until after
         
     | 
| 155 | 
         
            +
                    backtranslation.
         
     | 
| 156 | 
         
            +
                    """
         
     | 
| 157 | 
         
            +
                    tgt_size = self.tgt_dataset.size(index)[0]
         
     | 
| 158 | 
         
            +
                    return (tgt_size, tgt_size)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                @property
         
     | 
| 161 | 
         
            +
                def supports_prefetch(self):
         
     | 
| 162 | 
         
            +
                    return getattr(self.tgt_dataset, "supports_prefetch", False)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def prefetch(self, indices):
         
     | 
| 165 | 
         
            +
                    return self.tgt_dataset.prefetch(indices)
         
     | 
    	
        modules/voice_conversion/fairseq/data/base_wrapper_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,78 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from torch.utils.data.dataloader import default_collate
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from . import FairseqDataset
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class BaseWrapperDataset(FairseqDataset):
         
     | 
| 12 | 
         
            +
                def __init__(self, dataset):
         
     | 
| 13 | 
         
            +
                    super().__init__()
         
     | 
| 14 | 
         
            +
                    self.dataset = dataset
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 17 | 
         
            +
                    return self.dataset[index]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def __len__(self):
         
     | 
| 20 | 
         
            +
                    return len(self.dataset)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                def collater(self, samples):
         
     | 
| 23 | 
         
            +
                    if hasattr(self.dataset, "collater"):
         
     | 
| 24 | 
         
            +
                        return self.dataset.collater(samples)
         
     | 
| 25 | 
         
            +
                    else:
         
     | 
| 26 | 
         
            +
                        return default_collate(samples)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                @property
         
     | 
| 29 | 
         
            +
                def sizes(self):
         
     | 
| 30 | 
         
            +
                    return self.dataset.sizes
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 33 | 
         
            +
                    return self.dataset.num_tokens(index)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def size(self, index):
         
     | 
| 36 | 
         
            +
                    return self.dataset.size(index)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def ordered_indices(self):
         
     | 
| 39 | 
         
            +
                    return self.dataset.ordered_indices()
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                @property
         
     | 
| 42 | 
         
            +
                def supports_prefetch(self):
         
     | 
| 43 | 
         
            +
                    return getattr(self.dataset, "supports_prefetch", False)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def attr(self, attr: str, index: int):
         
     | 
| 46 | 
         
            +
                    return self.dataset.attr(attr, index)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def prefetch(self, indices):
         
     | 
| 49 | 
         
            +
                    self.dataset.prefetch(indices)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def get_batch_shapes(self):
         
     | 
| 52 | 
         
            +
                    return self.dataset.get_batch_shapes()
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def batch_by_size(
         
     | 
| 55 | 
         
            +
                    self,
         
     | 
| 56 | 
         
            +
                    indices,
         
     | 
| 57 | 
         
            +
                    max_tokens=None,
         
     | 
| 58 | 
         
            +
                    max_sentences=None,
         
     | 
| 59 | 
         
            +
                    required_batch_size_multiple=1,
         
     | 
| 60 | 
         
            +
                ):
         
     | 
| 61 | 
         
            +
                    return self.dataset.batch_by_size(
         
     | 
| 62 | 
         
            +
                        indices,
         
     | 
| 63 | 
         
            +
                        max_tokens=max_tokens,
         
     | 
| 64 | 
         
            +
                        max_sentences=max_sentences,
         
     | 
| 65 | 
         
            +
                        required_batch_size_multiple=required_batch_size_multiple,
         
     | 
| 66 | 
         
            +
                    )
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def filter_indices_by_size(self, indices, max_sizes):
         
     | 
| 69 | 
         
            +
                    return self.dataset.filter_indices_by_size(indices, max_sizes)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                @property
         
     | 
| 72 | 
         
            +
                def can_reuse_epoch_itr_across_epochs(self):
         
     | 
| 73 | 
         
            +
                    return self.dataset.can_reuse_epoch_itr_across_epochs
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def set_epoch(self, epoch):
         
     | 
| 76 | 
         
            +
                    super().set_epoch(epoch)
         
     | 
| 77 | 
         
            +
                    if hasattr(self.dataset, "set_epoch"):
         
     | 
| 78 | 
         
            +
                        self.dataset.set_epoch(epoch)
         
     | 
    	
        modules/voice_conversion/fairseq/data/bucket_pad_length_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,78 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 8 | 
         
            +
            from fairseq.data import BaseWrapperDataset
         
     | 
| 9 | 
         
            +
            from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class BucketPadLengthDataset(BaseWrapperDataset):
         
     | 
| 13 | 
         
            +
                """
         
     | 
| 14 | 
         
            +
                Bucket and pad item lengths to the nearest bucket size. This can be used to
         
     | 
| 15 | 
         
            +
                reduce the number of unique batch shapes, which is important on TPUs since
         
     | 
| 16 | 
         
            +
                each new batch shape requires a recompilation.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                Args:
         
     | 
| 19 | 
         
            +
                    dataset (FairseqDatset): dataset to bucket
         
     | 
| 20 | 
         
            +
                    sizes (List[int]): all item sizes
         
     | 
| 21 | 
         
            +
                    num_buckets (int): number of buckets to create
         
     | 
| 22 | 
         
            +
                    pad_idx (int): padding symbol
         
     | 
| 23 | 
         
            +
                    left_pad (bool): if True, pad on the left; otherwise right pad
         
     | 
| 24 | 
         
            +
                """
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def __init__(
         
     | 
| 27 | 
         
            +
                    self,
         
     | 
| 28 | 
         
            +
                    dataset,
         
     | 
| 29 | 
         
            +
                    sizes,
         
     | 
| 30 | 
         
            +
                    num_buckets,
         
     | 
| 31 | 
         
            +
                    pad_idx,
         
     | 
| 32 | 
         
            +
                    left_pad,
         
     | 
| 33 | 
         
            +
                    tensor_key=None,
         
     | 
| 34 | 
         
            +
                ):
         
     | 
| 35 | 
         
            +
                    super().__init__(dataset)
         
     | 
| 36 | 
         
            +
                    self.pad_idx = pad_idx
         
     | 
| 37 | 
         
            +
                    self.left_pad = left_pad
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    assert num_buckets > 0
         
     | 
| 40 | 
         
            +
                    self.buckets = get_buckets(sizes, num_buckets)
         
     | 
| 41 | 
         
            +
                    self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
         
     | 
| 42 | 
         
            +
                    self._tensor_key = tensor_key
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def _set_tensor(self, item, val):
         
     | 
| 45 | 
         
            +
                    if self._tensor_key is None:
         
     | 
| 46 | 
         
            +
                        return val
         
     | 
| 47 | 
         
            +
                    item[self._tensor_key] = val
         
     | 
| 48 | 
         
            +
                    return item
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def _get_tensor(self, item):
         
     | 
| 51 | 
         
            +
                    if self._tensor_key is None:
         
     | 
| 52 | 
         
            +
                        return item
         
     | 
| 53 | 
         
            +
                    return item[self._tensor_key]
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def _pad(self, tensor, bucket_size, dim=-1):
         
     | 
| 56 | 
         
            +
                    num_pad = bucket_size - tensor.size(dim)
         
     | 
| 57 | 
         
            +
                    return F.pad(
         
     | 
| 58 | 
         
            +
                        tensor,
         
     | 
| 59 | 
         
            +
                        (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
         
     | 
| 60 | 
         
            +
                        value=self.pad_idx,
         
     | 
| 61 | 
         
            +
                    )
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 64 | 
         
            +
                    item = self.dataset[index]
         
     | 
| 65 | 
         
            +
                    bucket_size = self._bucketed_sizes[index]
         
     | 
| 66 | 
         
            +
                    tensor = self._get_tensor(item)
         
     | 
| 67 | 
         
            +
                    padded = self._pad(tensor, bucket_size)
         
     | 
| 68 | 
         
            +
                    return self._set_tensor(item, padded)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                @property
         
     | 
| 71 | 
         
            +
                def sizes(self):
         
     | 
| 72 | 
         
            +
                    return self._bucketed_sizes
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 75 | 
         
            +
                    return self._bucketed_sizes[index]
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def size(self, index):
         
     | 
| 78 | 
         
            +
                    return self._bucketed_sizes[index]
         
     | 
    	
        modules/voice_conversion/fairseq/data/codedataset.py
    ADDED
    
    | 
         @@ -0,0 +1,576 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import json
         
     | 
| 8 | 
         
            +
            import logging
         
     | 
| 9 | 
         
            +
            import os
         
     | 
| 10 | 
         
            +
            import random
         
     | 
| 11 | 
         
            +
            from pathlib import Path
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import numpy as np
         
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            import torch.utils.data
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from . import data_utils
         
     | 
| 18 | 
         
            +
            from fairseq.data.fairseq_dataset import FairseqDataset
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            F0_FRAME_SPACE = 0.005  # sec
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            class ExpressiveCodeDataConfig(object):
         
     | 
| 27 | 
         
            +
                def __init__(self, json_path):
         
     | 
| 28 | 
         
            +
                    with open(json_path, "r") as f:
         
     | 
| 29 | 
         
            +
                        self.config = json.load(f)
         
     | 
| 30 | 
         
            +
                    self._manifests = self.config["manifests"]
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                @property
         
     | 
| 33 | 
         
            +
                def manifests(self):
         
     | 
| 34 | 
         
            +
                    return self._manifests
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                @property
         
     | 
| 37 | 
         
            +
                def n_units(self):
         
     | 
| 38 | 
         
            +
                    return self.config["n_units"]
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                @property
         
     | 
| 41 | 
         
            +
                def sampling_rate(self):
         
     | 
| 42 | 
         
            +
                    return self.config["sampling_rate"]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                @property
         
     | 
| 45 | 
         
            +
                def code_hop_size(self):
         
     | 
| 46 | 
         
            +
                    return self.config["code_hop_size"]
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                @property
         
     | 
| 49 | 
         
            +
                def f0_stats(self):
         
     | 
| 50 | 
         
            +
                    """pre-computed f0 statistics path"""
         
     | 
| 51 | 
         
            +
                    return self.config.get("f0_stats", None)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                @property
         
     | 
| 54 | 
         
            +
                def f0_vq_type(self):
         
     | 
| 55 | 
         
            +
                    """naive or precomp"""
         
     | 
| 56 | 
         
            +
                    return self.config["f0_vq_type"]
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                @property
         
     | 
| 59 | 
         
            +
                def f0_vq_name(self):
         
     | 
| 60 | 
         
            +
                    return self.config["f0_vq_name"]
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
         
     | 
| 63 | 
         
            +
                    key = "log" if log else "linear"
         
     | 
| 64 | 
         
            +
                    if norm_mean and norm_std:
         
     | 
| 65 | 
         
            +
                        key += "_mean_std_norm"
         
     | 
| 66 | 
         
            +
                    elif norm_mean:
         
     | 
| 67 | 
         
            +
                        key += "_mean_norm"
         
     | 
| 68 | 
         
            +
                    else:
         
     | 
| 69 | 
         
            +
                        key += "_none_norm"
         
     | 
| 70 | 
         
            +
                    return self.config["f0_vq_naive_quantizer"][key]
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                @property
         
     | 
| 73 | 
         
            +
                def f0_vq_n_units(self):
         
     | 
| 74 | 
         
            +
                    return self.config["f0_vq_n_units"]
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                @property
         
     | 
| 77 | 
         
            +
                def multispkr(self):
         
     | 
| 78 | 
         
            +
                    """how to parse speaker label from audio path"""
         
     | 
| 79 | 
         
            +
                    return self.config.get("multispkr", None)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            def get_f0(audio, rate=16000):
         
     | 
| 83 | 
         
            +
                try:
         
     | 
| 84 | 
         
            +
                    import amfm_decompy.basic_tools as basic
         
     | 
| 85 | 
         
            +
                    import amfm_decompy.pYAAPT as pYAAPT
         
     | 
| 86 | 
         
            +
                    from librosa.util import normalize
         
     | 
| 87 | 
         
            +
                except ImportError:
         
     | 
| 88 | 
         
            +
                    raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                assert audio.ndim == 1
         
     | 
| 91 | 
         
            +
                frame_length = 20.0  # ms
         
     | 
| 92 | 
         
            +
                to_pad = int(frame_length / 1000 * rate) // 2
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                audio = normalize(audio) * 0.95
         
     | 
| 95 | 
         
            +
                audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
         
     | 
| 96 | 
         
            +
                audio = basic.SignalObj(audio, rate)
         
     | 
| 97 | 
         
            +
                pitch = pYAAPT.yaapt(
         
     | 
| 98 | 
         
            +
                    audio,
         
     | 
| 99 | 
         
            +
                    frame_length=frame_length,
         
     | 
| 100 | 
         
            +
                    frame_space=F0_FRAME_SPACE * 1000,
         
     | 
| 101 | 
         
            +
                    nccf_thresh1=0.25,
         
     | 
| 102 | 
         
            +
                    tda_frame_length=25.0,
         
     | 
| 103 | 
         
            +
                )
         
     | 
| 104 | 
         
            +
                f0 = pitch.samp_values
         
     | 
| 105 | 
         
            +
                return f0
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            def interpolate_f0(f0):
         
     | 
| 109 | 
         
            +
                try:
         
     | 
| 110 | 
         
            +
                    from scipy.interpolate import interp1d
         
     | 
| 111 | 
         
            +
                except ImportError:
         
     | 
| 112 | 
         
            +
                    raise "Please install scipy (`pip install scipy`)"
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                orig_t = np.arange(f0.shape[0])
         
     | 
| 115 | 
         
            +
                f0_interp = f0[:]
         
     | 
| 116 | 
         
            +
                ii = f0_interp != 0
         
     | 
| 117 | 
         
            +
                if ii.sum() > 1:
         
     | 
| 118 | 
         
            +
                    f0_interp = interp1d(
         
     | 
| 119 | 
         
            +
                        orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
         
     | 
| 120 | 
         
            +
                    )(orig_t)
         
     | 
| 121 | 
         
            +
                    f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
         
     | 
| 122 | 
         
            +
                return f0_interp
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            def naive_quantize(x, edges):
         
     | 
| 126 | 
         
            +
                bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
         
     | 
| 127 | 
         
            +
                return bin_idx
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            def load_wav(full_path):
         
     | 
| 131 | 
         
            +
                try:
         
     | 
| 132 | 
         
            +
                    import soundfile as sf
         
     | 
| 133 | 
         
            +
                except ImportError:
         
     | 
| 134 | 
         
            +
                    raise "Please install soundfile (`pip install SoundFile`)"
         
     | 
| 135 | 
         
            +
                data, sampling_rate = sf.read(full_path)
         
     | 
| 136 | 
         
            +
                return data, sampling_rate
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            def parse_code(code_str, dictionary, append_eos):
         
     | 
| 140 | 
         
            +
                code, duration = torch.unique_consecutive(
         
     | 
| 141 | 
         
            +
                    torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
         
     | 
| 142 | 
         
            +
                )
         
     | 
| 143 | 
         
            +
                code = " ".join(map(str, code.tolist()))
         
     | 
| 144 | 
         
            +
                code = dictionary.encode_line(code, append_eos).short()
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                if append_eos:
         
     | 
| 147 | 
         
            +
                    duration = torch.cat((duration, duration.new_zeros((1,))), dim=0)  # eos
         
     | 
| 148 | 
         
            +
                duration = duration.short()
         
     | 
| 149 | 
         
            +
                return code, duration
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            def parse_manifest(manifest, dictionary):
         
     | 
| 153 | 
         
            +
                audio_files = []
         
     | 
| 154 | 
         
            +
                codes = []
         
     | 
| 155 | 
         
            +
                durations = []
         
     | 
| 156 | 
         
            +
                speakers = []
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                with open(manifest) as info:
         
     | 
| 159 | 
         
            +
                    for line in info.readlines():
         
     | 
| 160 | 
         
            +
                        sample = eval(line.strip())
         
     | 
| 161 | 
         
            +
                        if "cpc_km100" in sample:
         
     | 
| 162 | 
         
            +
                            k = "cpc_km100"
         
     | 
| 163 | 
         
            +
                        elif "hubert_km100" in sample:
         
     | 
| 164 | 
         
            +
                            k = "hubert_km100"
         
     | 
| 165 | 
         
            +
                        elif "phone" in sample:
         
     | 
| 166 | 
         
            +
                            k = "phone"
         
     | 
| 167 | 
         
            +
                        else:
         
     | 
| 168 | 
         
            +
                            assert False, "unknown format"
         
     | 
| 169 | 
         
            +
                        code = sample[k]
         
     | 
| 170 | 
         
            +
                        code, duration = parse_code(code, dictionary, append_eos=True)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                        codes.append(code)
         
     | 
| 173 | 
         
            +
                        durations.append(duration)
         
     | 
| 174 | 
         
            +
                        audio_files.append(sample["audio"])
         
     | 
| 175 | 
         
            +
                        speakers.append(sample.get("speaker", None))
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                return audio_files, codes, durations, speakers
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
            def parse_speaker(path, method):
         
     | 
| 181 | 
         
            +
                if type(path) == str:
         
     | 
| 182 | 
         
            +
                    path = Path(path)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                if method == "parent_name":
         
     | 
| 185 | 
         
            +
                    return path.parent.name
         
     | 
| 186 | 
         
            +
                elif method == "parent_parent_name":
         
     | 
| 187 | 
         
            +
                    return path.parent.parent.name
         
     | 
| 188 | 
         
            +
                elif method == "_":
         
     | 
| 189 | 
         
            +
                    return path.name.split("_")[0]
         
     | 
| 190 | 
         
            +
                elif method == "single":
         
     | 
| 191 | 
         
            +
                    return "A"
         
     | 
| 192 | 
         
            +
                elif callable(method):
         
     | 
| 193 | 
         
            +
                    return method(path)
         
     | 
| 194 | 
         
            +
                else:
         
     | 
| 195 | 
         
            +
                    raise NotImplementedError()
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            def get_f0_by_filename(filename, tgt_sampling_rate):
         
     | 
| 199 | 
         
            +
                audio, sampling_rate = load_wav(filename)
         
     | 
| 200 | 
         
            +
                if sampling_rate != tgt_sampling_rate:
         
     | 
| 201 | 
         
            +
                    raise ValueError(
         
     | 
| 202 | 
         
            +
                        "{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
         
     | 
| 203 | 
         
            +
                    )
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                # compute un-interpolated f0, and use Ann's interp in __getitem__ if set
         
     | 
| 206 | 
         
            +
                f0 = get_f0(audio, rate=tgt_sampling_rate)
         
     | 
| 207 | 
         
            +
                f0 = torch.from_numpy(f0.astype(np.float32))
         
     | 
| 208 | 
         
            +
                return f0
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
            def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
         
     | 
| 212 | 
         
            +
                code_len = durations.sum()
         
     | 
| 213 | 
         
            +
                targ_len = int(f0_code_ratio * code_len)
         
     | 
| 214 | 
         
            +
                diff = f0.size(0) - targ_len
         
     | 
| 215 | 
         
            +
                assert abs(diff) <= tol, (
         
     | 
| 216 | 
         
            +
                    f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
         
     | 
| 217 | 
         
            +
                    f" > {tol} (dur=\n{durations})"
         
     | 
| 218 | 
         
            +
                )
         
     | 
| 219 | 
         
            +
                if diff > 0:
         
     | 
| 220 | 
         
            +
                    f0 = f0[:targ_len]
         
     | 
| 221 | 
         
            +
                elif diff < 0:
         
     | 
| 222 | 
         
            +
                    f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                f0_offset = 0.0
         
     | 
| 225 | 
         
            +
                seg_f0s = []
         
     | 
| 226 | 
         
            +
                for dur in durations:
         
     | 
| 227 | 
         
            +
                    f0_dur = dur.item() * f0_code_ratio
         
     | 
| 228 | 
         
            +
                    seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
         
     | 
| 229 | 
         
            +
                    seg_f0 = seg_f0[seg_f0 != 0]
         
     | 
| 230 | 
         
            +
                    if len(seg_f0) == 0:
         
     | 
| 231 | 
         
            +
                        seg_f0 = torch.tensor(0).type(seg_f0.type())
         
     | 
| 232 | 
         
            +
                    else:
         
     | 
| 233 | 
         
            +
                        seg_f0 = seg_f0.mean()
         
     | 
| 234 | 
         
            +
                    seg_f0s.append(seg_f0)
         
     | 
| 235 | 
         
            +
                    f0_offset += f0_dur
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
         
     | 
| 238 | 
         
            +
                return torch.tensor(seg_f0s)
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
            class Paddings(object):
         
     | 
| 242 | 
         
            +
                def __init__(self, code_val, dur_val=0, f0_val=-2.0):
         
     | 
| 243 | 
         
            +
                    self.code = code_val
         
     | 
| 244 | 
         
            +
                    self.dur = dur_val
         
     | 
| 245 | 
         
            +
                    self.f0 = f0_val
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
            class Shifts(object):
         
     | 
| 249 | 
         
            +
                def __init__(self, shifts_str, pads):
         
     | 
| 250 | 
         
            +
                    self._shifts = list(map(int, shifts_str.split(",")))
         
     | 
| 251 | 
         
            +
                    assert len(self._shifts) == 2, self._shifts
         
     | 
| 252 | 
         
            +
                    assert all(s >= 0 for s in self._shifts)
         
     | 
| 253 | 
         
            +
                    self.extra_length = max(s for s in self._shifts)
         
     | 
| 254 | 
         
            +
                    self.pads = pads
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                @property
         
     | 
| 257 | 
         
            +
                def dur(self):
         
     | 
| 258 | 
         
            +
                    return self._shifts[0]
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                @property
         
     | 
| 261 | 
         
            +
                def f0(self):
         
     | 
| 262 | 
         
            +
                    return self._shifts[1]
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                @staticmethod
         
     | 
| 265 | 
         
            +
                def shift_one(seq, left_pad_num, right_pad_num, pad):
         
     | 
| 266 | 
         
            +
                    assert seq.ndim == 1
         
     | 
| 267 | 
         
            +
                    bos = seq.new_full((left_pad_num,), pad)
         
     | 
| 268 | 
         
            +
                    eos = seq.new_full((right_pad_num,), pad)
         
     | 
| 269 | 
         
            +
                    seq = torch.cat([bos, seq, eos])
         
     | 
| 270 | 
         
            +
                    mask = torch.ones_like(seq).bool()
         
     | 
| 271 | 
         
            +
                    mask[left_pad_num : len(seq) - right_pad_num] = 0
         
     | 
| 272 | 
         
            +
                    return seq, mask
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                def __call__(self, code, dur, f0):
         
     | 
| 275 | 
         
            +
                    if self.extra_length == 0:
         
     | 
| 276 | 
         
            +
                        code_mask = torch.zeros_like(code).bool()
         
     | 
| 277 | 
         
            +
                        dur_mask = torch.zeros_like(dur).bool()
         
     | 
| 278 | 
         
            +
                        f0_mask = torch.zeros_like(f0).bool()
         
     | 
| 279 | 
         
            +
                        return code, code_mask, dur, dur_mask, f0, f0_mask
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
         
     | 
| 282 | 
         
            +
                    dur, dur_mask = self.shift_one(
         
     | 
| 283 | 
         
            +
                        dur, self.dur, self.extra_length - self.dur, self.pads.dur
         
     | 
| 284 | 
         
            +
                    )
         
     | 
| 285 | 
         
            +
                    f0, f0_mask = self.shift_one(
         
     | 
| 286 | 
         
            +
                        f0, self.f0, self.extra_length - self.f0, self.pads.f0
         
     | 
| 287 | 
         
            +
                    )
         
     | 
| 288 | 
         
            +
                    return code, code_mask, dur, dur_mask, f0, f0_mask
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
            class CodeDataset(FairseqDataset):
         
     | 
| 292 | 
         
            +
                def __init__(
         
     | 
| 293 | 
         
            +
                    self,
         
     | 
| 294 | 
         
            +
                    manifest,
         
     | 
| 295 | 
         
            +
                    dictionary,
         
     | 
| 296 | 
         
            +
                    dur_dictionary,
         
     | 
| 297 | 
         
            +
                    f0_dictionary,
         
     | 
| 298 | 
         
            +
                    config,
         
     | 
| 299 | 
         
            +
                    discrete_dur,
         
     | 
| 300 | 
         
            +
                    discrete_f0,
         
     | 
| 301 | 
         
            +
                    log_f0,
         
     | 
| 302 | 
         
            +
                    normalize_f0_mean,
         
     | 
| 303 | 
         
            +
                    normalize_f0_std,
         
     | 
| 304 | 
         
            +
                    interpolate_f0,
         
     | 
| 305 | 
         
            +
                    return_filename=False,
         
     | 
| 306 | 
         
            +
                    strip_filename=True,
         
     | 
| 307 | 
         
            +
                    shifts="0,0",
         
     | 
| 308 | 
         
            +
                    return_continuous_f0=False,
         
     | 
| 309 | 
         
            +
                ):
         
     | 
| 310 | 
         
            +
                    random.seed(1234)
         
     | 
| 311 | 
         
            +
                    self.dictionary = dictionary
         
     | 
| 312 | 
         
            +
                    self.dur_dictionary = dur_dictionary
         
     | 
| 313 | 
         
            +
                    self.f0_dictionary = f0_dictionary
         
     | 
| 314 | 
         
            +
                    self.config = config
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    # duration config
         
     | 
| 317 | 
         
            +
                    self.discrete_dur = discrete_dur
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    # pitch config
         
     | 
| 320 | 
         
            +
                    self.discrete_f0 = discrete_f0
         
     | 
| 321 | 
         
            +
                    self.log_f0 = log_f0
         
     | 
| 322 | 
         
            +
                    self.normalize_f0_mean = normalize_f0_mean
         
     | 
| 323 | 
         
            +
                    self.normalize_f0_std = normalize_f0_std
         
     | 
| 324 | 
         
            +
                    self.interpolate_f0 = interpolate_f0
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    self.return_filename = return_filename
         
     | 
| 327 | 
         
            +
                    self.strip_filename = strip_filename
         
     | 
| 328 | 
         
            +
                    self.f0_code_ratio = config.code_hop_size / (
         
     | 
| 329 | 
         
            +
                        config.sampling_rate * F0_FRAME_SPACE
         
     | 
| 330 | 
         
            +
                    )
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    # use lazy loading to avoid sharing file handlers across workers
         
     | 
| 333 | 
         
            +
                    self.manifest = manifest
         
     | 
| 334 | 
         
            +
                    self._codes = None
         
     | 
| 335 | 
         
            +
                    self._durs = None
         
     | 
| 336 | 
         
            +
                    self._f0s = None
         
     | 
| 337 | 
         
            +
                    with open(f"{manifest}.leng.txt", "r") as f:
         
     | 
| 338 | 
         
            +
                        lengs = [int(line.rstrip()) for line in f]
         
     | 
| 339 | 
         
            +
                        edges = np.cumsum([0] + lengs)
         
     | 
| 340 | 
         
            +
                        self.starts, self.ends = edges[:-1], edges[1:]
         
     | 
| 341 | 
         
            +
                    with open(f"{manifest}.path.txt", "r") as f:
         
     | 
| 342 | 
         
            +
                        self.file_names = [line.rstrip() for line in f]
         
     | 
| 343 | 
         
            +
                    logger.info(f"num entries: {len(self.starts)}")
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    if os.path.exists(f"{manifest}.f0_stat.pt"):
         
     | 
| 346 | 
         
            +
                        self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
         
     | 
| 347 | 
         
            +
                    elif config.f0_stats:
         
     | 
| 348 | 
         
            +
                        self.f0_stats = torch.load(config.f0_stats)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    self.multispkr = config.multispkr
         
     | 
| 351 | 
         
            +
                    if config.multispkr:
         
     | 
| 352 | 
         
            +
                        with open(f"{manifest}.speaker.txt", "r") as f:
         
     | 
| 353 | 
         
            +
                            self.spkrs = [line.rstrip() for line in f]
         
     | 
| 354 | 
         
            +
                        self.id_to_spkr = sorted(self.spkrs)
         
     | 
| 355 | 
         
            +
                        self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    self.pads = Paddings(
         
     | 
| 358 | 
         
            +
                        dictionary.pad(),
         
     | 
| 359 | 
         
            +
                        0,  # use 0 for duration padding
         
     | 
| 360 | 
         
            +
                        f0_dictionary.pad() if discrete_f0 else -5.0,
         
     | 
| 361 | 
         
            +
                    )
         
     | 
| 362 | 
         
            +
                    self.shifts = Shifts(shifts, pads=self.pads)
         
     | 
| 363 | 
         
            +
                    self.return_continuous_f0 = return_continuous_f0
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                def get_data_handlers(self):
         
     | 
| 366 | 
         
            +
                    logging.info(f"loading data for {self.manifest}")
         
     | 
| 367 | 
         
            +
                    self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
         
     | 
| 368 | 
         
            +
                    self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    if self.discrete_f0:
         
     | 
| 371 | 
         
            +
                        if self.config.f0_vq_type == "precomp":
         
     | 
| 372 | 
         
            +
                            self._f0s = np.load(
         
     | 
| 373 | 
         
            +
                                f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
         
     | 
| 374 | 
         
            +
                            )
         
     | 
| 375 | 
         
            +
                        elif self.config.f0_vq_type == "naive":
         
     | 
| 376 | 
         
            +
                            self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
         
     | 
| 377 | 
         
            +
                            quantizers_path = self.config.get_f0_vq_naive_quantizer(
         
     | 
| 378 | 
         
            +
                                self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
         
     | 
| 379 | 
         
            +
                            )
         
     | 
| 380 | 
         
            +
                            quantizers = torch.load(quantizers_path)
         
     | 
| 381 | 
         
            +
                            n_units = self.config.f0_vq_n_units
         
     | 
| 382 | 
         
            +
                            self._f0_quantizer = torch.from_numpy(quantizers[n_units])
         
     | 
| 383 | 
         
            +
                        else:
         
     | 
| 384 | 
         
            +
                            raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
         
     | 
| 385 | 
         
            +
                    else:
         
     | 
| 386 | 
         
            +
                        self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                def preprocess_f0(self, f0, stats):
         
     | 
| 389 | 
         
            +
                    """
         
     | 
| 390 | 
         
            +
                    1. interpolate
         
     | 
| 391 | 
         
            +
                    2. log transform (keep unvoiced frame 0)
         
     | 
| 392 | 
         
            +
                    """
         
     | 
| 393 | 
         
            +
                    # TODO: change this to be dependent on config for naive quantizer
         
     | 
| 394 | 
         
            +
                    f0 = f0.clone()
         
     | 
| 395 | 
         
            +
                    if self.interpolate_f0:
         
     | 
| 396 | 
         
            +
                        f0 = interpolate_f0(f0)
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                    mask = f0 != 0  # only process voiced frames
         
     | 
| 399 | 
         
            +
                    if self.log_f0:
         
     | 
| 400 | 
         
            +
                        f0[mask] = f0[mask].log()
         
     | 
| 401 | 
         
            +
                    if self.normalize_f0_mean:
         
     | 
| 402 | 
         
            +
                        mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
         
     | 
| 403 | 
         
            +
                        f0[mask] = f0[mask] - mean
         
     | 
| 404 | 
         
            +
                    if self.normalize_f0_std:
         
     | 
| 405 | 
         
            +
                        std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
         
     | 
| 406 | 
         
            +
                        f0[mask] = f0[mask] / std
         
     | 
| 407 | 
         
            +
                    return f0
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                def _get_raw_item(self, index):
         
     | 
| 410 | 
         
            +
                    start, end = self.starts[index], self.ends[index]
         
     | 
| 411 | 
         
            +
                    if self._codes is None:
         
     | 
| 412 | 
         
            +
                        self.get_data_handlers()
         
     | 
| 413 | 
         
            +
                    code = torch.from_numpy(np.array(self._codes[start:end])).long()
         
     | 
| 414 | 
         
            +
                    dur = torch.from_numpy(np.array(self._durs[start:end]))
         
     | 
| 415 | 
         
            +
                    f0 = torch.from_numpy(np.array(self._f0s[start:end]))
         
     | 
| 416 | 
         
            +
                    return code, dur, f0
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 419 | 
         
            +
                    code, dur, f0 = self._get_raw_item(index)
         
     | 
| 420 | 
         
            +
                    code = torch.cat([code.new([self.dictionary.bos()]), code])
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
                    # use 0 for eos and bos
         
     | 
| 423 | 
         
            +
                    dur = torch.cat([dur.new([0]), dur])
         
     | 
| 424 | 
         
            +
                    if self.discrete_dur:
         
     | 
| 425 | 
         
            +
                        dur = self.dur_dictionary.encode_line(
         
     | 
| 426 | 
         
            +
                            " ".join(map(str, dur.tolist())), append_eos=False
         
     | 
| 427 | 
         
            +
                        ).long()
         
     | 
| 428 | 
         
            +
                    else:
         
     | 
| 429 | 
         
            +
                        dur = dur.float()
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                    # TODO: find a more elegant approach
         
     | 
| 432 | 
         
            +
                    raw_f0 = None
         
     | 
| 433 | 
         
            +
                    if self.discrete_f0:
         
     | 
| 434 | 
         
            +
                        if self.config.f0_vq_type == "precomp":
         
     | 
| 435 | 
         
            +
                            f0 = self.f0_dictionary.encode_line(
         
     | 
| 436 | 
         
            +
                                " ".join(map(str, f0.tolist())), append_eos=False
         
     | 
| 437 | 
         
            +
                            ).long()
         
     | 
| 438 | 
         
            +
                        else:
         
     | 
| 439 | 
         
            +
                            f0 = f0.float()
         
     | 
| 440 | 
         
            +
                            f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
         
     | 
| 441 | 
         
            +
                            if self.return_continuous_f0:
         
     | 
| 442 | 
         
            +
                                raw_f0 = f0
         
     | 
| 443 | 
         
            +
                                raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
         
     | 
| 444 | 
         
            +
                            f0 = naive_quantize(f0, self._f0_quantizer)
         
     | 
| 445 | 
         
            +
                        f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
         
     | 
| 446 | 
         
            +
                    else:
         
     | 
| 447 | 
         
            +
                        f0 = f0.float()
         
     | 
| 448 | 
         
            +
                        if self.multispkr:
         
     | 
| 449 | 
         
            +
                            f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
         
     | 
| 450 | 
         
            +
                        else:
         
     | 
| 451 | 
         
            +
                            f0 = self.preprocess_f0(f0, self.f0_stats)
         
     | 
| 452 | 
         
            +
                        f0 = torch.cat([f0.new([0]), f0])
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
                    if raw_f0 is not None:
         
     | 
| 455 | 
         
            +
                        *_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
         
     | 
| 456 | 
         
            +
                    else:
         
     | 
| 457 | 
         
            +
                        raw_f0_mask = None
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
                    code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
         
     | 
| 460 | 
         
            +
                    if raw_f0_mask is not None:
         
     | 
| 461 | 
         
            +
                        assert (raw_f0_mask == f0_mask).all()
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                    # is a padded frame if either input or output is padded
         
     | 
| 464 | 
         
            +
                    feats = {
         
     | 
| 465 | 
         
            +
                        "source": code[:-1],
         
     | 
| 466 | 
         
            +
                        "target": code[1:],
         
     | 
| 467 | 
         
            +
                        "mask": code_mask[1:].logical_or(code_mask[:-1]),
         
     | 
| 468 | 
         
            +
                        "dur_source": dur[:-1],
         
     | 
| 469 | 
         
            +
                        "dur_target": dur[1:],
         
     | 
| 470 | 
         
            +
                        "dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
         
     | 
| 471 | 
         
            +
                        "f0_source": f0[:-1],
         
     | 
| 472 | 
         
            +
                        "f0_target": f0[1:],
         
     | 
| 473 | 
         
            +
                        "f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
         
     | 
| 474 | 
         
            +
                    }
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                    if raw_f0 is not None:
         
     | 
| 477 | 
         
            +
                        feats["raw_f0"] = raw_f0[1:]
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                    if self.return_filename:
         
     | 
| 480 | 
         
            +
                        fname = self.file_names[index]
         
     | 
| 481 | 
         
            +
                        feats["filename"] = (
         
     | 
| 482 | 
         
            +
                            fname if not self.strip_filename else Path(fname).with_suffix("").name
         
     | 
| 483 | 
         
            +
                        )
         
     | 
| 484 | 
         
            +
                    return feats
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                def __len__(self):
         
     | 
| 487 | 
         
            +
                    return len(self.starts)
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                def size(self, index):
         
     | 
| 490 | 
         
            +
                    return self.ends[index] - self.starts[index] + self.shifts.extra_length
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                def num_tokens(self, index):
         
     | 
| 493 | 
         
            +
                    return self.size(index)
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                def collater(self, samples):
         
     | 
| 496 | 
         
            +
                    pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
         
     | 
| 497 | 
         
            +
                    if len(samples) == 0:
         
     | 
| 498 | 
         
            +
                        return {}
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                    src_tokens = data_utils.collate_tokens(
         
     | 
| 501 | 
         
            +
                        [s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
         
     | 
| 502 | 
         
            +
                    )
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
                    tgt_tokens = data_utils.collate_tokens(
         
     | 
| 505 | 
         
            +
                        [s["target"] for s in samples],
         
     | 
| 506 | 
         
            +
                        pad_idx=pad_idx,
         
     | 
| 507 | 
         
            +
                        eos_idx=pad_idx,  # appending padding, eos is there already
         
     | 
| 508 | 
         
            +
                        left_pad=False,
         
     | 
| 509 | 
         
            +
                    )
         
     | 
| 510 | 
         
            +
             
     | 
| 511 | 
         
            +
                    src_durs, tgt_durs = [
         
     | 
| 512 | 
         
            +
                        data_utils.collate_tokens(
         
     | 
| 513 | 
         
            +
                            [s[k] for s in samples],
         
     | 
| 514 | 
         
            +
                            pad_idx=self.pads.dur,
         
     | 
| 515 | 
         
            +
                            eos_idx=self.pads.dur,
         
     | 
| 516 | 
         
            +
                            left_pad=False,
         
     | 
| 517 | 
         
            +
                        )
         
     | 
| 518 | 
         
            +
                        for k in ["dur_source", "dur_target"]
         
     | 
| 519 | 
         
            +
                    ]
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
                    src_f0s, tgt_f0s = [
         
     | 
| 522 | 
         
            +
                        data_utils.collate_tokens(
         
     | 
| 523 | 
         
            +
                            [s[k] for s in samples],
         
     | 
| 524 | 
         
            +
                            pad_idx=self.pads.f0,
         
     | 
| 525 | 
         
            +
                            eos_idx=self.pads.f0,
         
     | 
| 526 | 
         
            +
                            left_pad=False,
         
     | 
| 527 | 
         
            +
                        )
         
     | 
| 528 | 
         
            +
                        for k in ["f0_source", "f0_target"]
         
     | 
| 529 | 
         
            +
                    ]
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                    mask, dur_mask, f0_mask = [
         
     | 
| 532 | 
         
            +
                        data_utils.collate_tokens(
         
     | 
| 533 | 
         
            +
                            [s[k] for s in samples],
         
     | 
| 534 | 
         
            +
                            pad_idx=1,
         
     | 
| 535 | 
         
            +
                            eos_idx=1,
         
     | 
| 536 | 
         
            +
                            left_pad=False,
         
     | 
| 537 | 
         
            +
                        )
         
     | 
| 538 | 
         
            +
                        for k in ["mask", "dur_mask", "f0_mask"]
         
     | 
| 539 | 
         
            +
                    ]
         
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
                    src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
         
     | 
| 542 | 
         
            +
                    n_tokens = sum(len(s["source"]) for s in samples)
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
                    result = {
         
     | 
| 545 | 
         
            +
                        "nsentences": len(samples),
         
     | 
| 546 | 
         
            +
                        "ntokens": n_tokens,
         
     | 
| 547 | 
         
            +
                        "net_input": {
         
     | 
| 548 | 
         
            +
                            "src_tokens": src_tokens,
         
     | 
| 549 | 
         
            +
                            "src_lengths": src_lengths,
         
     | 
| 550 | 
         
            +
                            "dur_src": src_durs,
         
     | 
| 551 | 
         
            +
                            "f0_src": src_f0s,
         
     | 
| 552 | 
         
            +
                        },
         
     | 
| 553 | 
         
            +
                        "target": tgt_tokens,
         
     | 
| 554 | 
         
            +
                        "dur_target": tgt_durs,
         
     | 
| 555 | 
         
            +
                        "f0_target": tgt_f0s,
         
     | 
| 556 | 
         
            +
                        "mask": mask,
         
     | 
| 557 | 
         
            +
                        "dur_mask": dur_mask,
         
     | 
| 558 | 
         
            +
                        "f0_mask": f0_mask,
         
     | 
| 559 | 
         
            +
                    }
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                    if "filename" in samples[0]:
         
     | 
| 562 | 
         
            +
                        result["filename"] = [s["filename"] for s in samples]
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                    # TODO: remove this hack into the inference dataset
         
     | 
| 565 | 
         
            +
                    if "prefix" in samples[0]:
         
     | 
| 566 | 
         
            +
                        result["prefix"] = [s["prefix"] for s in samples]
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
                    if "raw_f0" in samples[0]:
         
     | 
| 569 | 
         
            +
                        raw_f0s = data_utils.collate_tokens(
         
     | 
| 570 | 
         
            +
                            [s["raw_f0"] for s in samples],
         
     | 
| 571 | 
         
            +
                            pad_idx=self.pads.f0,
         
     | 
| 572 | 
         
            +
                            eos_idx=self.pads.f0,
         
     | 
| 573 | 
         
            +
                            left_pad=False,
         
     | 
| 574 | 
         
            +
                        )
         
     | 
| 575 | 
         
            +
                        result["raw_f0"] = raw_f0s
         
     | 
| 576 | 
         
            +
                    return result
         
     | 
    	
        modules/voice_conversion/fairseq/data/colorize_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from . import BaseWrapperDataset
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class ColorizeDataset(BaseWrapperDataset):
         
     | 
| 12 | 
         
            +
                """Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def __init__(self, dataset, color_getter):
         
     | 
| 15 | 
         
            +
                    super().__init__(dataset)
         
     | 
| 16 | 
         
            +
                    self.color_getter = color_getter
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def collater(self, samples):
         
     | 
| 19 | 
         
            +
                    base_collate = super().collater(samples)
         
     | 
| 20 | 
         
            +
                    if len(base_collate) > 0:
         
     | 
| 21 | 
         
            +
                        base_collate["net_input"]["colors"] = torch.tensor(
         
     | 
| 22 | 
         
            +
                            list(self.color_getter(self.dataset, s["id"]) for s in samples),
         
     | 
| 23 | 
         
            +
                            dtype=torch.long,
         
     | 
| 24 | 
         
            +
                        )
         
     | 
| 25 | 
         
            +
                    return base_collate
         
     | 
    	
        modules/voice_conversion/fairseq/data/concat_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,124 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import bisect
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            from torch.utils.data.dataloader import default_collate
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from . import FairseqDataset
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class ConcatDataset(FairseqDataset):
         
     | 
| 15 | 
         
            +
                @staticmethod
         
     | 
| 16 | 
         
            +
                def cumsum(sequence, sample_ratios):
         
     | 
| 17 | 
         
            +
                    r, s = [], 0
         
     | 
| 18 | 
         
            +
                    for e, ratio in zip(sequence, sample_ratios):
         
     | 
| 19 | 
         
            +
                        curr_len = int(ratio * len(e))
         
     | 
| 20 | 
         
            +
                        r.append(curr_len + s)
         
     | 
| 21 | 
         
            +
                        s += curr_len
         
     | 
| 22 | 
         
            +
                    return r
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def __init__(self, datasets, sample_ratios=1):
         
     | 
| 25 | 
         
            +
                    super(ConcatDataset, self).__init__()
         
     | 
| 26 | 
         
            +
                    assert len(datasets) > 0, "datasets should not be an empty iterable"
         
     | 
| 27 | 
         
            +
                    self.datasets = list(datasets)
         
     | 
| 28 | 
         
            +
                    if isinstance(sample_ratios, int):
         
     | 
| 29 | 
         
            +
                        sample_ratios = [sample_ratios] * len(self.datasets)
         
     | 
| 30 | 
         
            +
                    self.sample_ratios = sample_ratios
         
     | 
| 31 | 
         
            +
                    self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
         
     | 
| 32 | 
         
            +
                    self.real_sizes = [len(d) for d in self.datasets]
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def __len__(self):
         
     | 
| 35 | 
         
            +
                    return self.cumulative_sizes[-1]
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 38 | 
         
            +
                    dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
         
     | 
| 39 | 
         
            +
                    return self.datasets[dataset_idx][sample_idx]
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def _get_dataset_and_sample_index(self, idx: int):
         
     | 
| 42 | 
         
            +
                    dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
         
     | 
| 43 | 
         
            +
                    if dataset_idx == 0:
         
     | 
| 44 | 
         
            +
                        sample_idx = idx
         
     | 
| 45 | 
         
            +
                    else:
         
     | 
| 46 | 
         
            +
                        sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
         
     | 
| 47 | 
         
            +
                    sample_idx = sample_idx % self.real_sizes[dataset_idx]
         
     | 
| 48 | 
         
            +
                    return dataset_idx, sample_idx
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def collater(self, samples, **extra_args):
         
     | 
| 51 | 
         
            +
                    # For now only supports datasets with same underlying collater implementations
         
     | 
| 52 | 
         
            +
                    if hasattr(self.datasets[0], "collater"):
         
     | 
| 53 | 
         
            +
                        return self.datasets[0].collater(samples, **extra_args)
         
     | 
| 54 | 
         
            +
                    else:
         
     | 
| 55 | 
         
            +
                        return default_collate(samples, **extra_args)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def size(self, idx: int):
         
     | 
| 58 | 
         
            +
                    """
         
     | 
| 59 | 
         
            +
                    Return an example's size as a float or tuple.
         
     | 
| 60 | 
         
            +
                    """
         
     | 
| 61 | 
         
            +
                    dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
         
     | 
| 62 | 
         
            +
                    return self.datasets[dataset_idx].size(sample_idx)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                def num_tokens(self, index: int):
         
     | 
| 65 | 
         
            +
                    return np.max(self.size(index))
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def attr(self, attr: str, index: int):
         
     | 
| 68 | 
         
            +
                    dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
         
     | 
| 69 | 
         
            +
                    return getattr(self.datasets[dataset_idx], attr, None)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                @property
         
     | 
| 72 | 
         
            +
                def sizes(self):
         
     | 
| 73 | 
         
            +
                    _dataset_sizes = []
         
     | 
| 74 | 
         
            +
                    for ds, sr in zip(self.datasets, self.sample_ratios):
         
     | 
| 75 | 
         
            +
                        if isinstance(ds.sizes, np.ndarray):
         
     | 
| 76 | 
         
            +
                            _dataset_sizes.append(np.tile(ds.sizes, sr))
         
     | 
| 77 | 
         
            +
                        else:
         
     | 
| 78 | 
         
            +
                            # Only support underlying dataset with single size array.
         
     | 
| 79 | 
         
            +
                            assert isinstance(ds.sizes, list)
         
     | 
| 80 | 
         
            +
                            _dataset_sizes.append(np.tile(ds.sizes[0], sr))
         
     | 
| 81 | 
         
            +
                    return np.concatenate(_dataset_sizes)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                @property
         
     | 
| 84 | 
         
            +
                def supports_prefetch(self):
         
     | 
| 85 | 
         
            +
                    return all(d.supports_prefetch for d in self.datasets)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                def ordered_indices(self):
         
     | 
| 88 | 
         
            +
                    """
         
     | 
| 89 | 
         
            +
                    Returns indices sorted by length. So less padding is needed.
         
     | 
| 90 | 
         
            +
                    """
         
     | 
| 91 | 
         
            +
                    if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
         
     | 
| 92 | 
         
            +
                        # special handling for concatenating lang_pair_datasets
         
     | 
| 93 | 
         
            +
                        indices = np.arange(len(self))
         
     | 
| 94 | 
         
            +
                        sizes = self.sizes
         
     | 
| 95 | 
         
            +
                        tgt_sizes = (
         
     | 
| 96 | 
         
            +
                            sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
         
     | 
| 97 | 
         
            +
                        )
         
     | 
| 98 | 
         
            +
                        src_sizes = (
         
     | 
| 99 | 
         
            +
                            sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
         
     | 
| 100 | 
         
            +
                        )
         
     | 
| 101 | 
         
            +
                        # sort by target length, then source length
         
     | 
| 102 | 
         
            +
                        if tgt_sizes is not None:
         
     | 
| 103 | 
         
            +
                            indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
         
     | 
| 104 | 
         
            +
                        return indices[np.argsort(src_sizes[indices], kind="mergesort")]
         
     | 
| 105 | 
         
            +
                    else:
         
     | 
| 106 | 
         
            +
                        return np.argsort(self.sizes)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def prefetch(self, indices):
         
     | 
| 109 | 
         
            +
                    frm = 0
         
     | 
| 110 | 
         
            +
                    for to, ds in zip(self.cumulative_sizes, self.datasets):
         
     | 
| 111 | 
         
            +
                        real_size = len(ds)
         
     | 
| 112 | 
         
            +
                        if getattr(ds, "supports_prefetch", False):
         
     | 
| 113 | 
         
            +
                            ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
         
     | 
| 114 | 
         
            +
                        frm = to
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                @property
         
     | 
| 117 | 
         
            +
                def can_reuse_epoch_itr_across_epochs(self):
         
     | 
| 118 | 
         
            +
                    return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                def set_epoch(self, epoch):
         
     | 
| 121 | 
         
            +
                    super().set_epoch(epoch)
         
     | 
| 122 | 
         
            +
                    for ds in self.datasets:
         
     | 
| 123 | 
         
            +
                        if hasattr(ds, "set_epoch"):
         
     | 
| 124 | 
         
            +
                            ds.set_epoch(epoch)
         
     |