Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						26925fd
	
0
								Parent(s):
							
							
Duplicate from zlc99/M4Singer
Browse filesCo-authored-by: Lichao Zhang <[email protected]>
This view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +34 -0
- README.md +14 -0
- checkpoints/m4singer_diff_e2e/config.yaml +348 -0
- checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt +3 -0
- checkpoints/m4singer_fs2_e2e/config.yaml +347 -0
- checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt +3 -0
- checkpoints/m4singer_hifigan/config.yaml +246 -0
- checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt +3 -0
- checkpoints/m4singer_pe/config.yaml +172 -0
- checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt +3 -0
- configs/config_base.yaml +42 -0
- configs/singing/base.yaml +42 -0
- configs/singing/fs2.yaml +3 -0
- configs/tts/base.yaml +95 -0
- configs/tts/base_zh.yaml +3 -0
- configs/tts/fs2.yaml +80 -0
- configs/tts/hifigan.yaml +21 -0
- configs/tts/lj/base_mel2wav.yaml +3 -0
- configs/tts/lj/base_text2mel.yaml +13 -0
- configs/tts/lj/fs2.yaml +3 -0
- configs/tts/lj/hifigan.yaml +3 -0
- configs/tts/lj/pwg.yaml +3 -0
- configs/tts/pwg.yaml +110 -0
- data_gen/singing/binarize.py +393 -0
- data_gen/tts/base_binarizer.py +224 -0
- data_gen/tts/bin/binarize.py +20 -0
- data_gen/tts/binarizer_zh.py +59 -0
- data_gen/tts/data_gen_utils.py +347 -0
- data_gen/tts/txt_processors/base_text_processor.py +8 -0
- data_gen/tts/txt_processors/en.py +78 -0
- data_gen/tts/txt_processors/zh.py +41 -0
- data_gen/tts/txt_processors/zh_g2pM.py +71 -0
- inference/m4singer/base_svs_infer.py +242 -0
- inference/m4singer/ds_e2e.py +67 -0
- inference/m4singer/gradio/gradio_settings.yaml +48 -0
- inference/m4singer/gradio/infer.py +143 -0
- inference/m4singer/gradio/share_btn.py +86 -0
- inference/m4singer/m4singer/m4singer_pinyin2ph.txt +413 -0
- inference/m4singer/m4singer/map.py +7 -0
- modules/__init__.py +0 -0
- modules/commons/common_layers.py +668 -0
- modules/commons/espnet_positional_embedding.py +113 -0
- modules/commons/ssim.py +391 -0
- modules/diffsinger_midi/fs2.py +118 -0
- modules/fastspeech/fs2.py +255 -0
- modules/fastspeech/pe.py +149 -0
- modules/fastspeech/tts_modules.py +357 -0
- modules/hifigan/hifigan.py +370 -0
- modules/hifigan/mel_utils.py +81 -0
- modules/parallel_wavegan/__init__.py +0 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: 111
         | 
| 3 | 
            +
            emoji: 🎶
         | 
| 4 | 
            +
            colorFrom: yellow
         | 
| 5 | 
            +
            colorTo: green
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 3.8.1
         | 
| 8 | 
            +
            app_file: inference/m4singer/gradio/infer.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            duplicated_from: zlc99/M4Singer
         | 
| 11 | 
            +
            ---
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
| 14 | 
            +
             | 
    	
        checkpoints/m4singer_diff_e2e/config.yaml
    ADDED
    
    | @@ -0,0 +1,348 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            K_step: 1000
         | 
| 2 | 
            +
            accumulate_grad_batches: 1
         | 
| 3 | 
            +
            audio_num_mel_bins: 80
         | 
| 4 | 
            +
            audio_sample_rate: 24000
         | 
| 5 | 
            +
            base_config:
         | 
| 6 | 
            +
            - usr/configs/m4singer/base.yaml
         | 
| 7 | 
            +
            binarization_args:
         | 
| 8 | 
            +
              shuffle: false
         | 
| 9 | 
            +
              with_align: true
         | 
| 10 | 
            +
              with_f0: true
         | 
| 11 | 
            +
              with_f0cwt: true
         | 
| 12 | 
            +
              with_spk_embed: true
         | 
| 13 | 
            +
              with_txt: true
         | 
| 14 | 
            +
              with_wav: false
         | 
| 15 | 
            +
            binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer
         | 
| 16 | 
            +
            binary_data_dir: data/binary/m4singer
         | 
| 17 | 
            +
            check_val_every_n_epoch: 10
         | 
| 18 | 
            +
            clip_grad_norm: 1
         | 
| 19 | 
            +
            content_cond_steps: []
         | 
| 20 | 
            +
            cwt_add_f0_loss: false
         | 
| 21 | 
            +
            cwt_hidden_size: 128
         | 
| 22 | 
            +
            cwt_layers: 2
         | 
| 23 | 
            +
            cwt_loss: l1
         | 
| 24 | 
            +
            cwt_std_scale: 0.8
         | 
| 25 | 
            +
            datasets:
         | 
| 26 | 
            +
            - m4singer
         | 
| 27 | 
            +
            debug: false
         | 
| 28 | 
            +
            dec_ffn_kernel_size: 9
         | 
| 29 | 
            +
            dec_layers: 4
         | 
| 30 | 
            +
            decay_steps: 100000
         | 
| 31 | 
            +
            decoder_type: fft
         | 
| 32 | 
            +
            dict_dir: ''
         | 
| 33 | 
            +
            diff_decoder_type: wavenet
         | 
| 34 | 
            +
            diff_loss_type: l1
         | 
| 35 | 
            +
            dilation_cycle_length: 4
         | 
| 36 | 
            +
            dropout: 0.1
         | 
| 37 | 
            +
            ds_workers: 4
         | 
| 38 | 
            +
            dur_enc_hidden_stride_kernel:
         | 
| 39 | 
            +
            - 0,2,3
         | 
| 40 | 
            +
            - 0,2,3
         | 
| 41 | 
            +
            - 0,1,3
         | 
| 42 | 
            +
            dur_loss: mse
         | 
| 43 | 
            +
            dur_predictor_kernel: 3
         | 
| 44 | 
            +
            dur_predictor_layers: 5
         | 
| 45 | 
            +
            enc_ffn_kernel_size: 9
         | 
| 46 | 
            +
            enc_layers: 4
         | 
| 47 | 
            +
            encoder_K: 8
         | 
| 48 | 
            +
            encoder_type: fft
         | 
| 49 | 
            +
            endless_ds: true
         | 
| 50 | 
            +
            ffn_act: gelu
         | 
| 51 | 
            +
            ffn_padding: SAME
         | 
| 52 | 
            +
            fft_size: 512
         | 
| 53 | 
            +
            fmax: 12000
         | 
| 54 | 
            +
            fmin: 30
         | 
| 55 | 
            +
            fs2_ckpt: checkpoints/m4singer_fs2_e2e
         | 
| 56 | 
            +
            gaussian_start: true
         | 
| 57 | 
            +
            gen_dir_name: ''
         | 
| 58 | 
            +
            gen_tgt_spk_id: -1
         | 
| 59 | 
            +
            hidden_size: 256
         | 
| 60 | 
            +
            hop_size: 128
         | 
| 61 | 
            +
            infer: false
         | 
| 62 | 
            +
            keep_bins: 80
         | 
| 63 | 
            +
            lambda_commit: 0.25
         | 
| 64 | 
            +
            lambda_energy: 0.0
         | 
| 65 | 
            +
            lambda_f0: 0.0
         | 
| 66 | 
            +
            lambda_ph_dur: 1.0
         | 
| 67 | 
            +
            lambda_sent_dur: 1.0
         | 
| 68 | 
            +
            lambda_uv: 0.0
         | 
| 69 | 
            +
            lambda_word_dur: 1.0
         | 
| 70 | 
            +
            load_ckpt: ''
         | 
| 71 | 
            +
            log_interval: 100
         | 
| 72 | 
            +
            loud_norm: false
         | 
| 73 | 
            +
            lr: 0.001
         | 
| 74 | 
            +
            max_beta: 0.02
         | 
| 75 | 
            +
            max_epochs: 1000
         | 
| 76 | 
            +
            max_eval_sentences: 1
         | 
| 77 | 
            +
            max_eval_tokens: 60000
         | 
| 78 | 
            +
            max_frames: 5000
         | 
| 79 | 
            +
            max_input_tokens: 1550
         | 
| 80 | 
            +
            max_sentences: 28
         | 
| 81 | 
            +
            max_tokens: 36000
         | 
| 82 | 
            +
            max_updates: 900000
         | 
| 83 | 
            +
            mel_loss: ssim:0.5|l1:0.5
         | 
| 84 | 
            +
            mel_vmax: 1.5
         | 
| 85 | 
            +
            mel_vmin: -6.0
         | 
| 86 | 
            +
            min_level_db: -120
         | 
| 87 | 
            +
            norm_type: gn
         | 
| 88 | 
            +
            num_ckpt_keep: 3
         | 
| 89 | 
            +
            num_heads: 2
         | 
| 90 | 
            +
            num_sanity_val_steps: 1
         | 
| 91 | 
            +
            num_spk: 20
         | 
| 92 | 
            +
            num_test_samples: 0
         | 
| 93 | 
            +
            num_valid_plots: 10
         | 
| 94 | 
            +
            optimizer_adam_beta1: 0.9
         | 
| 95 | 
            +
            optimizer_adam_beta2: 0.98
         | 
| 96 | 
            +
            out_wav_norm: false
         | 
| 97 | 
            +
            pe_ckpt: checkpoints/m4singer_pe
         | 
| 98 | 
            +
            pe_enable: true
         | 
| 99 | 
            +
            pitch_ar: false
         | 
| 100 | 
            +
            pitch_enc_hidden_stride_kernel:
         | 
| 101 | 
            +
            - 0,2,5
         | 
| 102 | 
            +
            - 0,2,5
         | 
| 103 | 
            +
            - 0,2,5
         | 
| 104 | 
            +
            pitch_extractor: parselmouth
         | 
| 105 | 
            +
            pitch_loss: l1
         | 
| 106 | 
            +
            pitch_norm: log
         | 
| 107 | 
            +
            pitch_type: frame
         | 
| 108 | 
            +
            pndm_speedup: 10
         | 
| 109 | 
            +
            pre_align_args:
         | 
| 110 | 
            +
              allow_no_txt: false
         | 
| 111 | 
            +
              denoise: false
         | 
| 112 | 
            +
              forced_align: mfa
         | 
| 113 | 
            +
              txt_processor: zh_g2pM
         | 
| 114 | 
            +
              use_sox: true
         | 
| 115 | 
            +
              use_tone: false
         | 
| 116 | 
            +
            pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
         | 
| 117 | 
            +
            predictor_dropout: 0.5
         | 
| 118 | 
            +
            predictor_grad: 0.1
         | 
| 119 | 
            +
            predictor_hidden: -1
         | 
| 120 | 
            +
            predictor_kernel: 5
         | 
| 121 | 
            +
            predictor_layers: 5
         | 
| 122 | 
            +
            prenet_dropout: 0.5
         | 
| 123 | 
            +
            prenet_hidden_size: 256
         | 
| 124 | 
            +
            pretrain_fs_ckpt: ''
         | 
| 125 | 
            +
            processed_data_dir: xxx
         | 
| 126 | 
            +
            profile_infer: false
         | 
| 127 | 
            +
            raw_data_dir: data/raw/m4singer
         | 
| 128 | 
            +
            ref_norm_layer: bn
         | 
| 129 | 
            +
            rel_pos: true
         | 
| 130 | 
            +
            reset_phone_dict: true
         | 
| 131 | 
            +
            residual_channels: 256
         | 
| 132 | 
            +
            residual_layers: 20
         | 
| 133 | 
            +
            save_best: false
         | 
| 134 | 
            +
            save_ckpt: true
         | 
| 135 | 
            +
            save_codes:
         | 
| 136 | 
            +
            - configs
         | 
| 137 | 
            +
            - modules
         | 
| 138 | 
            +
            - tasks
         | 
| 139 | 
            +
            - utils
         | 
| 140 | 
            +
            - usr
         | 
| 141 | 
            +
            save_f0: true
         | 
| 142 | 
            +
            save_gt: true
         | 
| 143 | 
            +
            schedule_type: linear
         | 
| 144 | 
            +
            seed: 1234
         | 
| 145 | 
            +
            sort_by_len: true
         | 
| 146 | 
            +
            spec_max:
         | 
| 147 | 
            +
            - -0.3894500136375427
         | 
| 148 | 
            +
            - -0.3796464204788208
         | 
| 149 | 
            +
            - -0.2914905250072479
         | 
| 150 | 
            +
            - -0.15550297498703003
         | 
| 151 | 
            +
            - -0.08502643555402756
         | 
| 152 | 
            +
            - 0.10698417574167252
         | 
| 153 | 
            +
            - -0.0739326998591423
         | 
| 154 | 
            +
            - -0.0541548952460289
         | 
| 155 | 
            +
            - 0.15501998364925385
         | 
| 156 | 
            +
            - 0.06483431905508041
         | 
| 157 | 
            +
            - 0.03054228238761425
         | 
| 158 | 
            +
            - -0.013737732544541359
         | 
| 159 | 
            +
            - -0.004876468330621719
         | 
| 160 | 
            +
            - 0.04368264228105545
         | 
| 161 | 
            +
            - 0.13329921662807465
         | 
| 162 | 
            +
            - 0.16471388936042786
         | 
| 163 | 
            +
            - 0.04605761915445328
         | 
| 164 | 
            +
            - -0.05680707097053528
         | 
| 165 | 
            +
            - 0.0542571023106575
         | 
| 166 | 
            +
            - -0.0076539707370102406
         | 
| 167 | 
            +
            - -0.00953489076346159
         | 
| 168 | 
            +
            - -0.04434828832745552
         | 
| 169 | 
            +
            - 0.001293870504014194
         | 
| 170 | 
            +
            - -0.12238839268684387
         | 
| 171 | 
            +
            - 0.06418416649103165
         | 
| 172 | 
            +
            - 0.02843189612030983
         | 
| 173 | 
            +
            - 0.08505241572856903
         | 
| 174 | 
            +
            - 0.07062800228595734
         | 
| 175 | 
            +
            - 0.00120724702719599
         | 
| 176 | 
            +
            - -0.07675088942050934
         | 
| 177 | 
            +
            - 0.03785804659128189
         | 
| 178 | 
            +
            - 0.04890783503651619
         | 
| 179 | 
            +
            - -0.06888376921415329
         | 
| 180 | 
            +
            - -0.0839693546295166
         | 
| 181 | 
            +
            - -0.17545585334300995
         | 
| 182 | 
            +
            - -0.2911079525947571
         | 
| 183 | 
            +
            - -0.4238220453262329
         | 
| 184 | 
            +
            - -0.262084037065506
         | 
| 185 | 
            +
            - -0.3002263605594635
         | 
| 186 | 
            +
            - -0.3845032751560211
         | 
| 187 | 
            +
            - -0.3906497061252594
         | 
| 188 | 
            +
            - -0.6550108790397644
         | 
| 189 | 
            +
            - -0.7810799479484558
         | 
| 190 | 
            +
            - -0.7503029704093933
         | 
| 191 | 
            +
            - -0.7995198965072632
         | 
| 192 | 
            +
            - -0.8092347383499146
         | 
| 193 | 
            +
            - -0.6196113228797913
         | 
| 194 | 
            +
            - -0.6684317588806152
         | 
| 195 | 
            +
            - -0.7735874056816101
         | 
| 196 | 
            +
            - -0.8324533104896545
         | 
| 197 | 
            +
            - -0.9601566791534424
         | 
| 198 | 
            +
            - -0.955253541469574
         | 
| 199 | 
            +
            - -0.748817503452301
         | 
| 200 | 
            +
            - -0.9106167554855347
         | 
| 201 | 
            +
            - -0.9707801342010498
         | 
| 202 | 
            +
            - -1.053107500076294
         | 
| 203 | 
            +
            - -1.0448424816131592
         | 
| 204 | 
            +
            - -1.1082794666290283
         | 
| 205 | 
            +
            - -1.1296544075012207
         | 
| 206 | 
            +
            - -1.071642279624939
         | 
| 207 | 
            +
            - -1.1003081798553467
         | 
| 208 | 
            +
            - -1.166810154914856
         | 
| 209 | 
            +
            - -1.1408926248550415
         | 
| 210 | 
            +
            - -1.1330615282058716
         | 
| 211 | 
            +
            - -1.1167492866516113
         | 
| 212 | 
            +
            - -1.0716774463653564
         | 
| 213 | 
            +
            - -1.035891056060791
         | 
| 214 | 
            +
            - -1.0092483758926392
         | 
| 215 | 
            +
            - -0.9675999879837036
         | 
| 216 | 
            +
            - -0.938962996006012
         | 
| 217 | 
            +
            - -1.0120564699172974
         | 
| 218 | 
            +
            - -0.9777995347976685
         | 
| 219 | 
            +
            - -1.029313564300537
         | 
| 220 | 
            +
            - -0.9459163546562195
         | 
| 221 | 
            +
            - -0.8519706130027771
         | 
| 222 | 
            +
            - -0.7751091122627258
         | 
| 223 | 
            +
            - -0.7933766841888428
         | 
| 224 | 
            +
            - -0.9019735455513
         | 
| 225 | 
            +
            - -0.9983296990394592
         | 
| 226 | 
            +
            - -1.505873441696167
         | 
| 227 | 
            +
            spec_min:
         | 
| 228 | 
            +
            - -6.0
         | 
| 229 | 
            +
            - -6.0
         | 
| 230 | 
            +
            - -6.0
         | 
| 231 | 
            +
            - -6.0
         | 
| 232 | 
            +
            - -6.0
         | 
| 233 | 
            +
            - -6.0
         | 
| 234 | 
            +
            - -6.0
         | 
| 235 | 
            +
            - -6.0
         | 
| 236 | 
            +
            - -6.0
         | 
| 237 | 
            +
            - -6.0
         | 
| 238 | 
            +
            - -6.0
         | 
| 239 | 
            +
            - -6.0
         | 
| 240 | 
            +
            - -6.0
         | 
| 241 | 
            +
            - -6.0
         | 
| 242 | 
            +
            - -6.0
         | 
| 243 | 
            +
            - -6.0
         | 
| 244 | 
            +
            - -6.0
         | 
| 245 | 
            +
            - -6.0
         | 
| 246 | 
            +
            - -6.0
         | 
| 247 | 
            +
            - -6.0
         | 
| 248 | 
            +
            - -6.0
         | 
| 249 | 
            +
            - -6.0
         | 
| 250 | 
            +
            - -6.0
         | 
| 251 | 
            +
            - -6.0
         | 
| 252 | 
            +
            - -6.0
         | 
| 253 | 
            +
            - -6.0
         | 
| 254 | 
            +
            - -6.0
         | 
| 255 | 
            +
            - -6.0
         | 
| 256 | 
            +
            - -6.0
         | 
| 257 | 
            +
            - -6.0
         | 
| 258 | 
            +
            - -6.0
         | 
| 259 | 
            +
            - -6.0
         | 
| 260 | 
            +
            - -6.0
         | 
| 261 | 
            +
            - -6.0
         | 
| 262 | 
            +
            - -6.0
         | 
| 263 | 
            +
            - -6.0
         | 
| 264 | 
            +
            - -6.0
         | 
| 265 | 
            +
            - -6.0
         | 
| 266 | 
            +
            - -6.0
         | 
| 267 | 
            +
            - -6.0
         | 
| 268 | 
            +
            - -6.0
         | 
| 269 | 
            +
            - -6.0
         | 
| 270 | 
            +
            - -6.0
         | 
| 271 | 
            +
            - -6.0
         | 
| 272 | 
            +
            - -6.0
         | 
| 273 | 
            +
            - -6.0
         | 
| 274 | 
            +
            - -6.0
         | 
| 275 | 
            +
            - -6.0
         | 
| 276 | 
            +
            - -6.0
         | 
| 277 | 
            +
            - -6.0
         | 
| 278 | 
            +
            - -6.0
         | 
| 279 | 
            +
            - -6.0
         | 
| 280 | 
            +
            - -6.0
         | 
| 281 | 
            +
            - -6.0
         | 
| 282 | 
            +
            - -6.0
         | 
| 283 | 
            +
            - -6.0
         | 
| 284 | 
            +
            - -6.0
         | 
| 285 | 
            +
            - -6.0
         | 
| 286 | 
            +
            - -6.0
         | 
| 287 | 
            +
            - -6.0
         | 
| 288 | 
            +
            - -6.0
         | 
| 289 | 
            +
            - -6.0
         | 
| 290 | 
            +
            - -6.0
         | 
| 291 | 
            +
            - -6.0
         | 
| 292 | 
            +
            - -6.0
         | 
| 293 | 
            +
            - -6.0
         | 
| 294 | 
            +
            - -6.0
         | 
| 295 | 
            +
            - -6.0
         | 
| 296 | 
            +
            - -6.0
         | 
| 297 | 
            +
            - -6.0
         | 
| 298 | 
            +
            - -6.0
         | 
| 299 | 
            +
            - -6.0
         | 
| 300 | 
            +
            - -6.0
         | 
| 301 | 
            +
            - -6.0
         | 
| 302 | 
            +
            - -6.0
         | 
| 303 | 
            +
            - -6.0
         | 
| 304 | 
            +
            - -6.0
         | 
| 305 | 
            +
            - -6.0
         | 
| 306 | 
            +
            - -6.0
         | 
| 307 | 
            +
            - -6.0
         | 
| 308 | 
            +
            spk_cond_steps: []
         | 
| 309 | 
            +
            stop_token_weight: 5.0
         | 
| 310 | 
            +
            task_cls: usr.diffsinger_task.DiffSingerMIDITask
         | 
| 311 | 
            +
            test_ids: []
         | 
| 312 | 
            +
            test_input_dir: ''
         | 
| 313 | 
            +
            test_num: 0
         | 
| 314 | 
            +
            test_prefixes:
         | 
| 315 | 
            +
            - "Alto-2#\u5C81\u6708\u795E\u5077"
         | 
| 316 | 
            +
            - "Alto-2#\u5947\u5999\u80FD\u529B\u6B4C"
         | 
| 317 | 
            +
            - "Tenor-1#\u4E00\u5343\u5E74\u4EE5\u540E"
         | 
| 318 | 
            +
            - "Tenor-1#\u7AE5\u8BDD"
         | 
| 319 | 
            +
            - "Tenor-2#\u6D88\u6101"
         | 
| 320 | 
            +
            - "Tenor-2#\u4E00\u8364\u4E00\u7D20"
         | 
| 321 | 
            +
            - "Soprano-1#\u5FF5\u5974\u5A07\u8D64\u58C1\u6000\u53E4"
         | 
| 322 | 
            +
            - "Soprano-1#\u95EE\u6625"
         | 
| 323 | 
            +
            test_set_name: test
         | 
| 324 | 
            +
            timesteps: 1000
         | 
| 325 | 
            +
            train_set_name: train
         | 
| 326 | 
            +
            use_denoise: false
         | 
| 327 | 
            +
            use_energy_embed: false
         | 
| 328 | 
            +
            use_gt_dur: false
         | 
| 329 | 
            +
            use_gt_f0: false
         | 
| 330 | 
            +
            use_midi: true
         | 
| 331 | 
            +
            use_nsf: true
         | 
| 332 | 
            +
            use_pitch_embed: false
         | 
| 333 | 
            +
            use_pos_embed: true
         | 
| 334 | 
            +
            use_spk_embed: false
         | 
| 335 | 
            +
            use_spk_id: true
         | 
| 336 | 
            +
            use_split_spk_id: false
         | 
| 337 | 
            +
            use_uv: true
         | 
| 338 | 
            +
            use_var_enc: false
         | 
| 339 | 
            +
            val_check_interval: 2000
         | 
| 340 | 
            +
            valid_num: 0
         | 
| 341 | 
            +
            valid_set_name: valid
         | 
| 342 | 
            +
            vocoder: vocoders.hifigan.HifiGAN
         | 
| 343 | 
            +
            vocoder_ckpt: checkpoints/m4singer_hifigan
         | 
| 344 | 
            +
            warmup_updates: 2000
         | 
| 345 | 
            +
            wav2spec_eps: 1e-6
         | 
| 346 | 
            +
            weight_decay: 0
         | 
| 347 | 
            +
            win_size: 512
         | 
| 348 | 
            +
            work_dir: checkpoints/m4singer_diff_e2e
         | 
    	
        checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:dbea4e8b9712d2cca54cc07915859472a17f2f3b97a86f33a6c9974192bb5b47
         | 
| 3 | 
            +
            size 392239086
         | 
    	
        checkpoints/m4singer_fs2_e2e/config.yaml
    ADDED
    
    | @@ -0,0 +1,347 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            K_step: 51
         | 
| 2 | 
            +
            accumulate_grad_batches: 1
         | 
| 3 | 
            +
            audio_num_mel_bins: 80
         | 
| 4 | 
            +
            audio_sample_rate: 24000
         | 
| 5 | 
            +
            base_config:
         | 
| 6 | 
            +
            - configs/singing/fs2.yaml
         | 
| 7 | 
            +
            - usr/configs/m4singer/base.yaml
         | 
| 8 | 
            +
            binarization_args:
         | 
| 9 | 
            +
              shuffle: false
         | 
| 10 | 
            +
              with_align: true
         | 
| 11 | 
            +
              with_f0: true
         | 
| 12 | 
            +
              with_f0cwt: true
         | 
| 13 | 
            +
              with_spk_embed: true
         | 
| 14 | 
            +
              with_txt: true
         | 
| 15 | 
            +
              with_wav: false
         | 
| 16 | 
            +
            binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer
         | 
| 17 | 
            +
            binary_data_dir: data/binary/m4singer
         | 
| 18 | 
            +
            check_val_every_n_epoch: 10
         | 
| 19 | 
            +
            clip_grad_norm: 1
         | 
| 20 | 
            +
            content_cond_steps: []
         | 
| 21 | 
            +
            cwt_add_f0_loss: false
         | 
| 22 | 
            +
            cwt_hidden_size: 128
         | 
| 23 | 
            +
            cwt_layers: 2
         | 
| 24 | 
            +
            cwt_loss: l1
         | 
| 25 | 
            +
            cwt_std_scale: 0.8
         | 
| 26 | 
            +
            datasets:
         | 
| 27 | 
            +
            - m4singer
         | 
| 28 | 
            +
            debug: false
         | 
| 29 | 
            +
            dec_ffn_kernel_size: 9
         | 
| 30 | 
            +
            dec_layers: 4
         | 
| 31 | 
            +
            decay_steps: 50000
         | 
| 32 | 
            +
            decoder_type: fft
         | 
| 33 | 
            +
            dict_dir: ''
         | 
| 34 | 
            +
            diff_decoder_type: wavenet
         | 
| 35 | 
            +
            diff_loss_type: l1
         | 
| 36 | 
            +
            dilation_cycle_length: 1
         | 
| 37 | 
            +
            dropout: 0.1
         | 
| 38 | 
            +
            ds_workers: 4
         | 
| 39 | 
            +
            dur_enc_hidden_stride_kernel:
         | 
| 40 | 
            +
            - 0,2,3
         | 
| 41 | 
            +
            - 0,2,3
         | 
| 42 | 
            +
            - 0,1,3
         | 
| 43 | 
            +
            dur_loss: mse
         | 
| 44 | 
            +
            dur_predictor_kernel: 3
         | 
| 45 | 
            +
            dur_predictor_layers: 5
         | 
| 46 | 
            +
            enc_ffn_kernel_size: 9
         | 
| 47 | 
            +
            enc_layers: 4
         | 
| 48 | 
            +
            encoder_K: 8
         | 
| 49 | 
            +
            encoder_type: fft
         | 
| 50 | 
            +
            endless_ds: true
         | 
| 51 | 
            +
            ffn_act: gelu
         | 
| 52 | 
            +
            ffn_padding: SAME
         | 
| 53 | 
            +
            fft_size: 512
         | 
| 54 | 
            +
            fmax: 12000
         | 
| 55 | 
            +
            fmin: 30
         | 
| 56 | 
            +
            fs2_ckpt: ''
         | 
| 57 | 
            +
            gen_dir_name: ''
         | 
| 58 | 
            +
            gen_tgt_spk_id: -1
         | 
| 59 | 
            +
            hidden_size: 256
         | 
| 60 | 
            +
            hop_size: 128
         | 
| 61 | 
            +
            infer: false
         | 
| 62 | 
            +
            keep_bins: 80
         | 
| 63 | 
            +
            lambda_commit: 0.25
         | 
| 64 | 
            +
            lambda_energy: 0.0
         | 
| 65 | 
            +
            lambda_f0: 1.0
         | 
| 66 | 
            +
            lambda_ph_dur: 1.0
         | 
| 67 | 
            +
            lambda_sent_dur: 1.0
         | 
| 68 | 
            +
            lambda_uv: 1.0
         | 
| 69 | 
            +
            lambda_word_dur: 1.0
         | 
| 70 | 
            +
            load_ckpt: ''
         | 
| 71 | 
            +
            log_interval: 100
         | 
| 72 | 
            +
            loud_norm: false
         | 
| 73 | 
            +
            lr: 1
         | 
| 74 | 
            +
            max_beta: 0.06
         | 
| 75 | 
            +
            max_epochs: 1000
         | 
| 76 | 
            +
            max_eval_sentences: 1
         | 
| 77 | 
            +
            max_eval_tokens: 60000
         | 
| 78 | 
            +
            max_frames: 5000
         | 
| 79 | 
            +
            max_input_tokens: 1550
         | 
| 80 | 
            +
            max_sentences: 12
         | 
| 81 | 
            +
            max_tokens: 40000
         | 
| 82 | 
            +
            max_updates: 320000
         | 
| 83 | 
            +
            mel_loss: ssim:0.5|l1:0.5
         | 
| 84 | 
            +
            mel_vmax: 1.5
         | 
| 85 | 
            +
            mel_vmin: -6.0
         | 
| 86 | 
            +
            min_level_db: -120
         | 
| 87 | 
            +
            norm_type: gn
         | 
| 88 | 
            +
            num_ckpt_keep: 3
         | 
| 89 | 
            +
            num_heads: 2
         | 
| 90 | 
            +
            num_sanity_val_steps: 1
         | 
| 91 | 
            +
            num_spk: 20
         | 
| 92 | 
            +
            num_test_samples: 0
         | 
| 93 | 
            +
            num_valid_plots: 10
         | 
| 94 | 
            +
            optimizer_adam_beta1: 0.9
         | 
| 95 | 
            +
            optimizer_adam_beta2: 0.98
         | 
| 96 | 
            +
            out_wav_norm: false
         | 
| 97 | 
            +
            pe_ckpt: checkpoints/m4singer_pe
         | 
| 98 | 
            +
            pe_enable: true
         | 
| 99 | 
            +
            pitch_ar: false
         | 
| 100 | 
            +
            pitch_enc_hidden_stride_kernel:
         | 
| 101 | 
            +
            - 0,2,5
         | 
| 102 | 
            +
            - 0,2,5
         | 
| 103 | 
            +
            - 0,2,5
         | 
| 104 | 
            +
            pitch_extractor: parselmouth
         | 
| 105 | 
            +
            pitch_loss: l1
         | 
| 106 | 
            +
            pitch_norm: log
         | 
| 107 | 
            +
            pitch_type: frame
         | 
| 108 | 
            +
            pre_align_args:
         | 
| 109 | 
            +
              allow_no_txt: false
         | 
| 110 | 
            +
              denoise: false
         | 
| 111 | 
            +
              forced_align: mfa
         | 
| 112 | 
            +
              txt_processor: zh_g2pM
         | 
| 113 | 
            +
              use_sox: true
         | 
| 114 | 
            +
              use_tone: false
         | 
| 115 | 
            +
            pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
         | 
| 116 | 
            +
            predictor_dropout: 0.5
         | 
| 117 | 
            +
            predictor_grad: 0.1
         | 
| 118 | 
            +
            predictor_hidden: -1
         | 
| 119 | 
            +
            predictor_kernel: 5
         | 
| 120 | 
            +
            predictor_layers: 5
         | 
| 121 | 
            +
            prenet_dropout: 0.5
         | 
| 122 | 
            +
            prenet_hidden_size: 256
         | 
| 123 | 
            +
            pretrain_fs_ckpt: ''
         | 
| 124 | 
            +
            processed_data_dir: xxx
         | 
| 125 | 
            +
            profile_infer: false
         | 
| 126 | 
            +
            raw_data_dir: data/raw/m4singer
         | 
| 127 | 
            +
            ref_norm_layer: bn
         | 
| 128 | 
            +
            rel_pos: true
         | 
| 129 | 
            +
            reset_phone_dict: true
         | 
| 130 | 
            +
            residual_channels: 256
         | 
| 131 | 
            +
            residual_layers: 20
         | 
| 132 | 
            +
            save_best: false
         | 
| 133 | 
            +
            save_ckpt: true
         | 
| 134 | 
            +
            save_codes:
         | 
| 135 | 
            +
            - configs
         | 
| 136 | 
            +
            - modules
         | 
| 137 | 
            +
            - tasks
         | 
| 138 | 
            +
            - utils
         | 
| 139 | 
            +
            - usr
         | 
| 140 | 
            +
            save_f0: true
         | 
| 141 | 
            +
            save_gt: true
         | 
| 142 | 
            +
            schedule_type: linear
         | 
| 143 | 
            +
            seed: 1234
         | 
| 144 | 
            +
            sort_by_len: true
         | 
| 145 | 
            +
            spec_max:
         | 
| 146 | 
            +
            - -0.3894500136375427
         | 
| 147 | 
            +
            - -0.3796464204788208
         | 
| 148 | 
            +
            - -0.2914905250072479
         | 
| 149 | 
            +
            - -0.15550297498703003
         | 
| 150 | 
            +
            - -0.08502643555402756
         | 
| 151 | 
            +
            - 0.10698417574167252
         | 
| 152 | 
            +
            - -0.0739326998591423
         | 
| 153 | 
            +
            - -0.0541548952460289
         | 
| 154 | 
            +
            - 0.15501998364925385
         | 
| 155 | 
            +
            - 0.06483431905508041
         | 
| 156 | 
            +
            - 0.03054228238761425
         | 
| 157 | 
            +
            - -0.013737732544541359
         | 
| 158 | 
            +
            - -0.004876468330621719
         | 
| 159 | 
            +
            - 0.04368264228105545
         | 
| 160 | 
            +
            - 0.13329921662807465
         | 
| 161 | 
            +
            - 0.16471388936042786
         | 
| 162 | 
            +
            - 0.04605761915445328
         | 
| 163 | 
            +
            - -0.05680707097053528
         | 
| 164 | 
            +
            - 0.0542571023106575
         | 
| 165 | 
            +
            - -0.0076539707370102406
         | 
| 166 | 
            +
            - -0.00953489076346159
         | 
| 167 | 
            +
            - -0.04434828832745552
         | 
| 168 | 
            +
            - 0.001293870504014194
         | 
| 169 | 
            +
            - -0.12238839268684387
         | 
| 170 | 
            +
            - 0.06418416649103165
         | 
| 171 | 
            +
            - 0.02843189612030983
         | 
| 172 | 
            +
            - 0.08505241572856903
         | 
| 173 | 
            +
            - 0.07062800228595734
         | 
| 174 | 
            +
            - 0.00120724702719599
         | 
| 175 | 
            +
            - -0.07675088942050934
         | 
| 176 | 
            +
            - 0.03785804659128189
         | 
| 177 | 
            +
            - 0.04890783503651619
         | 
| 178 | 
            +
            - -0.06888376921415329
         | 
| 179 | 
            +
            - -0.0839693546295166
         | 
| 180 | 
            +
            - -0.17545585334300995
         | 
| 181 | 
            +
            - -0.2911079525947571
         | 
| 182 | 
            +
            - -0.4238220453262329
         | 
| 183 | 
            +
            - -0.262084037065506
         | 
| 184 | 
            +
            - -0.3002263605594635
         | 
| 185 | 
            +
            - -0.3845032751560211
         | 
| 186 | 
            +
            - -0.3906497061252594
         | 
| 187 | 
            +
            - -0.6550108790397644
         | 
| 188 | 
            +
            - -0.7810799479484558
         | 
| 189 | 
            +
            - -0.7503029704093933
         | 
| 190 | 
            +
            - -0.7995198965072632
         | 
| 191 | 
            +
            - -0.8092347383499146
         | 
| 192 | 
            +
            - -0.6196113228797913
         | 
| 193 | 
            +
            - -0.6684317588806152
         | 
| 194 | 
            +
            - -0.7735874056816101
         | 
| 195 | 
            +
            - -0.8324533104896545
         | 
| 196 | 
            +
            - -0.9601566791534424
         | 
| 197 | 
            +
            - -0.955253541469574
         | 
| 198 | 
            +
            - -0.748817503452301
         | 
| 199 | 
            +
            - -0.9106167554855347
         | 
| 200 | 
            +
            - -0.9707801342010498
         | 
| 201 | 
            +
            - -1.053107500076294
         | 
| 202 | 
            +
            - -1.0448424816131592
         | 
| 203 | 
            +
            - -1.1082794666290283
         | 
| 204 | 
            +
            - -1.1296544075012207
         | 
| 205 | 
            +
            - -1.071642279624939
         | 
| 206 | 
            +
            - -1.1003081798553467
         | 
| 207 | 
            +
            - -1.166810154914856
         | 
| 208 | 
            +
            - -1.1408926248550415
         | 
| 209 | 
            +
            - -1.1330615282058716
         | 
| 210 | 
            +
            - -1.1167492866516113
         | 
| 211 | 
            +
            - -1.0716774463653564
         | 
| 212 | 
            +
            - -1.035891056060791
         | 
| 213 | 
            +
            - -1.0092483758926392
         | 
| 214 | 
            +
            - -0.9675999879837036
         | 
| 215 | 
            +
            - -0.938962996006012
         | 
| 216 | 
            +
            - -1.0120564699172974
         | 
| 217 | 
            +
            - -0.9777995347976685
         | 
| 218 | 
            +
            - -1.029313564300537
         | 
| 219 | 
            +
            - -0.9459163546562195
         | 
| 220 | 
            +
            - -0.8519706130027771
         | 
| 221 | 
            +
            - -0.7751091122627258
         | 
| 222 | 
            +
            - -0.7933766841888428
         | 
| 223 | 
            +
            - -0.9019735455513
         | 
| 224 | 
            +
            - -0.9983296990394592
         | 
| 225 | 
            +
            - -1.505873441696167
         | 
| 226 | 
            +
            spec_min:
         | 
| 227 | 
            +
            - -6.0
         | 
| 228 | 
            +
            - -6.0
         | 
| 229 | 
            +
            - -6.0
         | 
| 230 | 
            +
            - -6.0
         | 
| 231 | 
            +
            - -6.0
         | 
| 232 | 
            +
            - -6.0
         | 
| 233 | 
            +
            - -6.0
         | 
| 234 | 
            +
            - -6.0
         | 
| 235 | 
            +
            - -6.0
         | 
| 236 | 
            +
            - -6.0
         | 
| 237 | 
            +
            - -6.0
         | 
| 238 | 
            +
            - -6.0
         | 
| 239 | 
            +
            - -6.0
         | 
| 240 | 
            +
            - -6.0
         | 
| 241 | 
            +
            - -6.0
         | 
| 242 | 
            +
            - -6.0
         | 
| 243 | 
            +
            - -6.0
         | 
| 244 | 
            +
            - -6.0
         | 
| 245 | 
            +
            - -6.0
         | 
| 246 | 
            +
            - -6.0
         | 
| 247 | 
            +
            - -6.0
         | 
| 248 | 
            +
            - -6.0
         | 
| 249 | 
            +
            - -6.0
         | 
| 250 | 
            +
            - -6.0
         | 
| 251 | 
            +
            - -6.0
         | 
| 252 | 
            +
            - -6.0
         | 
| 253 | 
            +
            - -6.0
         | 
| 254 | 
            +
            - -6.0
         | 
| 255 | 
            +
            - -6.0
         | 
| 256 | 
            +
            - -6.0
         | 
| 257 | 
            +
            - -6.0
         | 
| 258 | 
            +
            - -6.0
         | 
| 259 | 
            +
            - -6.0
         | 
| 260 | 
            +
            - -6.0
         | 
| 261 | 
            +
            - -6.0
         | 
| 262 | 
            +
            - -6.0
         | 
| 263 | 
            +
            - -6.0
         | 
| 264 | 
            +
            - -6.0
         | 
| 265 | 
            +
            - -6.0
         | 
| 266 | 
            +
            - -6.0
         | 
| 267 | 
            +
            - -6.0
         | 
| 268 | 
            +
            - -6.0
         | 
| 269 | 
            +
            - -6.0
         | 
| 270 | 
            +
            - -6.0
         | 
| 271 | 
            +
            - -6.0
         | 
| 272 | 
            +
            - -6.0
         | 
| 273 | 
            +
            - -6.0
         | 
| 274 | 
            +
            - -6.0
         | 
| 275 | 
            +
            - -6.0
         | 
| 276 | 
            +
            - -6.0
         | 
| 277 | 
            +
            - -6.0
         | 
| 278 | 
            +
            - -6.0
         | 
| 279 | 
            +
            - -6.0
         | 
| 280 | 
            +
            - -6.0
         | 
| 281 | 
            +
            - -6.0
         | 
| 282 | 
            +
            - -6.0
         | 
| 283 | 
            +
            - -6.0
         | 
| 284 | 
            +
            - -6.0
         | 
| 285 | 
            +
            - -6.0
         | 
| 286 | 
            +
            - -6.0
         | 
| 287 | 
            +
            - -6.0
         | 
| 288 | 
            +
            - -6.0
         | 
| 289 | 
            +
            - -6.0
         | 
| 290 | 
            +
            - -6.0
         | 
| 291 | 
            +
            - -6.0
         | 
| 292 | 
            +
            - -6.0
         | 
| 293 | 
            +
            - -6.0
         | 
| 294 | 
            +
            - -6.0
         | 
| 295 | 
            +
            - -6.0
         | 
| 296 | 
            +
            - -6.0
         | 
| 297 | 
            +
            - -6.0
         | 
| 298 | 
            +
            - -6.0
         | 
| 299 | 
            +
            - -6.0
         | 
| 300 | 
            +
            - -6.0
         | 
| 301 | 
            +
            - -6.0
         | 
| 302 | 
            +
            - -6.0
         | 
| 303 | 
            +
            - -6.0
         | 
| 304 | 
            +
            - -6.0
         | 
| 305 | 
            +
            - -6.0
         | 
| 306 | 
            +
            - -6.0
         | 
| 307 | 
            +
            spk_cond_steps: []
         | 
| 308 | 
            +
            stop_token_weight: 5.0
         | 
| 309 | 
            +
            task_cls: usr.diffsinger_task.AuxDecoderMIDITask
         | 
| 310 | 
            +
            test_ids: []
         | 
| 311 | 
            +
            test_input_dir: ''
         | 
| 312 | 
            +
            test_num: 0
         | 
| 313 | 
            +
            test_prefixes:
         | 
| 314 | 
            +
            - "Alto-2#\u5C81\u6708\u795E\u5077"
         | 
| 315 | 
            +
            - "Alto-2#\u5947\u5999\u80FD\u529B\u6B4C"
         | 
| 316 | 
            +
            - "Tenor-1#\u4E00\u5343\u5E74\u4EE5\u540E"
         | 
| 317 | 
            +
            - "Tenor-1#\u7AE5\u8BDD"
         | 
| 318 | 
            +
            - "Tenor-2#\u6D88\u6101"
         | 
| 319 | 
            +
            - "Tenor-2#\u4E00\u8364\u4E00\u7D20"
         | 
| 320 | 
            +
            - "Soprano-1#\u5FF5\u5974\u5A07\u8D64\u58C1\u6000\u53E4"
         | 
| 321 | 
            +
            - "Soprano-1#\u95EE\u6625"
         | 
| 322 | 
            +
            test_set_name: test
         | 
| 323 | 
            +
            timesteps: 100
         | 
| 324 | 
            +
            train_set_name: train
         | 
| 325 | 
            +
            use_denoise: false
         | 
| 326 | 
            +
            use_energy_embed: false
         | 
| 327 | 
            +
            use_gt_dur: false
         | 
| 328 | 
            +
            use_gt_f0: false
         | 
| 329 | 
            +
            use_midi: true
         | 
| 330 | 
            +
            use_nsf: true
         | 
| 331 | 
            +
            use_pitch_embed: false
         | 
| 332 | 
            +
            use_pos_embed: true
         | 
| 333 | 
            +
            use_spk_embed: false
         | 
| 334 | 
            +
            use_spk_id: true
         | 
| 335 | 
            +
            use_split_spk_id: false
         | 
| 336 | 
            +
            use_uv: true
         | 
| 337 | 
            +
            use_var_enc: false
         | 
| 338 | 
            +
            val_check_interval: 2000
         | 
| 339 | 
            +
            valid_num: 0
         | 
| 340 | 
            +
            valid_set_name: valid
         | 
| 341 | 
            +
            vocoder: vocoders.hifigan.HifiGAN
         | 
| 342 | 
            +
            vocoder_ckpt: checkpoints/m4singer_hifigan
         | 
| 343 | 
            +
            warmup_updates: 2000
         | 
| 344 | 
            +
            wav2spec_eps: 1e-6
         | 
| 345 | 
            +
            weight_decay: 0
         | 
| 346 | 
            +
            win_size: 512
         | 
| 347 | 
            +
            work_dir: checkpoints/m4singer_fs2_e2e
         | 
    	
        checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:993d7063a1773bd29d2810591f98152218a4cf8440e2b10c4761516a28f9d566
         | 
| 3 | 
            +
            size 290456153
         | 
    	
        checkpoints/m4singer_hifigan/config.yaml
    ADDED
    
    | @@ -0,0 +1,246 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            max_eval_tokens: 60000
         | 
| 2 | 
            +
            max_eval_sentences: 1
         | 
| 3 | 
            +
            save_ckpt: true
         | 
| 4 | 
            +
            log_interval: 100
         | 
| 5 | 
            +
            accumulate_grad_batches: 1
         | 
| 6 | 
            +
            adam_b1: 0.8
         | 
| 7 | 
            +
            adam_b2: 0.99
         | 
| 8 | 
            +
            amp: false
         | 
| 9 | 
            +
            audio_num_mel_bins: 80
         | 
| 10 | 
            +
            audio_sample_rate: 24000
         | 
| 11 | 
            +
            aux_context_window: 0
         | 
| 12 | 
            +
            #base_config:
         | 
| 13 | 
            +
            #- egs/egs_bases/singing/pwg.yaml
         | 
| 14 | 
            +
            #- egs/egs_bases/tts/vocoder/hifigan.yaml
         | 
| 15 | 
            +
            binarization_args:
         | 
| 16 | 
            +
              reset_phone_dict: true
         | 
| 17 | 
            +
              reset_word_dict: true
         | 
| 18 | 
            +
              shuffle: false
         | 
| 19 | 
            +
              trim_eos_bos: false
         | 
| 20 | 
            +
              trim_sil: false
         | 
| 21 | 
            +
              with_align: false
         | 
| 22 | 
            +
              with_f0: true
         | 
| 23 | 
            +
              with_f0cwt: false
         | 
| 24 | 
            +
              with_linear: false
         | 
| 25 | 
            +
              with_spk_embed: false
         | 
| 26 | 
            +
              with_spk_id: true
         | 
| 27 | 
            +
              with_txt: false
         | 
| 28 | 
            +
              with_wav: true
         | 
| 29 | 
            +
              with_word: false
         | 
| 30 | 
            +
            binarizer_cls: data_gen.tts.singing.binarize.SingingBinarizer
         | 
| 31 | 
            +
            binary_data_dir: data/binary/m4singer_vocoder
         | 
| 32 | 
            +
            check_val_every_n_epoch: 10
         | 
| 33 | 
            +
            clip_grad_norm: 1
         | 
| 34 | 
            +
            clip_grad_value: 0
         | 
| 35 | 
            +
            datasets: []
         | 
| 36 | 
            +
            debug: false
         | 
| 37 | 
            +
            dec_ffn_kernel_size: 9
         | 
| 38 | 
            +
            dec_layers: 4
         | 
| 39 | 
            +
            dict_dir: ''
         | 
| 40 | 
            +
            disc_start_steps: 40000
         | 
| 41 | 
            +
            discriminator_grad_norm: 1
         | 
| 42 | 
            +
            discriminator_optimizer_params:
         | 
| 43 | 
            +
              eps: 1.0e-06
         | 
| 44 | 
            +
              lr: 0.0002
         | 
| 45 | 
            +
              weight_decay: 0.0
         | 
| 46 | 
            +
            discriminator_params:
         | 
| 47 | 
            +
              bias: true
         | 
| 48 | 
            +
              conv_channels: 64
         | 
| 49 | 
            +
              in_channels: 1
         | 
| 50 | 
            +
              kernel_size: 3
         | 
| 51 | 
            +
              layers: 10
         | 
| 52 | 
            +
              nonlinear_activation: LeakyReLU
         | 
| 53 | 
            +
              nonlinear_activation_params:
         | 
| 54 | 
            +
                negative_slope: 0.2
         | 
| 55 | 
            +
              out_channels: 1
         | 
| 56 | 
            +
              use_weight_norm: true
         | 
| 57 | 
            +
            discriminator_scheduler_params:
         | 
| 58 | 
            +
              gamma: 0.999
         | 
| 59 | 
            +
              step_size: 600
         | 
| 60 | 
            +
            dropout: 0.1
         | 
| 61 | 
            +
            ds_workers: 1
         | 
| 62 | 
            +
            enc_ffn_kernel_size: 9
         | 
| 63 | 
            +
            enc_layers: 4
         | 
| 64 | 
            +
            endless_ds: true
         | 
| 65 | 
            +
            ffn_act: gelu
         | 
| 66 | 
            +
            ffn_padding: SAME
         | 
| 67 | 
            +
            fft_size: 512
         | 
| 68 | 
            +
            fmax: 12000
         | 
| 69 | 
            +
            fmin: 30
         | 
| 70 | 
            +
            frames_multiple: 1
         | 
| 71 | 
            +
            gen_dir_name: ''
         | 
| 72 | 
            +
            generator_grad_norm: 10
         | 
| 73 | 
            +
            generator_optimizer_params:
         | 
| 74 | 
            +
              eps: 1.0e-06
         | 
| 75 | 
            +
              lr: 0.0002
         | 
| 76 | 
            +
              weight_decay: 0.0
         | 
| 77 | 
            +
            generator_params:
         | 
| 78 | 
            +
              aux_context_window: 0
         | 
| 79 | 
            +
              aux_channels: 80
         | 
| 80 | 
            +
              dropout: 0.0
         | 
| 81 | 
            +
              gate_channels: 128
         | 
| 82 | 
            +
              in_channels: 1
         | 
| 83 | 
            +
              kernel_size: 3
         | 
| 84 | 
            +
              layers: 30
         | 
| 85 | 
            +
              out_channels: 1
         | 
| 86 | 
            +
              residual_channels: 64
         | 
| 87 | 
            +
              skip_channels: 64
         | 
| 88 | 
            +
              stacks: 3
         | 
| 89 | 
            +
              upsample_net: ConvInUpsampleNetwork
         | 
| 90 | 
            +
              upsample_params:
         | 
| 91 | 
            +
                upsample_scales:
         | 
| 92 | 
            +
                - 2
         | 
| 93 | 
            +
                - 4
         | 
| 94 | 
            +
                - 4
         | 
| 95 | 
            +
                - 4
         | 
| 96 | 
            +
              use_nsf: false
         | 
| 97 | 
            +
              use_pitch_embed: true
         | 
| 98 | 
            +
              use_weight_norm: true
         | 
| 99 | 
            +
            generator_scheduler_params:
         | 
| 100 | 
            +
              gamma: 0.999
         | 
| 101 | 
            +
              step_size: 600
         | 
| 102 | 
            +
            griffin_lim_iters: 60
         | 
| 103 | 
            +
            hidden_size: 256
         | 
| 104 | 
            +
            hop_size: 128
         | 
| 105 | 
            +
            infer: false
         | 
| 106 | 
            +
            lambda_adv: 1.0
         | 
| 107 | 
            +
            lambda_cdisc: 4.0
         | 
| 108 | 
            +
            lambda_energy: 0.0
         | 
| 109 | 
            +
            lambda_f0: 0.0
         | 
| 110 | 
            +
            lambda_mel: 5.0
         | 
| 111 | 
            +
            lambda_mel_adv: 1.0
         | 
| 112 | 
            +
            lambda_ph_dur: 0.0
         | 
| 113 | 
            +
            lambda_sent_dur: 0.0
         | 
| 114 | 
            +
            lambda_uv: 0.0
         | 
| 115 | 
            +
            lambda_word_dur: 0.0
         | 
| 116 | 
            +
            load_ckpt: 'checkpoints/m4singer_hifigan'
         | 
| 117 | 
            +
            loud_norm: false
         | 
| 118 | 
            +
            lr: 2.0
         | 
| 119 | 
            +
            max_epochs: 1000
         | 
| 120 | 
            +
            max_frames: 2400
         | 
| 121 | 
            +
            max_input_tokens: 1550
         | 
| 122 | 
            +
            max_samples: 8192
         | 
| 123 | 
            +
            max_sentences: 20
         | 
| 124 | 
            +
            max_tokens: 24000
         | 
| 125 | 
            +
            max_updates: 3000000
         | 
| 126 | 
            +
            max_valid_sentences: 1
         | 
| 127 | 
            +
            max_valid_tokens: 60000
         | 
| 128 | 
            +
            mel_loss: ssim:0.5|l1:0.5
         | 
| 129 | 
            +
            mel_vmax: 1.5
         | 
| 130 | 
            +
            mel_vmin: -6
         | 
| 131 | 
            +
            min_frames: 0
         | 
| 132 | 
            +
            min_level_db: -120
         | 
| 133 | 
            +
            num_ckpt_keep: 3
         | 
| 134 | 
            +
            num_heads: 2
         | 
| 135 | 
            +
            num_mels: 80
         | 
| 136 | 
            +
            num_sanity_val_steps: 5
         | 
| 137 | 
            +
            num_spk: 100
         | 
| 138 | 
            +
            num_test_samples: 0
         | 
| 139 | 
            +
            num_valid_plots: 10
         | 
| 140 | 
            +
            optimizer_adam_beta1: 0.9
         | 
| 141 | 
            +
            optimizer_adam_beta2: 0.98
         | 
| 142 | 
            +
            out_wav_norm: false
         | 
| 143 | 
            +
            pitch_extractor: parselmouth
         | 
| 144 | 
            +
            pitch_type: frame
         | 
| 145 | 
            +
            pre_align_args:
         | 
| 146 | 
            +
              allow_no_txt: false
         | 
| 147 | 
            +
              denoise: false
         | 
| 148 | 
            +
              sox_resample: true
         | 
| 149 | 
            +
              sox_to_wav: false
         | 
| 150 | 
            +
              trim_sil: false
         | 
| 151 | 
            +
              txt_processor: zh
         | 
| 152 | 
            +
              use_tone: false
         | 
| 153 | 
            +
            pre_align_cls: data_gen.tts.singing.pre_align.SingingPreAlign
         | 
| 154 | 
            +
            predictor_grad: 0.0
         | 
| 155 | 
            +
            print_nan_grads: false
         | 
| 156 | 
            +
            processed_data_dir: ''
         | 
| 157 | 
            +
            profile_infer: false
         | 
| 158 | 
            +
            raw_data_dir: ''
         | 
| 159 | 
            +
            ref_level_db: 20
         | 
| 160 | 
            +
            rename_tmux: true
         | 
| 161 | 
            +
            rerun_gen: true
         | 
| 162 | 
            +
            resblock: '1'
         | 
| 163 | 
            +
            resblock_dilation_sizes:
         | 
| 164 | 
            +
            - - 1
         | 
| 165 | 
            +
              - 3
         | 
| 166 | 
            +
              - 5
         | 
| 167 | 
            +
            - - 1
         | 
| 168 | 
            +
              - 3
         | 
| 169 | 
            +
              - 5
         | 
| 170 | 
            +
            - - 1
         | 
| 171 | 
            +
              - 3
         | 
| 172 | 
            +
              - 5
         | 
| 173 | 
            +
            resblock_kernel_sizes:
         | 
| 174 | 
            +
            - 3
         | 
| 175 | 
            +
            - 7
         | 
| 176 | 
            +
            - 11
         | 
| 177 | 
            +
            resume_from_checkpoint: 0
         | 
| 178 | 
            +
            save_best: true
         | 
| 179 | 
            +
            save_codes: []
         | 
| 180 | 
            +
            save_f0: true
         | 
| 181 | 
            +
            save_gt: true
         | 
| 182 | 
            +
            scheduler: rsqrt
         | 
| 183 | 
            +
            seed: 1234
         | 
| 184 | 
            +
            sort_by_len: true
         | 
| 185 | 
            +
            stft_loss_params:
         | 
| 186 | 
            +
              fft_sizes:
         | 
| 187 | 
            +
              - 1024
         | 
| 188 | 
            +
              - 2048
         | 
| 189 | 
            +
              - 512
         | 
| 190 | 
            +
              hop_sizes:
         | 
| 191 | 
            +
              - 120
         | 
| 192 | 
            +
              - 240
         | 
| 193 | 
            +
              - 50
         | 
| 194 | 
            +
              win_lengths:
         | 
| 195 | 
            +
              - 600
         | 
| 196 | 
            +
              - 1200
         | 
| 197 | 
            +
              - 240
         | 
| 198 | 
            +
              window: hann_window
         | 
| 199 | 
            +
            task_cls: tasks.vocoder.hifigan.HifiGanTask
         | 
| 200 | 
            +
            tb_log_interval: 100
         | 
| 201 | 
            +
            test_ids: []
         | 
| 202 | 
            +
            test_input_dir: ''
         | 
| 203 | 
            +
            test_num: 50
         | 
| 204 | 
            +
            test_prefixes: []
         | 
| 205 | 
            +
            test_set_name: test
         | 
| 206 | 
            +
            train_set_name: train
         | 
| 207 | 
            +
            train_sets: ''
         | 
| 208 | 
            +
            upsample_initial_channel: 512
         | 
| 209 | 
            +
            upsample_kernel_sizes:
         | 
| 210 | 
            +
            - 16
         | 
| 211 | 
            +
            - 16
         | 
| 212 | 
            +
            - 4
         | 
| 213 | 
            +
            - 4
         | 
| 214 | 
            +
            upsample_rates:
         | 
| 215 | 
            +
            - 8
         | 
| 216 | 
            +
            - 4
         | 
| 217 | 
            +
            - 2
         | 
| 218 | 
            +
            - 2
         | 
| 219 | 
            +
            use_cdisc: false
         | 
| 220 | 
            +
            use_cond_disc: false
         | 
| 221 | 
            +
            use_fm_loss: false
         | 
| 222 | 
            +
            use_gt_dur: true
         | 
| 223 | 
            +
            use_gt_f0: true
         | 
| 224 | 
            +
            use_mel_loss: true
         | 
| 225 | 
            +
            use_ms_stft: false
         | 
| 226 | 
            +
            use_pitch_embed: true
         | 
| 227 | 
            +
            use_ref_enc: true
         | 
| 228 | 
            +
            use_spec_disc: false
         | 
| 229 | 
            +
            use_spk_embed: false
         | 
| 230 | 
            +
            use_spk_id: false
         | 
| 231 | 
            +
            use_split_spk_id: false
         | 
| 232 | 
            +
            val_check_interval: 2000
         | 
| 233 | 
            +
            valid_infer_interval: 10000
         | 
| 234 | 
            +
            valid_monitor_key: val_loss
         | 
| 235 | 
            +
            valid_monitor_mode: min
         | 
| 236 | 
            +
            valid_set_name: valid
         | 
| 237 | 
            +
            vocoder: pwg
         | 
| 238 | 
            +
            vocoder_ckpt: ''
         | 
| 239 | 
            +
            vocoder_denoise_c: 0.0
         | 
| 240 | 
            +
            warmup_updates: 8000
         | 
| 241 | 
            +
            weight_decay: 0
         | 
| 242 | 
            +
            win_length: null
         | 
| 243 | 
            +
            win_size: 512
         | 
| 244 | 
            +
            window: hann
         | 
| 245 | 
            +
            word_size: 3000
         | 
| 246 | 
            +
            work_dir: checkpoints/m4singer_hifigan
         | 
    	
        checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c3e859bd2b1e125fe661aedfd6fa3e97e10e06f3ec3d03b7735a041984402f89
         | 
| 3 | 
            +
            size 1016324099
         | 
    	
        checkpoints/m4singer_pe/config.yaml
    ADDED
    
    | @@ -0,0 +1,172 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            accumulate_grad_batches: 1
         | 
| 2 | 
            +
            audio_num_mel_bins: 80
         | 
| 3 | 
            +
            audio_sample_rate: 24000
         | 
| 4 | 
            +
            base_config:
         | 
| 5 | 
            +
            - configs/tts/lj/fs2.yaml
         | 
| 6 | 
            +
            binarization_args:
         | 
| 7 | 
            +
              shuffle: false
         | 
| 8 | 
            +
              with_align: true
         | 
| 9 | 
            +
              with_f0: true
         | 
| 10 | 
            +
              with_f0cwt: true
         | 
| 11 | 
            +
              with_spk_embed: true
         | 
| 12 | 
            +
              with_txt: true
         | 
| 13 | 
            +
              with_wav: false
         | 
| 14 | 
            +
            binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
         | 
| 15 | 
            +
            binary_data_dir: data/binary/m4singer
         | 
| 16 | 
            +
            check_val_every_n_epoch: 10
         | 
| 17 | 
            +
            clip_grad_norm: 1
         | 
| 18 | 
            +
            cwt_add_f0_loss: false
         | 
| 19 | 
            +
            cwt_hidden_size: 128
         | 
| 20 | 
            +
            cwt_layers: 2
         | 
| 21 | 
            +
            cwt_loss: l1
         | 
| 22 | 
            +
            cwt_std_scale: 0.8
         | 
| 23 | 
            +
            debug: false
         | 
| 24 | 
            +
            dec_ffn_kernel_size: 9
         | 
| 25 | 
            +
            dec_layers: 4
         | 
| 26 | 
            +
            decoder_type: fft
         | 
| 27 | 
            +
            dict_dir: ''
         | 
| 28 | 
            +
            dropout: 0.1
         | 
| 29 | 
            +
            ds_workers: 4
         | 
| 30 | 
            +
            dur_enc_hidden_stride_kernel:
         | 
| 31 | 
            +
            - 0,2,3
         | 
| 32 | 
            +
            - 0,2,3
         | 
| 33 | 
            +
            - 0,1,3
         | 
| 34 | 
            +
            dur_loss: mse
         | 
| 35 | 
            +
            dur_predictor_kernel: 3
         | 
| 36 | 
            +
            dur_predictor_layers: 2
         | 
| 37 | 
            +
            enc_ffn_kernel_size: 9
         | 
| 38 | 
            +
            enc_layers: 4
         | 
| 39 | 
            +
            encoder_K: 8
         | 
| 40 | 
            +
            encoder_type: fft
         | 
| 41 | 
            +
            endless_ds: true
         | 
| 42 | 
            +
            ffn_act: gelu
         | 
| 43 | 
            +
            ffn_padding: SAME
         | 
| 44 | 
            +
            fft_size: 512
         | 
| 45 | 
            +
            fmax: 12000
         | 
| 46 | 
            +
            fmin: 30
         | 
| 47 | 
            +
            gen_dir_name: ''
         | 
| 48 | 
            +
            hidden_size: 256
         | 
| 49 | 
            +
            hop_size: 128
         | 
| 50 | 
            +
            infer: false
         | 
| 51 | 
            +
            lambda_commit: 0.25
         | 
| 52 | 
            +
            lambda_energy: 0.1
         | 
| 53 | 
            +
            lambda_f0: 1.0
         | 
| 54 | 
            +
            lambda_ph_dur: 1.0
         | 
| 55 | 
            +
            lambda_sent_dur: 1.0
         | 
| 56 | 
            +
            lambda_uv: 1.0
         | 
| 57 | 
            +
            lambda_word_dur: 1.0
         | 
| 58 | 
            +
            load_ckpt: ''
         | 
| 59 | 
            +
            log_interval: 100
         | 
| 60 | 
            +
            loud_norm: false
         | 
| 61 | 
            +
            lr: 0.1
         | 
| 62 | 
            +
            max_epochs: 1000
         | 
| 63 | 
            +
            max_eval_sentences: 1
         | 
| 64 | 
            +
            max_eval_tokens: 60000
         | 
| 65 | 
            +
            max_frames: 5000
         | 
| 66 | 
            +
            max_input_tokens: 1550
         | 
| 67 | 
            +
            max_sentences: 100000
         | 
| 68 | 
            +
            max_tokens: 20000
         | 
| 69 | 
            +
            max_updates: 280000
         | 
| 70 | 
            +
            mel_loss: l1
         | 
| 71 | 
            +
            mel_vmax: 1.5
         | 
| 72 | 
            +
            mel_vmin: -6
         | 
| 73 | 
            +
            min_level_db: -120
         | 
| 74 | 
            +
            norm_type: gn
         | 
| 75 | 
            +
            num_ckpt_keep: 3
         | 
| 76 | 
            +
            num_heads: 2
         | 
| 77 | 
            +
            num_sanity_val_steps: 5
         | 
| 78 | 
            +
            num_spk: 1
         | 
| 79 | 
            +
            num_test_samples: 20
         | 
| 80 | 
            +
            num_valid_plots: 10
         | 
| 81 | 
            +
            optimizer_adam_beta1: 0.9
         | 
| 82 | 
            +
            optimizer_adam_beta2: 0.98
         | 
| 83 | 
            +
            out_wav_norm: false
         | 
| 84 | 
            +
            pitch_ar: false
         | 
| 85 | 
            +
            pitch_enc_hidden_stride_kernel:
         | 
| 86 | 
            +
            - 0,2,5
         | 
| 87 | 
            +
            - 0,2,5
         | 
| 88 | 
            +
            - 0,2,5
         | 
| 89 | 
            +
            pitch_extractor_conv_layers: 2
         | 
| 90 | 
            +
            pitch_loss: l1
         | 
| 91 | 
            +
            pitch_norm: log
         | 
| 92 | 
            +
            pitch_type: frame
         | 
| 93 | 
            +
            pre_align_args:
         | 
| 94 | 
            +
              allow_no_txt: false
         | 
| 95 | 
            +
              denoise: false
         | 
| 96 | 
            +
              forced_align: mfa
         | 
| 97 | 
            +
              txt_processor: en
         | 
| 98 | 
            +
              use_sox: false
         | 
| 99 | 
            +
              use_tone: true
         | 
| 100 | 
            +
            pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
         | 
| 101 | 
            +
            predictor_dropout: 0.5
         | 
| 102 | 
            +
            predictor_grad: 0.1
         | 
| 103 | 
            +
            predictor_hidden: -1
         | 
| 104 | 
            +
            predictor_kernel: 5
         | 
| 105 | 
            +
            predictor_layers: 2
         | 
| 106 | 
            +
            prenet_dropout: 0.5
         | 
| 107 | 
            +
            prenet_hidden_size: 256
         | 
| 108 | 
            +
            pretrain_fs_ckpt: ''
         | 
| 109 | 
            +
            processed_data_dir: data/processed/ljspeech
         | 
| 110 | 
            +
            profile_infer: false
         | 
| 111 | 
            +
            raw_data_dir: data/raw/LJSpeech-1.1
         | 
| 112 | 
            +
            ref_norm_layer: bn
         | 
| 113 | 
            +
            reset_phone_dict: true
         | 
| 114 | 
            +
            save_best: false
         | 
| 115 | 
            +
            save_ckpt: true
         | 
| 116 | 
            +
            save_codes:
         | 
| 117 | 
            +
            - configs
         | 
| 118 | 
            +
            - modules
         | 
| 119 | 
            +
            - tasks
         | 
| 120 | 
            +
            - utils
         | 
| 121 | 
            +
            - usr
         | 
| 122 | 
            +
            save_f0: false
         | 
| 123 | 
            +
            save_gt: false
         | 
| 124 | 
            +
            seed: 1234
         | 
| 125 | 
            +
            sort_by_len: true
         | 
| 126 | 
            +
            stop_token_weight: 5.0
         | 
| 127 | 
            +
            task_cls: tasks.tts.pe.PitchExtractionTask
         | 
| 128 | 
            +
            test_ids:
         | 
| 129 | 
            +
            - 68
         | 
| 130 | 
            +
            - 70
         | 
| 131 | 
            +
            - 74
         | 
| 132 | 
            +
            - 87
         | 
| 133 | 
            +
            - 110
         | 
| 134 | 
            +
            - 172
         | 
| 135 | 
            +
            - 190
         | 
| 136 | 
            +
            - 215
         | 
| 137 | 
            +
            - 231
         | 
| 138 | 
            +
            - 294
         | 
| 139 | 
            +
            - 316
         | 
| 140 | 
            +
            - 324
         | 
| 141 | 
            +
            - 402
         | 
| 142 | 
            +
            - 422
         | 
| 143 | 
            +
            - 485
         | 
| 144 | 
            +
            - 500
         | 
| 145 | 
            +
            - 505
         | 
| 146 | 
            +
            - 508
         | 
| 147 | 
            +
            - 509
         | 
| 148 | 
            +
            - 519
         | 
| 149 | 
            +
            test_input_dir: ''
         | 
| 150 | 
            +
            test_num: 523
         | 
| 151 | 
            +
            test_set_name: test
         | 
| 152 | 
            +
            train_set_name: train
         | 
| 153 | 
            +
            use_denoise: false
         | 
| 154 | 
            +
            use_energy_embed: false
         | 
| 155 | 
            +
            use_gt_dur: false
         | 
| 156 | 
            +
            use_gt_f0: false
         | 
| 157 | 
            +
            use_pitch_embed: true
         | 
| 158 | 
            +
            use_pos_embed: true
         | 
| 159 | 
            +
            use_spk_embed: false
         | 
| 160 | 
            +
            use_spk_id: false
         | 
| 161 | 
            +
            use_split_spk_id: false
         | 
| 162 | 
            +
            use_uv: true
         | 
| 163 | 
            +
            use_var_enc: false
         | 
| 164 | 
            +
            val_check_interval: 2000
         | 
| 165 | 
            +
            valid_num: 348
         | 
| 166 | 
            +
            valid_set_name: valid
         | 
| 167 | 
            +
            vocoder: pwg
         | 
| 168 | 
            +
            vocoder_ckpt: ''
         | 
| 169 | 
            +
            warmup_updates: 2000
         | 
| 170 | 
            +
            weight_decay: 0
         | 
| 171 | 
            +
            win_size: 512
         | 
| 172 | 
            +
            work_dir: checkpoints/m4singer_pe
         | 
    	
        checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:10cbf382bf82ecf335fbf68ba226f93c9c715b0476f6604351cbad9783f529fe
         | 
| 3 | 
            +
            size 39146292
         | 
    	
        configs/config_base.yaml
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # task
         | 
| 2 | 
            +
            binary_data_dir: ''
         | 
| 3 | 
            +
            work_dir: '' # experiment directory.
         | 
| 4 | 
            +
            infer: false # infer
         | 
| 5 | 
            +
            seed: 1234
         | 
| 6 | 
            +
            debug: false
         | 
| 7 | 
            +
            save_codes:
         | 
| 8 | 
            +
              - configs
         | 
| 9 | 
            +
              - modules
         | 
| 10 | 
            +
              - tasks
         | 
| 11 | 
            +
              - utils
         | 
| 12 | 
            +
              - usr
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            #############
         | 
| 15 | 
            +
            # dataset
         | 
| 16 | 
            +
            #############
         | 
| 17 | 
            +
            ds_workers: 1
         | 
| 18 | 
            +
            test_num: 100
         | 
| 19 | 
            +
            valid_num: 100
         | 
| 20 | 
            +
            endless_ds: false
         | 
| 21 | 
            +
            sort_by_len: true
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            #########
         | 
| 24 | 
            +
            # train and eval
         | 
| 25 | 
            +
            #########
         | 
| 26 | 
            +
            load_ckpt: ''
         | 
| 27 | 
            +
            save_ckpt: true
         | 
| 28 | 
            +
            save_best: false
         | 
| 29 | 
            +
            num_ckpt_keep: 3
         | 
| 30 | 
            +
            clip_grad_norm: 0
         | 
| 31 | 
            +
            accumulate_grad_batches: 1
         | 
| 32 | 
            +
            log_interval: 100
         | 
| 33 | 
            +
            num_sanity_val_steps: 5  # steps of validation at the beginning
         | 
| 34 | 
            +
            check_val_every_n_epoch: 10
         | 
| 35 | 
            +
            val_check_interval: 2000
         | 
| 36 | 
            +
            max_epochs: 1000
         | 
| 37 | 
            +
            max_updates: 160000
         | 
| 38 | 
            +
            max_tokens: 31250
         | 
| 39 | 
            +
            max_sentences: 100000
         | 
| 40 | 
            +
            max_eval_tokens: -1
         | 
| 41 | 
            +
            max_eval_sentences: -1
         | 
| 42 | 
            +
            test_input_dir: ''
         | 
    	
        configs/singing/base.yaml
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_config:
         | 
| 2 | 
            +
              - configs/tts/base.yaml
         | 
| 3 | 
            +
              - configs/tts/base_zh.yaml
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            datasets: []
         | 
| 7 | 
            +
            test_prefixes: []
         | 
| 8 | 
            +
            test_num: 0
         | 
| 9 | 
            +
            valid_num: 0
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
         | 
| 12 | 
            +
            binarizer_cls: data_gen.singing.binarize.SingingBinarizer
         | 
| 13 | 
            +
            pre_align_args:
         | 
| 14 | 
            +
              use_tone: false # for ZH
         | 
| 15 | 
            +
              forced_align: mfa
         | 
| 16 | 
            +
              use_sox: true
         | 
| 17 | 
            +
            hop_size: 128            # Hop size.
         | 
| 18 | 
            +
            fft_size: 512           # FFT size.
         | 
| 19 | 
            +
            win_size: 512           # FFT size.
         | 
| 20 | 
            +
            max_frames: 8000
         | 
| 21 | 
            +
            fmin: 50                 # Minimum freq in mel basis calculation.
         | 
| 22 | 
            +
            fmax: 11025               # Maximum frequency in mel basis calculation.
         | 
| 23 | 
            +
            pitch_type: frame
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            hidden_size: 256
         | 
| 26 | 
            +
            mel_loss: "ssim:0.5|l1:0.5"
         | 
| 27 | 
            +
            lambda_f0: 0.0
         | 
| 28 | 
            +
            lambda_uv: 0.0
         | 
| 29 | 
            +
            lambda_energy: 0.0
         | 
| 30 | 
            +
            lambda_ph_dur: 0.0
         | 
| 31 | 
            +
            lambda_sent_dur: 0.0
         | 
| 32 | 
            +
            lambda_word_dur: 0.0
         | 
| 33 | 
            +
            predictor_grad: 0.0
         | 
| 34 | 
            +
            use_spk_embed: true
         | 
| 35 | 
            +
            use_spk_id: false
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            max_tokens: 20000
         | 
| 38 | 
            +
            max_updates: 400000
         | 
| 39 | 
            +
            num_spk: 100
         | 
| 40 | 
            +
            save_f0: true
         | 
| 41 | 
            +
            use_gt_dur: true
         | 
| 42 | 
            +
            use_gt_f0: true
         | 
    	
        configs/singing/fs2.yaml
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_config:
         | 
| 2 | 
            +
              - configs/tts/fs2.yaml
         | 
| 3 | 
            +
              - configs/singing/base.yaml
         | 
    	
        configs/tts/base.yaml
    ADDED
    
    | @@ -0,0 +1,95 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # task
         | 
| 2 | 
            +
            base_config: configs/config_base.yaml
         | 
| 3 | 
            +
            task_cls: ''
         | 
| 4 | 
            +
            #############
         | 
| 5 | 
            +
            # dataset
         | 
| 6 | 
            +
            #############
         | 
| 7 | 
            +
            raw_data_dir: ''
         | 
| 8 | 
            +
            processed_data_dir: ''
         | 
| 9 | 
            +
            binary_data_dir: ''
         | 
| 10 | 
            +
            dict_dir: ''
         | 
| 11 | 
            +
            pre_align_cls: ''
         | 
| 12 | 
            +
            binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
         | 
| 13 | 
            +
            pre_align_args:
         | 
| 14 | 
            +
              use_tone: true # for ZH
         | 
| 15 | 
            +
              forced_align: mfa
         | 
| 16 | 
            +
              use_sox: false
         | 
| 17 | 
            +
              txt_processor: en
         | 
| 18 | 
            +
              allow_no_txt: false
         | 
| 19 | 
            +
              denoise: false
         | 
| 20 | 
            +
            binarization_args:
         | 
| 21 | 
            +
              shuffle: false
         | 
| 22 | 
            +
              with_txt: true
         | 
| 23 | 
            +
              with_wav: false
         | 
| 24 | 
            +
              with_align: true
         | 
| 25 | 
            +
              with_spk_embed: true
         | 
| 26 | 
            +
              with_f0: true
         | 
| 27 | 
            +
              with_f0cwt: true
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            loud_norm: false
         | 
| 30 | 
            +
            endless_ds: true
         | 
| 31 | 
            +
            reset_phone_dict: true
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            test_num: 100
         | 
| 34 | 
            +
            valid_num: 100
         | 
| 35 | 
            +
            max_frames: 1550
         | 
| 36 | 
            +
            max_input_tokens: 1550
         | 
| 37 | 
            +
            audio_num_mel_bins: 80
         | 
| 38 | 
            +
            audio_sample_rate: 22050
         | 
| 39 | 
            +
            hop_size: 256  # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
         | 
| 40 | 
            +
            win_size: 1024  # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
         | 
| 41 | 
            +
            fmin: 80  # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
         | 
| 42 | 
            +
            fmax: 7600  # To be increased/reduced depending on data.
         | 
| 43 | 
            +
            fft_size: 1024  # Extra window size is filled with 0 paddings to match this parameter
         | 
| 44 | 
            +
            min_level_db: -100
         | 
| 45 | 
            +
            num_spk: 1
         | 
| 46 | 
            +
            mel_vmin: -6
         | 
| 47 | 
            +
            mel_vmax: 1.5
         | 
| 48 | 
            +
            ds_workers: 4
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            #########
         | 
| 51 | 
            +
            # model
         | 
| 52 | 
            +
            #########
         | 
| 53 | 
            +
            dropout: 0.1
         | 
| 54 | 
            +
            enc_layers: 4
         | 
| 55 | 
            +
            dec_layers: 4
         | 
| 56 | 
            +
            hidden_size: 384
         | 
| 57 | 
            +
            num_heads: 2
         | 
| 58 | 
            +
            prenet_dropout: 0.5
         | 
| 59 | 
            +
            prenet_hidden_size: 256
         | 
| 60 | 
            +
            stop_token_weight: 5.0
         | 
| 61 | 
            +
            enc_ffn_kernel_size: 9
         | 
| 62 | 
            +
            dec_ffn_kernel_size: 9
         | 
| 63 | 
            +
            ffn_act: gelu
         | 
| 64 | 
            +
            ffn_padding: 'SAME'
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            ###########
         | 
| 68 | 
            +
            # optimization
         | 
| 69 | 
            +
            ###########
         | 
| 70 | 
            +
            lr: 2.0
         | 
| 71 | 
            +
            warmup_updates: 8000
         | 
| 72 | 
            +
            optimizer_adam_beta1: 0.9
         | 
| 73 | 
            +
            optimizer_adam_beta2: 0.98
         | 
| 74 | 
            +
            weight_decay: 0
         | 
| 75 | 
            +
            clip_grad_norm: 1
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            ###########
         | 
| 79 | 
            +
            # train and eval
         | 
| 80 | 
            +
            ###########
         | 
| 81 | 
            +
            max_tokens: 30000
         | 
| 82 | 
            +
            max_sentences: 100000
         | 
| 83 | 
            +
            max_eval_sentences: 1
         | 
| 84 | 
            +
            max_eval_tokens: 60000
         | 
| 85 | 
            +
            train_set_name: 'train'
         | 
| 86 | 
            +
            valid_set_name: 'valid'
         | 
| 87 | 
            +
            test_set_name: 'test'
         | 
| 88 | 
            +
            vocoder: pwg
         | 
| 89 | 
            +
            vocoder_ckpt: ''
         | 
| 90 | 
            +
            profile_infer: false
         | 
| 91 | 
            +
            out_wav_norm: false
         | 
| 92 | 
            +
            save_gt: false
         | 
| 93 | 
            +
            save_f0: false
         | 
| 94 | 
            +
            gen_dir_name: ''
         | 
| 95 | 
            +
            use_denoise: false
         | 
    	
        configs/tts/base_zh.yaml
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pre_align_args:
         | 
| 2 | 
            +
              txt_processor: zh_g2pM
         | 
| 3 | 
            +
            binarizer_cls: data_gen.tts.binarizer_zh.ZhBinarizer
         | 
    	
        configs/tts/fs2.yaml
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_config: configs/tts/base.yaml
         | 
| 2 | 
            +
            task_cls: tasks.tts.fs2.FastSpeech2Task
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # model
         | 
| 5 | 
            +
            hidden_size: 256
         | 
| 6 | 
            +
            dropout: 0.1
         | 
| 7 | 
            +
            encoder_type: fft # fft|tacotron|tacotron2|conformer
         | 
| 8 | 
            +
            encoder_K: 8 # for tacotron encoder
         | 
| 9 | 
            +
            decoder_type: fft # fft|rnn|conv|conformer
         | 
| 10 | 
            +
            use_pos_embed: true
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # duration
         | 
| 13 | 
            +
            predictor_hidden: -1
         | 
| 14 | 
            +
            predictor_kernel: 5
         | 
| 15 | 
            +
            predictor_layers: 2
         | 
| 16 | 
            +
            dur_predictor_kernel: 3
         | 
| 17 | 
            +
            dur_predictor_layers: 2
         | 
| 18 | 
            +
            predictor_dropout: 0.5
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # pitch and energy
         | 
| 21 | 
            +
            use_pitch_embed: true
         | 
| 22 | 
            +
            pitch_type: ph # frame|ph|cwt
         | 
| 23 | 
            +
            use_uv: true
         | 
| 24 | 
            +
            cwt_hidden_size: 128
         | 
| 25 | 
            +
            cwt_layers: 2
         | 
| 26 | 
            +
            cwt_loss: l1
         | 
| 27 | 
            +
            cwt_add_f0_loss: false
         | 
| 28 | 
            +
            cwt_std_scale: 0.8
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            pitch_ar: false
         | 
| 31 | 
            +
            #pitch_embed_type: 0q
         | 
| 32 | 
            +
            pitch_loss: 'l1' # l1|l2|ssim
         | 
| 33 | 
            +
            pitch_norm: log
         | 
| 34 | 
            +
            use_energy_embed: false
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # reference encoder and speaker embedding
         | 
| 37 | 
            +
            use_spk_id: false
         | 
| 38 | 
            +
            use_split_spk_id: false
         | 
| 39 | 
            +
            use_spk_embed: false
         | 
| 40 | 
            +
            use_var_enc: false
         | 
| 41 | 
            +
            lambda_commit: 0.25
         | 
| 42 | 
            +
            ref_norm_layer: bn
         | 
| 43 | 
            +
            pitch_enc_hidden_stride_kernel:
         | 
| 44 | 
            +
              - 0,2,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
         | 
| 45 | 
            +
              - 0,2,5
         | 
| 46 | 
            +
              - 0,2,5
         | 
| 47 | 
            +
            dur_enc_hidden_stride_kernel:
         | 
| 48 | 
            +
              - 0,2,3 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
         | 
| 49 | 
            +
              - 0,2,3
         | 
| 50 | 
            +
              - 0,1,3
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            # mel
         | 
| 54 | 
            +
            mel_loss: l1:0.5|ssim:0.5 # l1|l2|gdl|ssim or l1:0.5|ssim:0.5
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            # loss lambda
         | 
| 57 | 
            +
            lambda_f0: 1.0
         | 
| 58 | 
            +
            lambda_uv: 1.0
         | 
| 59 | 
            +
            lambda_energy: 0.1
         | 
| 60 | 
            +
            lambda_ph_dur: 1.0
         | 
| 61 | 
            +
            lambda_sent_dur: 1.0
         | 
| 62 | 
            +
            lambda_word_dur: 1.0
         | 
| 63 | 
            +
            predictor_grad: 0.1
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            # train and eval
         | 
| 66 | 
            +
            pretrain_fs_ckpt: ''
         | 
| 67 | 
            +
            warmup_updates: 2000
         | 
| 68 | 
            +
            max_tokens: 32000
         | 
| 69 | 
            +
            max_sentences: 100000
         | 
| 70 | 
            +
            max_eval_sentences: 1
         | 
| 71 | 
            +
            max_updates: 120000
         | 
| 72 | 
            +
            num_valid_plots: 5
         | 
| 73 | 
            +
            num_test_samples: 0
         | 
| 74 | 
            +
            test_ids: []
         | 
| 75 | 
            +
            use_gt_dur: false
         | 
| 76 | 
            +
            use_gt_f0: false
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            # exp
         | 
| 79 | 
            +
            dur_loss: mse # huber|mol
         | 
| 80 | 
            +
            norm_type: gn
         | 
    	
        configs/tts/hifigan.yaml
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_config: configs/tts/pwg.yaml
         | 
| 2 | 
            +
            task_cls: tasks.vocoder.hifigan.HifiGanTask
         | 
| 3 | 
            +
            resblock: "1"
         | 
| 4 | 
            +
            adam_b1: 0.8
         | 
| 5 | 
            +
            adam_b2: 0.99
         | 
| 6 | 
            +
            upsample_rates: [ 8,8,2,2 ]
         | 
| 7 | 
            +
            upsample_kernel_sizes: [ 16,16,4,4 ]
         | 
| 8 | 
            +
            upsample_initial_channel: 128
         | 
| 9 | 
            +
            resblock_kernel_sizes: [ 3,7,11 ]
         | 
| 10 | 
            +
            resblock_dilation_sizes: [ [ 1,3,5 ], [ 1,3,5 ], [ 1,3,5 ] ]
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            lambda_mel: 45.0
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            max_samples: 8192
         | 
| 15 | 
            +
            max_sentences: 16
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            generator_params:
         | 
| 18 | 
            +
              lr: 0.0002            # Generator's learning rate.
         | 
| 19 | 
            +
              aux_context_window: 0 # Context window size for auxiliary feature.
         | 
| 20 | 
            +
            discriminator_optimizer_params:
         | 
| 21 | 
            +
              lr: 0.0002            # Discriminator's learning rate.
         | 
    	
        configs/tts/lj/base_mel2wav.yaml
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            raw_data_dir: 'data/raw/LJSpeech-1.1'
         | 
| 2 | 
            +
            processed_data_dir: 'data/processed/ljspeech'
         | 
| 3 | 
            +
            binary_data_dir: 'data/binary/ljspeech_wav'
         | 
    	
        configs/tts/lj/base_text2mel.yaml
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            raw_data_dir: 'data/raw/LJSpeech-1.1'
         | 
| 2 | 
            +
            processed_data_dir: 'data/processed/ljspeech'
         | 
| 3 | 
            +
            binary_data_dir: 'data/binary/ljspeech'
         | 
| 4 | 
            +
            pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            pitch_type: cwt
         | 
| 7 | 
            +
            mel_loss: l1
         | 
| 8 | 
            +
            num_test_samples: 20
         | 
| 9 | 
            +
            test_ids: [ 68, 70, 74, 87, 110, 172, 190, 215, 231, 294,
         | 
| 10 | 
            +
                        316, 324, 402, 422, 485, 500, 505, 508, 509, 519 ]
         | 
| 11 | 
            +
            use_energy_embed: false
         | 
| 12 | 
            +
            test_num: 523
         | 
| 13 | 
            +
            valid_num: 348
         | 
    	
        configs/tts/lj/fs2.yaml
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_config:
         | 
| 2 | 
            +
              - configs/tts/fs2.yaml
         | 
| 3 | 
            +
              - configs/tts/lj/base_text2mel.yaml
         | 
    	
        configs/tts/lj/hifigan.yaml
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_config:
         | 
| 2 | 
            +
              - configs/tts/hifigan.yaml
         | 
| 3 | 
            +
              - configs/tts/lj/base_mel2wav.yaml
         | 
    	
        configs/tts/lj/pwg.yaml
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_config:
         | 
| 2 | 
            +
              - configs/tts/pwg.yaml
         | 
| 3 | 
            +
              - configs/tts/lj/base_mel2wav.yaml
         | 
    	
        configs/tts/pwg.yaml
    ADDED
    
    | @@ -0,0 +1,110 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            base_config: configs/tts/base.yaml
         | 
| 2 | 
            +
            task_cls: tasks.vocoder.pwg.PwgTask
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            binarization_args:
         | 
| 5 | 
            +
              with_wav: true
         | 
| 6 | 
            +
              with_spk_embed: false
         | 
| 7 | 
            +
              with_align: false
         | 
| 8 | 
            +
            test_input_dir: ''
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            ###########
         | 
| 11 | 
            +
            # train and eval
         | 
| 12 | 
            +
            ###########
         | 
| 13 | 
            +
            max_samples: 25600
         | 
| 14 | 
            +
            max_sentences: 5
         | 
| 15 | 
            +
            max_eval_sentences: 1
         | 
| 16 | 
            +
            max_updates: 1000000
         | 
| 17 | 
            +
            val_check_interval: 2000
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            ###########################################################
         | 
| 21 | 
            +
            #                FEATURE EXTRACTION SETTING               #
         | 
| 22 | 
            +
            ###########################################################
         | 
| 23 | 
            +
            sampling_rate: 22050     # Sampling rate.
         | 
| 24 | 
            +
            fft_size: 1024           # FFT size.
         | 
| 25 | 
            +
            hop_size: 256            # Hop size.
         | 
| 26 | 
            +
            win_length: null         # Window length.
         | 
| 27 | 
            +
            # If set to null, it will be the same as fft_size.
         | 
| 28 | 
            +
            window: "hann"           # Window function.
         | 
| 29 | 
            +
            num_mels: 80             # Number of mel basis.
         | 
| 30 | 
            +
            fmin: 80                 # Minimum freq in mel basis calculation.
         | 
| 31 | 
            +
            fmax: 7600               # Maximum frequency in mel basis calculation.
         | 
| 32 | 
            +
            format: "hdf5"           # Feature file format. "npy" or "hdf5" is supported.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            ###########################################################
         | 
| 35 | 
            +
            #         GENERATOR NETWORK ARCHITECTURE SETTING          #
         | 
| 36 | 
            +
            ###########################################################
         | 
| 37 | 
            +
            generator_params:
         | 
| 38 | 
            +
              in_channels: 1        # Number of input channels.
         | 
| 39 | 
            +
              out_channels: 1       # Number of output channels.
         | 
| 40 | 
            +
              kernel_size: 3        # Kernel size of dilated convolution.
         | 
| 41 | 
            +
              layers: 30            # Number of residual block layers.
         | 
| 42 | 
            +
              stacks: 3             # Number of stacks i.e., dilation cycles.
         | 
| 43 | 
            +
              residual_channels: 64 # Number of channels in residual conv.
         | 
| 44 | 
            +
              gate_channels: 128    # Number of channels in gated conv.
         | 
| 45 | 
            +
              skip_channels: 64     # Number of channels in skip conv.
         | 
| 46 | 
            +
              aux_channels: 80      # Number of channels for auxiliary feature conv.
         | 
| 47 | 
            +
              # Must be the same as num_mels.
         | 
| 48 | 
            +
              aux_context_window: 2 # Context window size for auxiliary feature.
         | 
| 49 | 
            +
              # If set to 2, previous 2 and future 2 frames will be considered.
         | 
| 50 | 
            +
              dropout: 0.0          # Dropout rate. 0.0 means no dropout applied.
         | 
| 51 | 
            +
              use_weight_norm: true # Whether to use weight norm.
         | 
| 52 | 
            +
              # If set to true, it will be applied to all of the conv layers.
         | 
| 53 | 
            +
              upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture.
         | 
| 54 | 
            +
              upsample_params:                      # Upsampling network parameters.
         | 
| 55 | 
            +
                upsample_scales: [4, 4, 4, 4]     # Upsampling scales. Prodcut of these must be the same as hop size.
         | 
| 56 | 
            +
              use_pitch_embed: false
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            ###########################################################
         | 
| 59 | 
            +
            #       DISCRIMINATOR NETWORK ARCHITECTURE SETTING        #
         | 
| 60 | 
            +
            ###########################################################
         | 
| 61 | 
            +
            discriminator_params:
         | 
| 62 | 
            +
              in_channels: 1        # Number of input channels.
         | 
| 63 | 
            +
              out_channels: 1       # Number of output channels.
         | 
| 64 | 
            +
              kernel_size: 3        # Number of output channels.
         | 
| 65 | 
            +
              layers: 10            # Number of conv layers.
         | 
| 66 | 
            +
              conv_channels: 64     # Number of chnn layers.
         | 
| 67 | 
            +
              bias: true            # Whether to use bias parameter in conv.
         | 
| 68 | 
            +
              use_weight_norm: true # Whether to use weight norm.
         | 
| 69 | 
            +
              # If set to true, it will be applied to all of the conv layers.
         | 
| 70 | 
            +
              nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv.
         | 
| 71 | 
            +
              nonlinear_activation_params:      # Nonlinear function parameters
         | 
| 72 | 
            +
                negative_slope: 0.2           # Alpha in LeakyReLU.
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            ###########################################################
         | 
| 75 | 
            +
            #                   STFT LOSS SETTING                     #
         | 
| 76 | 
            +
            ###########################################################
         | 
| 77 | 
            +
            stft_loss_params:
         | 
| 78 | 
            +
              fft_sizes: [1024, 2048, 512]  # List of FFT size for STFT-based loss.
         | 
| 79 | 
            +
              hop_sizes: [120, 240, 50]     # List of hop size for STFT-based loss
         | 
| 80 | 
            +
              win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
         | 
| 81 | 
            +
              window: "hann_window"         # Window function for STFT-based loss
         | 
| 82 | 
            +
            use_mel_loss: false
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            ###########################################################
         | 
| 85 | 
            +
            #               ADVERSARIAL LOSS SETTING                  #
         | 
| 86 | 
            +
            ###########################################################
         | 
| 87 | 
            +
            lambda_adv: 4.0  # Loss balancing coefficient.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            ###########################################################
         | 
| 90 | 
            +
            #             OPTIMIZER & SCHEDULER SETTING               #
         | 
| 91 | 
            +
            ###########################################################
         | 
| 92 | 
            +
            generator_optimizer_params:
         | 
| 93 | 
            +
              lr: 0.0001             # Generator's learning rate.
         | 
| 94 | 
            +
              eps: 1.0e-6            # Generator's epsilon.
         | 
| 95 | 
            +
              weight_decay: 0.0      # Generator's weight decay coefficient.
         | 
| 96 | 
            +
            generator_scheduler_params:
         | 
| 97 | 
            +
              step_size: 200000      # Generator's scheduler step size.
         | 
| 98 | 
            +
              gamma: 0.5             # Generator's scheduler gamma.
         | 
| 99 | 
            +
              # At each step size, lr will be multiplied by this parameter.
         | 
| 100 | 
            +
            generator_grad_norm: 10    # Generator's gradient norm.
         | 
| 101 | 
            +
            discriminator_optimizer_params:
         | 
| 102 | 
            +
              lr: 0.00005            # Discriminator's learning rate.
         | 
| 103 | 
            +
              eps: 1.0e-6            # Discriminator's epsilon.
         | 
| 104 | 
            +
              weight_decay: 0.0      # Discriminator's weight decay coefficient.
         | 
| 105 | 
            +
            discriminator_scheduler_params:
         | 
| 106 | 
            +
              step_size: 200000      # Discriminator's scheduler step size.
         | 
| 107 | 
            +
              gamma: 0.5             # Discriminator's scheduler gamma.
         | 
| 108 | 
            +
              # At each step size, lr will be multiplied by this parameter.
         | 
| 109 | 
            +
            discriminator_grad_norm: 1 # Discriminator's gradient norm.
         | 
| 110 | 
            +
            disc_start_steps: 40000 # Number of steps to start to train discriminator.
         | 
    	
        data_gen/singing/binarize.py
    ADDED
    
    | @@ -0,0 +1,393 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            from copy import deepcopy
         | 
| 4 | 
            +
            import pandas as pd
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            import glob
         | 
| 9 | 
            +
            import re
         | 
| 10 | 
            +
            from resemblyzer import VoiceEncoder
         | 
| 11 | 
            +
            import traceback
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import pretty_midi
         | 
| 14 | 
            +
            import librosa
         | 
| 15 | 
            +
            from scipy.interpolate import interp1d
         | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            from textgrid import TextGrid
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from utils.hparams import hparams
         | 
| 20 | 
            +
            from data_gen.tts.data_gen_utils import build_phone_encoder, get_pitch
         | 
| 21 | 
            +
            from utils.pitch_utils import f0_to_coarse
         | 
| 22 | 
            +
            from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
         | 
| 23 | 
            +
            from data_gen.tts.binarizer_zh import ZhBinarizer
         | 
| 24 | 
            +
            from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU
         | 
| 25 | 
            +
            from vocoders.base_vocoder import VOCODERS
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class SingingBinarizer(BaseBinarizer):
         | 
| 29 | 
            +
                def __init__(self, processed_data_dir=None):
         | 
| 30 | 
            +
                    if processed_data_dir is None:
         | 
| 31 | 
            +
                        processed_data_dir = hparams['processed_data_dir']
         | 
| 32 | 
            +
                    self.processed_data_dirs = processed_data_dir.split(",")
         | 
| 33 | 
            +
                    self.binarization_args = hparams['binarization_args']
         | 
| 34 | 
            +
                    self.pre_align_args = hparams['pre_align_args']
         | 
| 35 | 
            +
                    self.item2txt = {}
         | 
| 36 | 
            +
                    self.item2ph = {}
         | 
| 37 | 
            +
                    self.item2wavfn = {}
         | 
| 38 | 
            +
                    self.item2f0fn = {}
         | 
| 39 | 
            +
                    self.item2tgfn = {}
         | 
| 40 | 
            +
                    self.item2spk = {}
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def split_train_test_set(self, item_names):
         | 
| 43 | 
            +
                    item_names = deepcopy(item_names)
         | 
| 44 | 
            +
                    test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])]
         | 
| 45 | 
            +
                    train_item_names = [x for x in item_names if x not in set(test_item_names)]
         | 
| 46 | 
            +
                    logging.info("train {}".format(len(train_item_names)))
         | 
| 47 | 
            +
                    logging.info("test {}".format(len(test_item_names)))
         | 
| 48 | 
            +
                    return train_item_names, test_item_names
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def load_meta_data(self):
         | 
| 51 | 
            +
                    for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
         | 
| 52 | 
            +
                        wav_suffix = '_wf0.wav'
         | 
| 53 | 
            +
                        txt_suffix = '.txt'
         | 
| 54 | 
            +
                        ph_suffix = '_ph.txt'
         | 
| 55 | 
            +
                        tg_suffix = '.TextGrid'
         | 
| 56 | 
            +
                        all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}')
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                        for piece_path in all_wav_pieces:
         | 
| 59 | 
            +
                            item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)]
         | 
| 60 | 
            +
                            if len(self.processed_data_dirs) > 1:
         | 
| 61 | 
            +
                                item_name = f'ds{ds_id}_{item_name}'
         | 
| 62 | 
            +
                            self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline()
         | 
| 63 | 
            +
                            self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline()
         | 
| 64 | 
            +
                            self.item2wavfn[item_name] = piece_path
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                            self.item2spk[item_name] = re.split('-|#', piece_path.split('/')[-2])[0]
         | 
| 67 | 
            +
                            if len(self.processed_data_dirs) > 1:
         | 
| 68 | 
            +
                                self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
         | 
| 69 | 
            +
                            self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix)
         | 
| 70 | 
            +
                    print('spkers: ', set(self.item2spk.values()))
         | 
| 71 | 
            +
                    self.item_names = sorted(list(self.item2txt.keys()))
         | 
| 72 | 
            +
                    if self.binarization_args['shuffle']:
         | 
| 73 | 
            +
                        random.seed(1234)
         | 
| 74 | 
            +
                        random.shuffle(self.item_names)
         | 
| 75 | 
            +
                    self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                @property
         | 
| 78 | 
            +
                def train_item_names(self):
         | 
| 79 | 
            +
                    return self._train_item_names
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                @property
         | 
| 82 | 
            +
                def valid_item_names(self):
         | 
| 83 | 
            +
                    return self._test_item_names
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                @property
         | 
| 86 | 
            +
                def test_item_names(self):
         | 
| 87 | 
            +
                    return self._test_item_names
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def process(self):
         | 
| 90 | 
            +
                    self.load_meta_data()
         | 
| 91 | 
            +
                    os.makedirs(hparams['binary_data_dir'], exist_ok=True)
         | 
| 92 | 
            +
                    self.spk_map = self.build_spk_map()
         | 
| 93 | 
            +
                    print("| spk_map: ", self.spk_map)
         | 
| 94 | 
            +
                    spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
         | 
| 95 | 
            +
                    json.dump(self.spk_map, open(spk_map_fn, 'w'))
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.phone_encoder = self._phone_encoder()
         | 
| 98 | 
            +
                    self.process_data('valid')
         | 
| 99 | 
            +
                    self.process_data('test')
         | 
| 100 | 
            +
                    self.process_data('train')
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def _phone_encoder(self):
         | 
| 103 | 
            +
                    ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
         | 
| 104 | 
            +
                    ph_set = []
         | 
| 105 | 
            +
                    if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
         | 
| 106 | 
            +
                        for ph_sent in self.item2ph.values():
         | 
| 107 | 
            +
                            ph_set += ph_sent.split(' ')
         | 
| 108 | 
            +
                        ph_set = sorted(set(ph_set))
         | 
| 109 | 
            +
                        json.dump(ph_set, open(ph_set_fn, 'w'))
         | 
| 110 | 
            +
                        print("| Build phone set: ", ph_set)
         | 
| 111 | 
            +
                    else:
         | 
| 112 | 
            +
                        ph_set = json.load(open(ph_set_fn, 'r'))
         | 
| 113 | 
            +
                        print("| Load phone set: ", ph_set)
         | 
| 114 | 
            +
                    return build_phone_encoder(hparams['binary_data_dir'])
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                # @staticmethod
         | 
| 117 | 
            +
                # def get_pitch(wav_fn, spec, res):
         | 
| 118 | 
            +
                #     wav_suffix = '_wf0.wav'
         | 
| 119 | 
            +
                #     f0_suffix = '_f0.npy'
         | 
| 120 | 
            +
                #     f0fn = wav_fn.replace(wav_suffix, f0_suffix)
         | 
| 121 | 
            +
                #     pitch_info = np.load(f0fn)
         | 
| 122 | 
            +
                #     f0 = [x[1] for x in pitch_info]
         | 
| 123 | 
            +
                #     spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
         | 
| 124 | 
            +
                #     f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
         | 
| 125 | 
            +
                #     f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
         | 
| 126 | 
            +
                #     # f0_x_coor = np.arange(0, 1, 1 / len(f0))
         | 
| 127 | 
            +
                #     # f0_x_coor[-1] = 1
         | 
| 128 | 
            +
                #     # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)]
         | 
| 129 | 
            +
                #     if sum(f0) == 0:
         | 
| 130 | 
            +
                #         raise BinarizationError("Empty f0")
         | 
| 131 | 
            +
                #     assert len(f0) == len(spec), (len(f0), len(spec))
         | 
| 132 | 
            +
                #     pitch_coarse = f0_to_coarse(f0)
         | 
| 133 | 
            +
                #
         | 
| 134 | 
            +
                #     # vis f0
         | 
| 135 | 
            +
                #     # import matplotlib.pyplot as plt
         | 
| 136 | 
            +
                #     # from textgrid import TextGrid
         | 
| 137 | 
            +
                #     # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid')
         | 
| 138 | 
            +
                #     # fig = plt.figure(figsize=(12, 6))
         | 
| 139 | 
            +
                #     # plt.pcolor(spec.T, vmin=-5, vmax=0)
         | 
| 140 | 
            +
                #     # ax = plt.gca()
         | 
| 141 | 
            +
                #     # ax2 = ax.twinx()
         | 
| 142 | 
            +
                #     # ax2.plot(f0, color='red')
         | 
| 143 | 
            +
                #     # ax2.set_ylim(0, 800)
         | 
| 144 | 
            +
                #     # itvs = TextGrid.fromFile(tg_fn)[0]
         | 
| 145 | 
            +
                #     # for itv in itvs:
         | 
| 146 | 
            +
                #     #     x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size']
         | 
| 147 | 
            +
                #     #     plt.vlines(x=x, ymin=0, ymax=80, color='black')
         | 
| 148 | 
            +
                #     #     plt.text(x=x, y=20, s=itv.mark, color='black')
         | 
| 149 | 
            +
                #     # plt.savefig('tmp/20211229_singing_plots_test.png')
         | 
| 150 | 
            +
                #
         | 
| 151 | 
            +
                #     res['f0'] = f0
         | 
| 152 | 
            +
                #     res['pitch'] = pitch_coarse
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                @classmethod
         | 
| 155 | 
            +
                def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
         | 
| 156 | 
            +
                    if hparams['vocoder'] in VOCODERS:
         | 
| 157 | 
            +
                        wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
         | 
| 158 | 
            +
                    else:
         | 
| 159 | 
            +
                        wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
         | 
| 160 | 
            +
                    res = {
         | 
| 161 | 
            +
                        'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
         | 
| 162 | 
            +
                        'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
         | 
| 163 | 
            +
                    }
         | 
| 164 | 
            +
                    try:
         | 
| 165 | 
            +
                        if binarization_args['with_f0']:
         | 
| 166 | 
            +
                            # cls.get_pitch(wav_fn, mel, res)
         | 
| 167 | 
            +
                            cls.get_pitch(wav, mel, res)
         | 
| 168 | 
            +
                        if binarization_args['with_txt']:
         | 
| 169 | 
            +
                            try:
         | 
| 170 | 
            +
                                # print(ph)
         | 
| 171 | 
            +
                                phone_encoded = res['phone'] = encoder.encode(ph)
         | 
| 172 | 
            +
                            except:
         | 
| 173 | 
            +
                                traceback.print_exc()
         | 
| 174 | 
            +
                                raise BinarizationError(f"Empty phoneme")
         | 
| 175 | 
            +
                            if binarization_args['with_align']:
         | 
| 176 | 
            +
                                cls.get_align(tg_fn, ph, mel, phone_encoded, res)
         | 
| 177 | 
            +
                    except BinarizationError as e:
         | 
| 178 | 
            +
                        print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
         | 
| 179 | 
            +
                        return None
         | 
| 180 | 
            +
                    return res
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            class MidiSingingBinarizer(SingingBinarizer):
         | 
| 184 | 
            +
                item2midi = {}
         | 
| 185 | 
            +
                item2midi_dur = {}
         | 
| 186 | 
            +
                item2is_slur = {}
         | 
| 187 | 
            +
                item2ph_durs = {}
         | 
| 188 | 
            +
                item2wdb = {}
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def load_meta_data(self):
         | 
| 191 | 
            +
                    for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
         | 
| 192 | 
            +
                        meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json')))   # [list of dict]
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                        for song_item in meta_midi:
         | 
| 195 | 
            +
                            item_name = raw_item_name = song_item['item_name']
         | 
| 196 | 
            +
                            if len(self.processed_data_dirs) > 1:
         | 
| 197 | 
            +
                                item_name = f'ds{ds_id}_{item_name}'
         | 
| 198 | 
            +
                            self.item2wavfn[item_name] = song_item['wav_fn']
         | 
| 199 | 
            +
                            self.item2txt[item_name] = song_item['txt']
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                            self.item2ph[item_name] = ' '.join(song_item['phs'])
         | 
| 202 | 
            +
                            self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP', '<SIL>'] else 0 for x in song_item['phs']]
         | 
| 203 | 
            +
                            self.item2ph_durs[item_name] = song_item['ph_dur']
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                            self.item2midi[item_name] = song_item['notes']
         | 
| 206 | 
            +
                            self.item2midi_dur[item_name] = song_item['notes_dur']
         | 
| 207 | 
            +
                            self.item2is_slur[item_name] = song_item['is_slur']
         | 
| 208 | 
            +
                            self.item2spk[item_name] = 'pop-cs'
         | 
| 209 | 
            +
                            if len(self.processed_data_dirs) > 1:
         | 
| 210 | 
            +
                                self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    print('spkers: ', set(self.item2spk.values()))
         | 
| 213 | 
            +
                    self.item_names = sorted(list(self.item2txt.keys()))
         | 
| 214 | 
            +
                    if self.binarization_args['shuffle']:
         | 
| 215 | 
            +
                        random.seed(1234)
         | 
| 216 | 
            +
                        random.shuffle(self.item_names)
         | 
| 217 | 
            +
                    self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                @staticmethod
         | 
| 220 | 
            +
                def get_pitch(wav_fn, wav, spec, ph, res):
         | 
| 221 | 
            +
                    wav_suffix = '.wav'
         | 
| 222 | 
            +
                    # midi_suffix = '.mid'
         | 
| 223 | 
            +
                    wav_dir = 'wavs'
         | 
| 224 | 
            +
                    f0_dir = 'f0'
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    item_name = '/'.join(os.path.splitext(wav_fn)[0].split('/')[-2:]).replace('_wf0', '')
         | 
| 227 | 
            +
                    res['pitch_midi'] = np.asarray(MidiSingingBinarizer.item2midi[item_name])
         | 
| 228 | 
            +
                    res['midi_dur'] = np.asarray(MidiSingingBinarizer.item2midi_dur[item_name])
         | 
| 229 | 
            +
                    res['is_slur'] = np.asarray(MidiSingingBinarizer.item2is_slur[item_name])
         | 
| 230 | 
            +
                    res['word_boundary'] = np.asarray(MidiSingingBinarizer.item2wdb[item_name])
         | 
| 231 | 
            +
                    assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (
         | 
| 232 | 
            +
                    res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    # gt f0.
         | 
| 235 | 
            +
                    gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
         | 
| 236 | 
            +
                    if sum(gt_f0) == 0:
         | 
| 237 | 
            +
                        raise BinarizationError("Empty **gt** f0")
         | 
| 238 | 
            +
                    res['f0'] = gt_f0
         | 
| 239 | 
            +
                    res['pitch'] = gt_pitch_coarse
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                @staticmethod
         | 
| 242 | 
            +
                def get_align(ph_durs, mel, phone_encoded, res, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']):
         | 
| 243 | 
            +
                    mel2ph = np.zeros([mel.shape[0]], int)
         | 
| 244 | 
            +
                    startTime = 0
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    for i_ph in range(len(ph_durs)):
         | 
| 247 | 
            +
                        start_frame = int(startTime * audio_sample_rate / hop_size + 0.5)
         | 
| 248 | 
            +
                        end_frame = int((startTime + ph_durs[i_ph]) * audio_sample_rate / hop_size + 0.5)
         | 
| 249 | 
            +
                        mel2ph[start_frame:end_frame] = i_ph + 1
         | 
| 250 | 
            +
                        startTime = startTime + ph_durs[i_ph]
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    # print('ph durs: ', ph_durs)
         | 
| 253 | 
            +
                    # print('mel2ph: ', mel2ph, len(mel2ph))
         | 
| 254 | 
            +
                    res['mel2ph'] = mel2ph
         | 
| 255 | 
            +
                    # res['dur'] = None
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                @classmethod
         | 
| 258 | 
            +
                def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
         | 
| 259 | 
            +
                    if hparams['vocoder'] in VOCODERS:
         | 
| 260 | 
            +
                        wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
         | 
| 261 | 
            +
                    else:
         | 
| 262 | 
            +
                        wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
         | 
| 263 | 
            +
                    res = {
         | 
| 264 | 
            +
                        'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
         | 
| 265 | 
            +
                        'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
         | 
| 266 | 
            +
                    }
         | 
| 267 | 
            +
                    try:
         | 
| 268 | 
            +
                        if binarization_args['with_f0']:
         | 
| 269 | 
            +
                            cls.get_pitch(wav_fn, wav, mel, ph, res)
         | 
| 270 | 
            +
                        if binarization_args['with_txt']:
         | 
| 271 | 
            +
                            try:
         | 
| 272 | 
            +
                                phone_encoded = res['phone'] = encoder.encode(ph)
         | 
| 273 | 
            +
                            except:
         | 
| 274 | 
            +
                                traceback.print_exc()
         | 
| 275 | 
            +
                                raise BinarizationError(f"Empty phoneme")
         | 
| 276 | 
            +
                            if binarization_args['with_align']:
         | 
| 277 | 
            +
                                cls.get_align(MidiSingingBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
         | 
| 278 | 
            +
                    except BinarizationError as e:
         | 
| 279 | 
            +
                        print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
         | 
| 280 | 
            +
                        return None
         | 
| 281 | 
            +
                    return res
         | 
| 282 | 
            +
             | 
| 283 | 
            +
             | 
| 284 | 
            +
            class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer):
         | 
| 285 | 
            +
                pass
         | 
| 286 | 
            +
             | 
| 287 | 
            +
            class M4SingerBinarizer(MidiSingingBinarizer):
         | 
| 288 | 
            +
                item2midi = {}
         | 
| 289 | 
            +
                item2midi_dur = {}
         | 
| 290 | 
            +
                item2is_slur = {}
         | 
| 291 | 
            +
                item2ph_durs = {}
         | 
| 292 | 
            +
                item2wdb = {}
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def split_train_test_set(self, item_names):
         | 
| 295 | 
            +
                    item_names = deepcopy(item_names)
         | 
| 296 | 
            +
                    test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
         | 
| 297 | 
            +
                    train_item_names = [x for x in item_names if x not in set(test_item_names)]
         | 
| 298 | 
            +
                    logging.info("train {}".format(len(train_item_names)))
         | 
| 299 | 
            +
                    logging.info("test {}".format(len(test_item_names)))
         | 
| 300 | 
            +
                    return train_item_names, test_item_names
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                def load_meta_data(self):
         | 
| 303 | 
            +
                    raw_data_dir = hparams['raw_data_dir']
         | 
| 304 | 
            +
                    song_items = json.load(open(os.path.join(raw_data_dir, 'meta.json')))  # [list of dict]
         | 
| 305 | 
            +
                    for song_item in song_items:
         | 
| 306 | 
            +
                        item_name = raw_item_name = song_item['item_name']
         | 
| 307 | 
            +
                        singer, song_name, sent_id = item_name.split("#")
         | 
| 308 | 
            +
                        self.item2wavfn[item_name] = f'{raw_data_dir}/{singer}#{song_name}/{sent_id}.wav'
         | 
| 309 | 
            +
                        self.item2txt[item_name] = song_item['txt']
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                        self.item2ph[item_name] = ' '.join(song_item['phs'])
         | 
| 312 | 
            +
                        self.item2ph_durs[item_name] = song_item['ph_dur']
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                        self.item2midi[item_name] = song_item['notes']
         | 
| 315 | 
            +
                        self.item2midi_dur[item_name] = song_item['notes_dur']
         | 
| 316 | 
            +
                        self.item2is_slur[item_name] = song_item['is_slur']
         | 
| 317 | 
            +
                        self.item2wdb[item_name] = [1 if (0 < i < len(song_item['phs']) - 1 and p in ALL_YUNMU + ['<SP>', '<AP>'])\
         | 
| 318 | 
            +
                                                    or i == len(song_item['phs']) - 1 else 0 for i, p in enumerate(song_item['phs'])]
         | 
| 319 | 
            +
                        self.item2spk[item_name] = singer
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    print('spkers: ', set(self.item2spk.values()))
         | 
| 322 | 
            +
                    self.item_names = sorted(list(self.item2txt.keys()))
         | 
| 323 | 
            +
                    if self.binarization_args['shuffle']:
         | 
| 324 | 
            +
                        random.seed(1234)
         | 
| 325 | 
            +
                        random.shuffle(self.item_names)
         | 
| 326 | 
            +
                    self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                @staticmethod
         | 
| 329 | 
            +
                def get_pitch(item_name, wav, spec, ph, res):
         | 
| 330 | 
            +
                    wav_suffix = '.wav'
         | 
| 331 | 
            +
                    # midi_suffix = '.mid'
         | 
| 332 | 
            +
                    wav_dir = 'wavs'
         | 
| 333 | 
            +
                    f0_dir = 'text_f0_align'
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    #item_name = os.path.splitext(os.path.basename(wav_fn))[0]
         | 
| 336 | 
            +
                    res['pitch_midi'] = np.asarray(M4SingerBinarizer.item2midi[item_name])
         | 
| 337 | 
            +
                    res['midi_dur'] = np.asarray(M4SingerBinarizer.item2midi_dur[item_name])
         | 
| 338 | 
            +
                    res['is_slur'] = np.asarray(M4SingerBinarizer.item2is_slur[item_name])
         | 
| 339 | 
            +
                    res['word_boundary'] = np.asarray(M4SingerBinarizer.item2wdb[item_name])
         | 
| 340 | 
            +
                    assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    # gt f0.
         | 
| 343 | 
            +
                    # f0 = None
         | 
| 344 | 
            +
                    # f0_suffix = '_f0.npy'
         | 
| 345 | 
            +
                    # f0fn = wav_fn.replace(wav_suffix, f0_suffix).replace(wav_dir, f0_dir)
         | 
| 346 | 
            +
                    # pitch_info = np.load(f0fn)
         | 
| 347 | 
            +
                    # f0 = [x[1] for x in pitch_info]
         | 
| 348 | 
            +
                    # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
         | 
| 349 | 
            +
                    #
         | 
| 350 | 
            +
                    # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
         | 
| 351 | 
            +
                    # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
         | 
| 352 | 
            +
                    # if sum(f0) == 0:
         | 
| 353 | 
            +
                    #     raise BinarizationError("Empty **gt** f0")
         | 
| 354 | 
            +
                    #
         | 
| 355 | 
            +
                    # pitch_coarse = f0_to_coarse(f0)
         | 
| 356 | 
            +
                    # res['f0'] = f0
         | 
| 357 | 
            +
                    # res['pitch'] = pitch_coarse
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    # gt f0.
         | 
| 360 | 
            +
                    gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
         | 
| 361 | 
            +
                    if sum(gt_f0) == 0:
         | 
| 362 | 
            +
                        raise BinarizationError("Empty **gt** f0")
         | 
| 363 | 
            +
                    res['f0'] = gt_f0
         | 
| 364 | 
            +
                    res['pitch'] = gt_pitch_coarse
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                @classmethod
         | 
| 367 | 
            +
                def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
         | 
| 368 | 
            +
                    if hparams['vocoder'] in VOCODERS:
         | 
| 369 | 
            +
                        wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
         | 
| 370 | 
            +
                    else:
         | 
| 371 | 
            +
                        wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
         | 
| 372 | 
            +
                    res = {
         | 
| 373 | 
            +
                        'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
         | 
| 374 | 
            +
                        'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
         | 
| 375 | 
            +
                    }
         | 
| 376 | 
            +
                    try:
         | 
| 377 | 
            +
                        if binarization_args['with_f0']:
         | 
| 378 | 
            +
                            cls.get_pitch(item_name, wav, mel, ph, res)
         | 
| 379 | 
            +
                        if binarization_args['with_txt']:
         | 
| 380 | 
            +
                            try:
         | 
| 381 | 
            +
                                phone_encoded = res['phone'] = encoder.encode(ph)
         | 
| 382 | 
            +
                            except:
         | 
| 383 | 
            +
                                traceback.print_exc()
         | 
| 384 | 
            +
                                raise BinarizationError(f"Empty phoneme")
         | 
| 385 | 
            +
                            if binarization_args['with_align']:
         | 
| 386 | 
            +
                                cls.get_align(M4SingerBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
         | 
| 387 | 
            +
                    except BinarizationError as e:
         | 
| 388 | 
            +
                        print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
         | 
| 389 | 
            +
                        return None
         | 
| 390 | 
            +
                    return res
         | 
| 391 | 
            +
             | 
| 392 | 
            +
            if __name__ == "__main__":
         | 
| 393 | 
            +
                SingingBinarizer().process()
         | 
    	
        data_gen/tts/base_binarizer.py
    ADDED
    
    | @@ -0,0 +1,224 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            os.environ["OMP_NUM_THREADS"] = "1"
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from utils.multiprocess_utils import chunked_multiprocess_run
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            import traceback
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            from resemblyzer import VoiceEncoder
         | 
| 9 | 
            +
            from tqdm import tqdm
         | 
| 10 | 
            +
            from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder
         | 
| 11 | 
            +
            from utils.hparams import set_hparams, hparams
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            from utils.indexed_datasets import IndexedDatasetBuilder
         | 
| 14 | 
            +
            from vocoders.base_vocoder import VOCODERS
         | 
| 15 | 
            +
            import pandas as pd
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class BinarizationError(Exception):
         | 
| 19 | 
            +
                pass
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class BaseBinarizer:
         | 
| 23 | 
            +
                def __init__(self, processed_data_dir=None):
         | 
| 24 | 
            +
                    if processed_data_dir is None:
         | 
| 25 | 
            +
                        processed_data_dir = hparams['processed_data_dir']
         | 
| 26 | 
            +
                    self.processed_data_dirs = processed_data_dir.split(",")
         | 
| 27 | 
            +
                    self.binarization_args = hparams['binarization_args']
         | 
| 28 | 
            +
                    self.pre_align_args = hparams['pre_align_args']
         | 
| 29 | 
            +
                    self.forced_align = self.pre_align_args['forced_align']
         | 
| 30 | 
            +
                    tg_dir = None
         | 
| 31 | 
            +
                    if self.forced_align == 'mfa':
         | 
| 32 | 
            +
                        tg_dir = 'mfa_outputs'
         | 
| 33 | 
            +
                    if self.forced_align == 'kaldi':
         | 
| 34 | 
            +
                        tg_dir = 'kaldi_outputs'
         | 
| 35 | 
            +
                    self.item2txt = {}
         | 
| 36 | 
            +
                    self.item2ph = {}
         | 
| 37 | 
            +
                    self.item2wavfn = {}
         | 
| 38 | 
            +
                    self.item2tgfn = {}
         | 
| 39 | 
            +
                    self.item2spk = {}
         | 
| 40 | 
            +
                    for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
         | 
| 41 | 
            +
                        self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str)
         | 
| 42 | 
            +
                        for r_idx, r in self.meta_df.iterrows():
         | 
| 43 | 
            +
                            item_name = raw_item_name = r['item_name']
         | 
| 44 | 
            +
                            if len(self.processed_data_dirs) > 1:
         | 
| 45 | 
            +
                                item_name = f'ds{ds_id}_{item_name}'
         | 
| 46 | 
            +
                            self.item2txt[item_name] = r['txt']
         | 
| 47 | 
            +
                            self.item2ph[item_name] = r['ph']
         | 
| 48 | 
            +
                            self.item2wavfn[item_name] = os.path.join(hparams['raw_data_dir'], 'wavs', os.path.basename(r['wav_fn']).split('_')[1])
         | 
| 49 | 
            +
                            self.item2spk[item_name] = r.get('spk', 'SPK1')
         | 
| 50 | 
            +
                            if len(self.processed_data_dirs) > 1:
         | 
| 51 | 
            +
                                self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
         | 
| 52 | 
            +
                            if tg_dir is not None:
         | 
| 53 | 
            +
                                self.item2tgfn[item_name] = f"{processed_data_dir}/{tg_dir}/{raw_item_name}.TextGrid"
         | 
| 54 | 
            +
                    self.item_names = sorted(list(self.item2txt.keys()))
         | 
| 55 | 
            +
                    if self.binarization_args['shuffle']:
         | 
| 56 | 
            +
                        random.seed(1234)
         | 
| 57 | 
            +
                        random.shuffle(self.item_names)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                @property
         | 
| 60 | 
            +
                def train_item_names(self):
         | 
| 61 | 
            +
                    return self.item_names[hparams['test_num']+hparams['valid_num']:]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                @property
         | 
| 64 | 
            +
                def valid_item_names(self):
         | 
| 65 | 
            +
                    return self.item_names[0: hparams['test_num']+hparams['valid_num']]  #
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                @property
         | 
| 68 | 
            +
                def test_item_names(self):
         | 
| 69 | 
            +
                    return self.item_names[0: hparams['test_num']]  # Audios for MOS testing are in 'test_ids'
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def build_spk_map(self):
         | 
| 72 | 
            +
                    spk_map = set()
         | 
| 73 | 
            +
                    for item_name in self.item_names:
         | 
| 74 | 
            +
                        spk_name = self.item2spk[item_name]
         | 
| 75 | 
            +
                        spk_map.add(spk_name)
         | 
| 76 | 
            +
                    spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
         | 
| 77 | 
            +
                    assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
         | 
| 78 | 
            +
                    return spk_map
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def item_name2spk_id(self, item_name):
         | 
| 81 | 
            +
                    return self.spk_map[self.item2spk[item_name]]
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def _phone_encoder(self):
         | 
| 84 | 
            +
                    ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
         | 
| 85 | 
            +
                    ph_set = []
         | 
| 86 | 
            +
                    if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
         | 
| 87 | 
            +
                        for processed_data_dir in self.processed_data_dirs:
         | 
| 88 | 
            +
                            ph_set += [x.split(' ')[0] for x in open(f'{processed_data_dir}/dict.txt').readlines()]
         | 
| 89 | 
            +
                        ph_set = sorted(set(ph_set))
         | 
| 90 | 
            +
                        json.dump(ph_set, open(ph_set_fn, 'w'))
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        ph_set = json.load(open(ph_set_fn, 'r'))
         | 
| 93 | 
            +
                    print("| phone set: ", ph_set)
         | 
| 94 | 
            +
                    return build_phone_encoder(hparams['binary_data_dir'])
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def meta_data(self, prefix):
         | 
| 97 | 
            +
                    if prefix == 'valid':
         | 
| 98 | 
            +
                        item_names = self.valid_item_names
         | 
| 99 | 
            +
                    elif prefix == 'test':
         | 
| 100 | 
            +
                        item_names = self.test_item_names
         | 
| 101 | 
            +
                    else:
         | 
| 102 | 
            +
                        item_names = self.train_item_names
         | 
| 103 | 
            +
                    for item_name in item_names:
         | 
| 104 | 
            +
                        ph = self.item2ph[item_name]
         | 
| 105 | 
            +
                        txt = self.item2txt[item_name]
         | 
| 106 | 
            +
                        tg_fn = self.item2tgfn.get(item_name)
         | 
| 107 | 
            +
                        wav_fn = self.item2wavfn[item_name]
         | 
| 108 | 
            +
                        spk_id = self.item_name2spk_id(item_name)
         | 
| 109 | 
            +
                        yield item_name, ph, txt, tg_fn, wav_fn, spk_id
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def process(self):
         | 
| 112 | 
            +
                    os.makedirs(hparams['binary_data_dir'], exist_ok=True)
         | 
| 113 | 
            +
                    self.spk_map = self.build_spk_map()
         | 
| 114 | 
            +
                    print("| spk_map: ", self.spk_map)
         | 
| 115 | 
            +
                    spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
         | 
| 116 | 
            +
                    json.dump(self.spk_map, open(spk_map_fn, 'w'))
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    self.phone_encoder = self._phone_encoder()
         | 
| 119 | 
            +
                    self.process_data('valid')
         | 
| 120 | 
            +
                    self.process_data('test')
         | 
| 121 | 
            +
                    self.process_data('train')
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def process_data(self, prefix):
         | 
| 124 | 
            +
                    data_dir = hparams['binary_data_dir']
         | 
| 125 | 
            +
                    args = []
         | 
| 126 | 
            +
                    builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
         | 
| 127 | 
            +
                    lengths = []
         | 
| 128 | 
            +
                    f0s = []
         | 
| 129 | 
            +
                    total_sec = 0
         | 
| 130 | 
            +
                    if self.binarization_args['with_spk_embed']:
         | 
| 131 | 
            +
                        voice_encoder = VoiceEncoder().cuda()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    meta_data = list(self.meta_data(prefix))
         | 
| 134 | 
            +
                    for m in meta_data:
         | 
| 135 | 
            +
                        args.append(list(m) + [self.phone_encoder, self.binarization_args])
         | 
| 136 | 
            +
                    num_workers = int(os.getenv('N_PROC', os.cpu_count() // 3))
         | 
| 137 | 
            +
                    for f_id, (_, item) in enumerate(
         | 
| 138 | 
            +
                            zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))):
         | 
| 139 | 
            +
                        if item is None:
         | 
| 140 | 
            +
                            continue
         | 
| 141 | 
            +
                        item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \
         | 
| 142 | 
            +
                            if self.binarization_args['with_spk_embed'] else None
         | 
| 143 | 
            +
                        if not self.binarization_args['with_wav'] and 'wav' in item:
         | 
| 144 | 
            +
                            #print("del wav")
         | 
| 145 | 
            +
                            del item['wav']
         | 
| 146 | 
            +
                        builder.add_item(item)
         | 
| 147 | 
            +
                        lengths.append(item['len'])
         | 
| 148 | 
            +
                        total_sec += item['sec']
         | 
| 149 | 
            +
                        if item.get('f0') is not None:
         | 
| 150 | 
            +
                            f0s.append(item['f0'])
         | 
| 151 | 
            +
                    builder.finalize()
         | 
| 152 | 
            +
                    np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
         | 
| 153 | 
            +
                    if len(f0s) > 0:
         | 
| 154 | 
            +
                        f0s = np.concatenate(f0s, 0)
         | 
| 155 | 
            +
                        f0s = f0s[f0s != 0]
         | 
| 156 | 
            +
                        np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
         | 
| 157 | 
            +
                    print(f"| {prefix} total duration: {total_sec:.3f}s")
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                @classmethod
         | 
| 160 | 
            +
                def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
         | 
| 161 | 
            +
                    if hparams['vocoder'] in VOCODERS:
         | 
| 162 | 
            +
                        wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
         | 
| 165 | 
            +
                    res = {
         | 
| 166 | 
            +
                        'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
         | 
| 167 | 
            +
                        'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
         | 
| 168 | 
            +
                    }
         | 
| 169 | 
            +
                    try:
         | 
| 170 | 
            +
                        if binarization_args['with_f0']:
         | 
| 171 | 
            +
                            cls.get_pitch(wav, mel, res)
         | 
| 172 | 
            +
                            if binarization_args['with_f0cwt']:
         | 
| 173 | 
            +
                                cls.get_f0cwt(res['f0'], res)
         | 
| 174 | 
            +
                        if binarization_args['with_txt']:
         | 
| 175 | 
            +
                            try:
         | 
| 176 | 
            +
                                phone_encoded = res['phone'] = encoder.encode(ph)
         | 
| 177 | 
            +
                            except:
         | 
| 178 | 
            +
                                traceback.print_exc()
         | 
| 179 | 
            +
                                raise BinarizationError(f"Empty phoneme")
         | 
| 180 | 
            +
                            if binarization_args['with_align']:
         | 
| 181 | 
            +
                                cls.get_align(tg_fn, ph, mel, phone_encoded, res)
         | 
| 182 | 
            +
                    except BinarizationError as e:
         | 
| 183 | 
            +
                        print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
         | 
| 184 | 
            +
                        return None
         | 
| 185 | 
            +
                    return res
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                @staticmethod
         | 
| 188 | 
            +
                def get_align(tg_fn, ph, mel, phone_encoded, res):
         | 
| 189 | 
            +
                    if tg_fn is not None and os.path.exists(tg_fn):
         | 
| 190 | 
            +
                        mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        raise BinarizationError(f"Align not found")
         | 
| 193 | 
            +
                    if mel2ph.max() - 1 >= len(phone_encoded):
         | 
| 194 | 
            +
                        raise BinarizationError(
         | 
| 195 | 
            +
                            f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
         | 
| 196 | 
            +
                    res['mel2ph'] = mel2ph
         | 
| 197 | 
            +
                    res['dur'] = dur
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                @staticmethod
         | 
| 200 | 
            +
                def get_pitch(wav, mel, res):
         | 
| 201 | 
            +
                    f0, pitch_coarse = get_pitch(wav, mel, hparams)
         | 
| 202 | 
            +
                    if sum(f0) == 0:
         | 
| 203 | 
            +
                        raise BinarizationError("Empty f0")
         | 
| 204 | 
            +
                    res['f0'] = f0
         | 
| 205 | 
            +
                    res['pitch'] = pitch_coarse
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                @staticmethod
         | 
| 208 | 
            +
                def get_f0cwt(f0, res):
         | 
| 209 | 
            +
                    from utils.cwt import get_cont_lf0, get_lf0_cwt
         | 
| 210 | 
            +
                    uv, cont_lf0_lpf = get_cont_lf0(f0)
         | 
| 211 | 
            +
                    logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
         | 
| 212 | 
            +
                    cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
         | 
| 213 | 
            +
                    Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
         | 
| 214 | 
            +
                    if np.any(np.isnan(Wavelet_lf0)):
         | 
| 215 | 
            +
                        raise BinarizationError("NaN CWT")
         | 
| 216 | 
            +
                    res['cwt_spec'] = Wavelet_lf0
         | 
| 217 | 
            +
                    res['cwt_scales'] = scales
         | 
| 218 | 
            +
                    res['f0_mean'] = logf0s_mean_org
         | 
| 219 | 
            +
                    res['f0_std'] = logf0s_std_org
         | 
| 220 | 
            +
             | 
| 221 | 
            +
             | 
| 222 | 
            +
            if __name__ == "__main__":
         | 
| 223 | 
            +
                set_hparams()
         | 
| 224 | 
            +
                BaseBinarizer().process()
         | 
    	
        data_gen/tts/bin/binarize.py
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            os.environ["OMP_NUM_THREADS"] = "1"
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import importlib
         | 
| 6 | 
            +
            from utils.hparams import set_hparams, hparams
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def binarize():
         | 
| 10 | 
            +
                binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
         | 
| 11 | 
            +
                pkg = ".".join(binarizer_cls.split(".")[:-1])
         | 
| 12 | 
            +
                cls_name = binarizer_cls.split(".")[-1]
         | 
| 13 | 
            +
                binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
         | 
| 14 | 
            +
                print("| Binarizer: ", binarizer_cls)
         | 
| 15 | 
            +
                binarizer_cls().process()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            if __name__ == '__main__':
         | 
| 19 | 
            +
                set_hparams()
         | 
| 20 | 
            +
                binarize()
         | 
    	
        data_gen/tts/binarizer_zh.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            os.environ["OMP_NUM_THREADS"] = "1"
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from data_gen.tts.txt_processors.zh_g2pM import ALL_SHENMU
         | 
| 6 | 
            +
            from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
         | 
| 7 | 
            +
            from data_gen.tts.data_gen_utils import get_mel2ph
         | 
| 8 | 
            +
            from utils.hparams import set_hparams, hparams
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class ZhBinarizer(BaseBinarizer):
         | 
| 13 | 
            +
                @staticmethod
         | 
| 14 | 
            +
                def get_align(tg_fn, ph, mel, phone_encoded, res):
         | 
| 15 | 
            +
                    if tg_fn is not None and os.path.exists(tg_fn):
         | 
| 16 | 
            +
                        _, dur = get_mel2ph(tg_fn, ph, mel, hparams)
         | 
| 17 | 
            +
                    else:
         | 
| 18 | 
            +
                        raise BinarizationError(f"Align not found")
         | 
| 19 | 
            +
                    ph_list = ph.split(" ")
         | 
| 20 | 
            +
                    assert len(dur) == len(ph_list)
         | 
| 21 | 
            +
                    mel2ph = []
         | 
| 22 | 
            +
                    # 分隔符的时长分配给韵母
         | 
| 23 | 
            +
                    dur_cumsum = np.pad(np.cumsum(dur), [1, 0], mode='constant', constant_values=0)
         | 
| 24 | 
            +
                    for i in range(len(dur)):
         | 
| 25 | 
            +
                        p = ph_list[i]
         | 
| 26 | 
            +
                        if p[0] != '<' and not p[0].isalpha():
         | 
| 27 | 
            +
                            uv_ = res['f0'][dur_cumsum[i]:dur_cumsum[i + 1]] == 0
         | 
| 28 | 
            +
                            j = 0
         | 
| 29 | 
            +
                            while j < len(uv_) and not uv_[j]:
         | 
| 30 | 
            +
                                j += 1
         | 
| 31 | 
            +
                            dur[i - 1] += j
         | 
| 32 | 
            +
                            dur[i] -= j
         | 
| 33 | 
            +
                            if dur[i] < 100:
         | 
| 34 | 
            +
                                dur[i - 1] += dur[i]
         | 
| 35 | 
            +
                                dur[i] = 0
         | 
| 36 | 
            +
                    # 声母和韵母等长
         | 
| 37 | 
            +
                    for i in range(len(dur)):
         | 
| 38 | 
            +
                        p = ph_list[i]
         | 
| 39 | 
            +
                        if p in ALL_SHENMU:
         | 
| 40 | 
            +
                            p_next = ph_list[i + 1]
         | 
| 41 | 
            +
                            if not (dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU):
         | 
| 42 | 
            +
                                print(f"assert dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU, "
         | 
| 43 | 
            +
                                      f"dur[i]: {dur[i]}, p: {p}, p_next: {p_next}.")
         | 
| 44 | 
            +
                                continue
         | 
| 45 | 
            +
                            total = dur[i + 1] + dur[i]
         | 
| 46 | 
            +
                            dur[i] = total // 2
         | 
| 47 | 
            +
                            dur[i + 1] = total - dur[i]
         | 
| 48 | 
            +
                    for i in range(len(dur)):
         | 
| 49 | 
            +
                        mel2ph += [i + 1] * dur[i]
         | 
| 50 | 
            +
                    mel2ph = np.array(mel2ph)
         | 
| 51 | 
            +
                    if mel2ph.max() - 1 >= len(phone_encoded):
         | 
| 52 | 
            +
                        raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone_encoded))}")
         | 
| 53 | 
            +
                    res['mel2ph'] = mel2ph
         | 
| 54 | 
            +
                    res['dur'] = dur
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            if __name__ == "__main__":
         | 
| 58 | 
            +
                set_hparams()
         | 
| 59 | 
            +
                ZhBinarizer().process()
         | 
    	
        data_gen/tts/data_gen_utils.py
    ADDED
    
    | @@ -0,0 +1,347 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import warnings
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            warnings.filterwarnings("ignore")
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import parselmouth
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from skimage.transform import resize
         | 
| 9 | 
            +
            from utils.text_encoder import TokenTextEncoder
         | 
| 10 | 
            +
            from utils.pitch_utils import f0_to_coarse
         | 
| 11 | 
            +
            import struct
         | 
| 12 | 
            +
            import webrtcvad
         | 
| 13 | 
            +
            from scipy.ndimage.morphology import binary_dilation
         | 
| 14 | 
            +
            import librosa
         | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
            from utils import audio
         | 
| 17 | 
            +
            import pyloudnorm as pyln
         | 
| 18 | 
            +
            import re
         | 
| 19 | 
            +
            import json
         | 
| 20 | 
            +
            from collections import OrderedDict
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            PUNCS = '!,.?;:'
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            int16_max = (2 ** 15) - 1
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12):
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                Ensures that segments without voice in the waveform remain no longer than a
         | 
| 30 | 
            +
                threshold determined by the VAD parameters in params.py.
         | 
| 31 | 
            +
                :param wav: the raw waveform as a numpy array of floats
         | 
| 32 | 
            +
                :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have.
         | 
| 33 | 
            +
                :return: the same waveform with silences trimmed away (length <= original wav length)
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                ## Voice Activation Detection
         | 
| 37 | 
            +
                # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
         | 
| 38 | 
            +
                # This sets the granularity of the VAD. Should not need to be changed.
         | 
| 39 | 
            +
                sampling_rate = 16000
         | 
| 40 | 
            +
                wav_raw, sr = librosa.core.load(path, sr=sr)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                if norm:
         | 
| 43 | 
            +
                    meter = pyln.Meter(sr)  # create BS.1770 meter
         | 
| 44 | 
            +
                    loudness = meter.integrated_loudness(wav_raw)
         | 
| 45 | 
            +
                    wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0)
         | 
| 46 | 
            +
                    if np.abs(wav_raw).max() > 1.0:
         | 
| 47 | 
            +
                        wav_raw = wav_raw / np.abs(wav_raw).max()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best')
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                vad_window_length = 30  # In milliseconds
         | 
| 52 | 
            +
                # Number of frames to average together when performing the moving average smoothing.
         | 
| 53 | 
            +
                # The larger this value, the larger the VAD variations must be to not get smoothed out.
         | 
| 54 | 
            +
                vad_moving_average_width = 8
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # Compute the voice detection window size
         | 
| 57 | 
            +
                samples_per_window = (vad_window_length * sampling_rate) // 1000
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # Trim the end of the audio to have a multiple of the window size
         | 
| 60 | 
            +
                wav = wav[:len(wav) - (len(wav) % samples_per_window)]
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                # Convert the float waveform to 16-bit mono PCM
         | 
| 63 | 
            +
                pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                # Perform voice activation detection
         | 
| 66 | 
            +
                voice_flags = []
         | 
| 67 | 
            +
                vad = webrtcvad.Vad(mode=3)
         | 
| 68 | 
            +
                for window_start in range(0, len(wav), samples_per_window):
         | 
| 69 | 
            +
                    window_end = window_start + samples_per_window
         | 
| 70 | 
            +
                    voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
         | 
| 71 | 
            +
                                                     sample_rate=sampling_rate))
         | 
| 72 | 
            +
                voice_flags = np.array(voice_flags)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                # Smooth the voice detection with a moving average
         | 
| 75 | 
            +
                def moving_average(array, width):
         | 
| 76 | 
            +
                    array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
         | 
| 77 | 
            +
                    ret = np.cumsum(array_padded, dtype=float)
         | 
| 78 | 
            +
                    ret[width:] = ret[width:] - ret[:-width]
         | 
| 79 | 
            +
                    return ret[width - 1:] / width
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                audio_mask = moving_average(voice_flags, vad_moving_average_width)
         | 
| 82 | 
            +
                audio_mask = np.round(audio_mask).astype(np.bool)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                # Dilate the voiced regions
         | 
| 85 | 
            +
                audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
         | 
| 86 | 
            +
                audio_mask = np.repeat(audio_mask, samples_per_window)
         | 
| 87 | 
            +
                audio_mask = resize(audio_mask, (len(wav_raw),)) > 0
         | 
| 88 | 
            +
                if return_raw_wav:
         | 
| 89 | 
            +
                    return wav_raw, audio_mask, sr
         | 
| 90 | 
            +
                return wav_raw[audio_mask], audio_mask, sr
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def process_utterance(wav_path,
         | 
| 94 | 
            +
                                  fft_size=1024,
         | 
| 95 | 
            +
                                  hop_size=256,
         | 
| 96 | 
            +
                                  win_length=1024,
         | 
| 97 | 
            +
                                  window="hann",
         | 
| 98 | 
            +
                                  num_mels=80,
         | 
| 99 | 
            +
                                  fmin=80,
         | 
| 100 | 
            +
                                  fmax=7600,
         | 
| 101 | 
            +
                                  eps=1e-6,
         | 
| 102 | 
            +
                                  sample_rate=22050,
         | 
| 103 | 
            +
                                  loud_norm=False,
         | 
| 104 | 
            +
                                  min_level_db=-100,
         | 
| 105 | 
            +
                                  return_linear=False,
         | 
| 106 | 
            +
                                  trim_long_sil=False, vocoder='pwg'):
         | 
| 107 | 
            +
                if isinstance(wav_path, str):
         | 
| 108 | 
            +
                    if trim_long_sil:
         | 
| 109 | 
            +
                        wav, _, _ = trim_long_silences(wav_path, sample_rate)
         | 
| 110 | 
            +
                    else:
         | 
| 111 | 
            +
                        wav, _ = librosa.core.load(wav_path, sr=sample_rate)
         | 
| 112 | 
            +
                else:
         | 
| 113 | 
            +
                    wav = wav_path
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                if loud_norm:
         | 
| 116 | 
            +
                    meter = pyln.Meter(sample_rate)  # create BS.1770 meter
         | 
| 117 | 
            +
                    loudness = meter.integrated_loudness(wav)
         | 
| 118 | 
            +
                    wav = pyln.normalize.loudness(wav, loudness, -22.0)
         | 
| 119 | 
            +
                    if np.abs(wav).max() > 1:
         | 
| 120 | 
            +
                        wav = wav / np.abs(wav).max()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # get amplitude spectrogram
         | 
| 123 | 
            +
                x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
         | 
| 124 | 
            +
                                      win_length=win_length, window=window, pad_mode="constant")
         | 
| 125 | 
            +
                spc = np.abs(x_stft)  # (n_bins, T)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                # get mel basis
         | 
| 128 | 
            +
                fmin = 0 if fmin == -1 else fmin
         | 
| 129 | 
            +
                fmax = sample_rate / 2 if fmax == -1 else fmax
         | 
| 130 | 
            +
                mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)
         | 
| 131 | 
            +
                mel = mel_basis @ spc
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                if vocoder == 'pwg':
         | 
| 134 | 
            +
                    mel = np.log10(np.maximum(eps, mel))  # (n_mel_bins, T)
         | 
| 135 | 
            +
                else:
         | 
| 136 | 
            +
                    assert False, f'"{vocoder}" is not in ["pwg"].'
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1)
         | 
| 139 | 
            +
                wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
         | 
| 140 | 
            +
                wav = wav[:mel.shape[1] * hop_size]
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                if not return_linear:
         | 
| 143 | 
            +
                    return wav, mel
         | 
| 144 | 
            +
                else:
         | 
| 145 | 
            +
                    spc = audio.amp_to_db(spc)
         | 
| 146 | 
            +
                    spc = audio.normalize(spc, {'min_level_db': min_level_db})
         | 
| 147 | 
            +
                    return wav, mel, spc
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            def get_pitch(wav_data, mel, hparams):
         | 
| 151 | 
            +
                """
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                :param wav_data: [T]
         | 
| 154 | 
            +
                :param mel: [T, 80]
         | 
| 155 | 
            +
                :param hparams:
         | 
| 156 | 
            +
                :return:
         | 
| 157 | 
            +
                """
         | 
| 158 | 
            +
                time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000
         | 
| 159 | 
            +
                f0_min = 80
         | 
| 160 | 
            +
                f0_max = 750
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                if hparams['hop_size'] == 128:
         | 
| 163 | 
            +
                    pad_size = 4
         | 
| 164 | 
            +
                elif hparams['hop_size'] == 256:
         | 
| 165 | 
            +
                    pad_size = 2
         | 
| 166 | 
            +
                else:
         | 
| 167 | 
            +
                    assert False
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
         | 
| 170 | 
            +
                    time_step=time_step / 1000, voicing_threshold=0.6,
         | 
| 171 | 
            +
                    pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
         | 
| 172 | 
            +
                lpad = pad_size * 2
         | 
| 173 | 
            +
                rpad = len(mel) - len(f0) - lpad
         | 
| 174 | 
            +
                f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
         | 
| 175 | 
            +
                # mel and f0 are extracted by 2 different libraries. we should force them to have the same length.
         | 
| 176 | 
            +
                # Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value...
         | 
| 177 | 
            +
                # Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda)
         | 
| 178 | 
            +
                delta_l = len(mel) - len(f0)
         | 
| 179 | 
            +
                assert np.abs(delta_l) <= 8
         | 
| 180 | 
            +
                if delta_l > 0:
         | 
| 181 | 
            +
                    f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
         | 
| 182 | 
            +
                f0 = f0[:len(mel)]
         | 
| 183 | 
            +
                pitch_coarse = f0_to_coarse(f0)
         | 
| 184 | 
            +
                return f0, pitch_coarse
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def remove_empty_lines(text):
         | 
| 188 | 
            +
                """remove empty lines"""
         | 
| 189 | 
            +
                assert (len(text) > 0)
         | 
| 190 | 
            +
                assert (isinstance(text, list))
         | 
| 191 | 
            +
                text = [t.strip() for t in text]
         | 
| 192 | 
            +
                if "" in text:
         | 
| 193 | 
            +
                    text.remove("")
         | 
| 194 | 
            +
                return text
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            class TextGrid(object):
         | 
| 198 | 
            +
                def __init__(self, text):
         | 
| 199 | 
            +
                    text = remove_empty_lines(text)
         | 
| 200 | 
            +
                    self.text = text
         | 
| 201 | 
            +
                    self.line_count = 0
         | 
| 202 | 
            +
                    self._get_type()
         | 
| 203 | 
            +
                    self._get_time_intval()
         | 
| 204 | 
            +
                    self._get_size()
         | 
| 205 | 
            +
                    self.tier_list = []
         | 
| 206 | 
            +
                    self._get_item_list()
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def _extract_pattern(self, pattern, inc):
         | 
| 209 | 
            +
                    """
         | 
| 210 | 
            +
                    Parameters
         | 
| 211 | 
            +
                    ----------
         | 
| 212 | 
            +
                    pattern : regex to extract pattern
         | 
| 213 | 
            +
                    inc : increment of line count after extraction
         | 
| 214 | 
            +
                    Returns
         | 
| 215 | 
            +
                    -------
         | 
| 216 | 
            +
                    group : extracted info
         | 
| 217 | 
            +
                    """
         | 
| 218 | 
            +
                    try:
         | 
| 219 | 
            +
                        group = re.match(pattern, self.text[self.line_count]).group(1)
         | 
| 220 | 
            +
                        self.line_count += inc
         | 
| 221 | 
            +
                    except AttributeError:
         | 
| 222 | 
            +
                        raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count]))
         | 
| 223 | 
            +
                    return group
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def _get_type(self):
         | 
| 226 | 
            +
                    self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def _get_time_intval(self):
         | 
| 229 | 
            +
                    self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
         | 
| 230 | 
            +
                    self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                def _get_size(self):
         | 
| 233 | 
            +
                    self.size = int(self._extract_pattern(r"size = (.*)", 2))
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                def _get_item_list(self):
         | 
| 236 | 
            +
                    """Only supports IntervalTier currently"""
         | 
| 237 | 
            +
                    for itemIdx in range(1, self.size + 1):
         | 
| 238 | 
            +
                        tier = OrderedDict()
         | 
| 239 | 
            +
                        item_list = []
         | 
| 240 | 
            +
                        tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
         | 
| 241 | 
            +
                        tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
         | 
| 242 | 
            +
                        if tier_class != "IntervalTier":
         | 
| 243 | 
            +
                            raise NotImplementedError("Only IntervalTier class is supported currently")
         | 
| 244 | 
            +
                        tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
         | 
| 245 | 
            +
                        tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
         | 
| 246 | 
            +
                        tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
         | 
| 247 | 
            +
                        tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
         | 
| 248 | 
            +
                        for i in range(int(tier_size)):
         | 
| 249 | 
            +
                            item = OrderedDict()
         | 
| 250 | 
            +
                            item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
         | 
| 251 | 
            +
                            item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
         | 
| 252 | 
            +
                            item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
         | 
| 253 | 
            +
                            item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
         | 
| 254 | 
            +
                            item_list.append(item)
         | 
| 255 | 
            +
                        tier["idx"] = tier_idx
         | 
| 256 | 
            +
                        tier["class"] = tier_class
         | 
| 257 | 
            +
                        tier["name"] = tier_name
         | 
| 258 | 
            +
                        tier["xmin"] = tier_xmin
         | 
| 259 | 
            +
                        tier["xmax"] = tier_xmax
         | 
| 260 | 
            +
                        tier["size"] = tier_size
         | 
| 261 | 
            +
                        tier["items"] = item_list
         | 
| 262 | 
            +
                        self.tier_list.append(tier)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                def toJson(self):
         | 
| 265 | 
            +
                    _json = OrderedDict()
         | 
| 266 | 
            +
                    _json["file_type"] = self.file_type
         | 
| 267 | 
            +
                    _json["xmin"] = self.xmin
         | 
| 268 | 
            +
                    _json["xmax"] = self.xmax
         | 
| 269 | 
            +
                    _json["size"] = self.size
         | 
| 270 | 
            +
                    _json["tiers"] = self.tier_list
         | 
| 271 | 
            +
                    return json.dumps(_json, ensure_ascii=False, indent=2)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            def get_mel2ph(tg_fn, ph, mel, hparams):
         | 
| 275 | 
            +
                ph_list = ph.split(" ")
         | 
| 276 | 
            +
                with open(tg_fn, "r") as f:
         | 
| 277 | 
            +
                    tg = f.readlines()
         | 
| 278 | 
            +
                tg = remove_empty_lines(tg)
         | 
| 279 | 
            +
                tg = TextGrid(tg)
         | 
| 280 | 
            +
                tg = json.loads(tg.toJson())
         | 
| 281 | 
            +
                split = np.ones(len(ph_list) + 1, np.float) * -1
         | 
| 282 | 
            +
                tg_idx = 0
         | 
| 283 | 
            +
                ph_idx = 0
         | 
| 284 | 
            +
                tg_align = [x for x in tg['tiers'][-1]['items']]
         | 
| 285 | 
            +
                tg_align_ = []
         | 
| 286 | 
            +
                for x in tg_align:
         | 
| 287 | 
            +
                    x['xmin'] = float(x['xmin'])
         | 
| 288 | 
            +
                    x['xmax'] = float(x['xmax'])
         | 
| 289 | 
            +
                    if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']:
         | 
| 290 | 
            +
                        x['text'] = ''
         | 
| 291 | 
            +
                        if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
         | 
| 292 | 
            +
                            tg_align_[-1]['xmax'] = x['xmax']
         | 
| 293 | 
            +
                            continue
         | 
| 294 | 
            +
                    tg_align_.append(x)
         | 
| 295 | 
            +
                tg_align = tg_align_
         | 
| 296 | 
            +
                tg_len = len([x for x in tg_align if x['text'] != ''])
         | 
| 297 | 
            +
                ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
         | 
| 298 | 
            +
                assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn)
         | 
| 299 | 
            +
                while tg_idx < len(tg_align) or ph_idx < len(ph_list):
         | 
| 300 | 
            +
                    if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
         | 
| 301 | 
            +
                        split[ph_idx] = 1e8
         | 
| 302 | 
            +
                        ph_idx += 1
         | 
| 303 | 
            +
                        continue
         | 
| 304 | 
            +
                    x = tg_align[tg_idx]
         | 
| 305 | 
            +
                    if x['text'] == '' and ph_idx == len(ph_list):
         | 
| 306 | 
            +
                        tg_idx += 1
         | 
| 307 | 
            +
                        continue
         | 
| 308 | 
            +
                    assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn)
         | 
| 309 | 
            +
                    ph = ph_list[ph_idx]
         | 
| 310 | 
            +
                    if x['text'] == '' and not is_sil_phoneme(ph):
         | 
| 311 | 
            +
                        assert False, (ph_list, tg_align)
         | 
| 312 | 
            +
                    if x['text'] != '' and is_sil_phoneme(ph):
         | 
| 313 | 
            +
                        ph_idx += 1
         | 
| 314 | 
            +
                    else:
         | 
| 315 | 
            +
                        assert (x['text'] == '' and is_sil_phoneme(ph)) \
         | 
| 316 | 
            +
                               or x['text'].lower() == ph.lower() \
         | 
| 317 | 
            +
                               or x['text'].lower() == 'sil', (x['text'], ph)
         | 
| 318 | 
            +
                        split[ph_idx] = x['xmin']
         | 
| 319 | 
            +
                        if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]):
         | 
| 320 | 
            +
                            split[ph_idx - 1] = split[ph_idx]
         | 
| 321 | 
            +
                        ph_idx += 1
         | 
| 322 | 
            +
                        tg_idx += 1
         | 
| 323 | 
            +
                assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
         | 
| 324 | 
            +
                assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn)
         | 
| 325 | 
            +
                mel2ph = np.zeros([mel.shape[0]], np.int)
         | 
| 326 | 
            +
                split[0] = 0
         | 
| 327 | 
            +
                split[-1] = 1e8
         | 
| 328 | 
            +
                for i in range(len(split) - 1):
         | 
| 329 | 
            +
                    assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],)
         | 
| 330 | 
            +
                split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split]
         | 
| 331 | 
            +
                for ph_idx in range(len(ph_list)):
         | 
| 332 | 
            +
                    mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1
         | 
| 333 | 
            +
                mel2ph_torch = torch.from_numpy(mel2ph)
         | 
| 334 | 
            +
                T_t = len(ph_list)
         | 
| 335 | 
            +
                dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch))
         | 
| 336 | 
            +
                dur = dur[1:].numpy()
         | 
| 337 | 
            +
                return mel2ph, dur
         | 
| 338 | 
            +
             | 
| 339 | 
            +
             | 
| 340 | 
            +
            def build_phone_encoder(data_dir):
         | 
| 341 | 
            +
                phone_list_file = os.path.join(data_dir, 'phone_set.json')
         | 
| 342 | 
            +
                phone_list = json.load(open(phone_list_file))
         | 
| 343 | 
            +
                return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
            def is_sil_phoneme(p):
         | 
| 347 | 
            +
                return not p[0].isalpha()
         | 
    	
        data_gen/tts/txt_processors/base_text_processor.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            class BaseTxtProcessor:
         | 
| 2 | 
            +
                @staticmethod
         | 
| 3 | 
            +
                def sp_phonemes():
         | 
| 4 | 
            +
                    return ['|']
         | 
| 5 | 
            +
             | 
| 6 | 
            +
                @classmethod
         | 
| 7 | 
            +
                def process(cls, txt, pre_align_args):
         | 
| 8 | 
            +
                    raise NotImplementedError
         | 
    	
        data_gen/tts/txt_processors/en.py
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            from data_gen.tts.data_gen_utils import PUNCS
         | 
| 3 | 
            +
            from g2p_en import G2p
         | 
| 4 | 
            +
            import unicodedata
         | 
| 5 | 
            +
            from g2p_en.expand import normalize_numbers
         | 
| 6 | 
            +
            from nltk import pos_tag
         | 
| 7 | 
            +
            from nltk.tokenize import TweetTokenizer
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class EnG2p(G2p):
         | 
| 13 | 
            +
                word_tokenize = TweetTokenizer().tokenize
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def __call__(self, text):
         | 
| 16 | 
            +
                    # preprocessing
         | 
| 17 | 
            +
                    words = EnG2p.word_tokenize(text)
         | 
| 18 | 
            +
                    tokens = pos_tag(words)  # tuples of (word, tag)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    # steps
         | 
| 21 | 
            +
                    prons = []
         | 
| 22 | 
            +
                    for word, pos in tokens:
         | 
| 23 | 
            +
                        if re.search("[a-z]", word) is None:
         | 
| 24 | 
            +
                            pron = [word]
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                        elif word in self.homograph2features:  # Check homograph
         | 
| 27 | 
            +
                            pron1, pron2, pos1 = self.homograph2features[word]
         | 
| 28 | 
            +
                            if pos.startswith(pos1):
         | 
| 29 | 
            +
                                pron = pron1
         | 
| 30 | 
            +
                            else:
         | 
| 31 | 
            +
                                pron = pron2
         | 
| 32 | 
            +
                        elif word in self.cmu:  # lookup CMU dict
         | 
| 33 | 
            +
                            pron = self.cmu[word][0]
         | 
| 34 | 
            +
                        else:  # predict for oov
         | 
| 35 | 
            +
                            pron = self.predict(word)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                        prons.extend(pron)
         | 
| 38 | 
            +
                        prons.extend([" "])
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    return prons[:-1]
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            class TxtProcessor(BaseTxtProcessor):
         | 
| 44 | 
            +
                g2p = EnG2p()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @staticmethod
         | 
| 47 | 
            +
                def preprocess_text(text):
         | 
| 48 | 
            +
                    text = normalize_numbers(text)
         | 
| 49 | 
            +
                    text = ''.join(char for char in unicodedata.normalize('NFD', text)
         | 
| 50 | 
            +
                                   if unicodedata.category(char) != 'Mn')  # Strip accents
         | 
| 51 | 
            +
                    text = text.lower()
         | 
| 52 | 
            +
                    text = re.sub("[\'\"()]+", "", text)
         | 
| 53 | 
            +
                    text = re.sub("[-]+", " ", text)
         | 
| 54 | 
            +
                    text = re.sub(f"[^ a-z{PUNCS}]", "", text)
         | 
| 55 | 
            +
                    text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text)  # !! -> !
         | 
| 56 | 
            +
                    text = re.sub(f"([{PUNCS}])+", r"\1", text)  # !! -> !
         | 
| 57 | 
            +
                    text = text.replace("i.e.", "that is")
         | 
| 58 | 
            +
                    text = text.replace("i.e.", "that is")
         | 
| 59 | 
            +
                    text = text.replace("etc.", "etc")
         | 
| 60 | 
            +
                    text = re.sub(f"([{PUNCS}])", r" \1 ", text)
         | 
| 61 | 
            +
                    text = re.sub(rf"\s+", r" ", text)
         | 
| 62 | 
            +
                    return text
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @classmethod
         | 
| 65 | 
            +
                def process(cls, txt, pre_align_args):
         | 
| 66 | 
            +
                    txt = cls.preprocess_text(txt).strip()
         | 
| 67 | 
            +
                    phs = cls.g2p(txt)
         | 
| 68 | 
            +
                    phs_ = []
         | 
| 69 | 
            +
                    n_word_sep = 0
         | 
| 70 | 
            +
                    for p in phs:
         | 
| 71 | 
            +
                        if p.strip() == '':
         | 
| 72 | 
            +
                            phs_ += ['|']
         | 
| 73 | 
            +
                            n_word_sep += 1
         | 
| 74 | 
            +
                        else:
         | 
| 75 | 
            +
                            phs_ += p.split(" ")
         | 
| 76 | 
            +
                    phs = phs_
         | 
| 77 | 
            +
                    assert n_word_sep + 1 == len(txt.split(" ")), (phs, f"\"{txt}\"")
         | 
| 78 | 
            +
                    return phs, txt
         | 
    	
        data_gen/tts/txt_processors/zh.py
    ADDED
    
    | @@ -0,0 +1,41 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            from pypinyin import pinyin, Style
         | 
| 3 | 
            +
            from data_gen.tts.data_gen_utils import PUNCS
         | 
| 4 | 
            +
            from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
         | 
| 5 | 
            +
            from utils.text_norm import NSWNormalizer
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class TxtProcessor(BaseTxtProcessor):
         | 
| 9 | 
            +
                table = {ord(f): ord(t) for f, t in zip(
         | 
| 10 | 
            +
                    u':,。!?【】()%#@&1234567890',
         | 
| 11 | 
            +
                    u':,.!?[]()%#@&1234567890')}
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                @staticmethod
         | 
| 14 | 
            +
                def preprocess_text(text):
         | 
| 15 | 
            +
                    text = text.translate(TxtProcessor.table)
         | 
| 16 | 
            +
                    text = NSWNormalizer(text).normalize(remove_punc=False)
         | 
| 17 | 
            +
                    text = re.sub("[\'\"()]+", "", text)
         | 
| 18 | 
            +
                    text = re.sub("[-]+", " ", text)
         | 
| 19 | 
            +
                    text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}]", "", text)
         | 
| 20 | 
            +
                    text = re.sub(f"([{PUNCS}])+", r"\1", text)  # !! -> !
         | 
| 21 | 
            +
                    text = re.sub(f"([{PUNCS}])", r" \1 ", text)
         | 
| 22 | 
            +
                    text = re.sub(rf"\s+", r"", text)
         | 
| 23 | 
            +
                    return text
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                @classmethod
         | 
| 26 | 
            +
                def process(cls, txt, pre_align_args):
         | 
| 27 | 
            +
                    txt = cls.preprocess_text(txt)
         | 
| 28 | 
            +
                    shengmu = pinyin(txt, style=Style.INITIALS)  # https://blog.csdn.net/zhoulei124/article/details/89055403
         | 
| 29 | 
            +
                    yunmu_finals = pinyin(txt, style=Style.FINALS)
         | 
| 30 | 
            +
                    yunmu_tone3 = pinyin(txt, style=Style.FINALS_TONE3)
         | 
| 31 | 
            +
                    yunmu = [[t[0] + '5'] if t[0] == f[0] else t for f, t in zip(yunmu_finals, yunmu_tone3)] \
         | 
| 32 | 
            +
                        if pre_align_args['use_tone'] else yunmu_finals
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    assert len(shengmu) == len(yunmu)
         | 
| 35 | 
            +
                    phs = ["|"]
         | 
| 36 | 
            +
                    for a, b, c in zip(shengmu, yunmu, yunmu_finals):
         | 
| 37 | 
            +
                        if a[0] == c[0]:
         | 
| 38 | 
            +
                            phs += [a[0], "|"]
         | 
| 39 | 
            +
                        else:
         | 
| 40 | 
            +
                            phs += [a[0], b[0], "|"]
         | 
| 41 | 
            +
                    return phs, txt
         | 
    	
        data_gen/tts/txt_processors/zh_g2pM.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            import jieba
         | 
| 3 | 
            +
            from pypinyin import pinyin, Style
         | 
| 4 | 
            +
            from data_gen.tts.data_gen_utils import PUNCS
         | 
| 5 | 
            +
            from data_gen.tts.txt_processors import zh
         | 
| 6 | 
            +
            from g2pM import G2pM
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            ALL_SHENMU = ['b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'x', 'z', 'zh']
         | 
| 9 | 
            +
            ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao',  'e', 'ei', 'en', 'eng', 'er',  'i', 'ia', 'ian', 'iang', 'iao',
         | 
| 10 | 
            +
                         'ie', 'in', 'ing', 'iong', 'iou', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'uei',
         | 
| 11 | 
            +
                         'uen', 'uo', 'v', 'van', 've', 'vn']
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class TxtProcessor(zh.TxtProcessor):
         | 
| 15 | 
            +
                model = G2pM()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                @staticmethod
         | 
| 18 | 
            +
                def sp_phonemes():
         | 
| 19 | 
            +
                    return ['|', '#']
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                @classmethod
         | 
| 22 | 
            +
                def process(cls, txt, pre_align_args):
         | 
| 23 | 
            +
                    txt = cls.preprocess_text(txt)
         | 
| 24 | 
            +
                    ph_list = cls.model(txt, tone=pre_align_args['use_tone'], char_split=True)
         | 
| 25 | 
            +
                    seg_list = '#'.join(jieba.cut(txt))
         | 
| 26 | 
            +
                    assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    # 加入词边界'#'
         | 
| 29 | 
            +
                    ph_list_ = []
         | 
| 30 | 
            +
                    seg_idx = 0
         | 
| 31 | 
            +
                    for p in ph_list:
         | 
| 32 | 
            +
                        p = p.replace("u:", "v")
         | 
| 33 | 
            +
                        if seg_list[seg_idx] == '#':
         | 
| 34 | 
            +
                            ph_list_.append('#')
         | 
| 35 | 
            +
                            seg_idx += 1
         | 
| 36 | 
            +
                        else:
         | 
| 37 | 
            +
                            ph_list_.append("|")
         | 
| 38 | 
            +
                        seg_idx += 1
         | 
| 39 | 
            +
                        if re.findall('[\u4e00-\u9fff]', p):
         | 
| 40 | 
            +
                            if pre_align_args['use_tone']:
         | 
| 41 | 
            +
                                p = pinyin(p, style=Style.TONE3, strict=True)[0][0]
         | 
| 42 | 
            +
                                if p[-1] not in ['1', '2', '3', '4', '5']:
         | 
| 43 | 
            +
                                    p = p + '5'
         | 
| 44 | 
            +
                            else:
         | 
| 45 | 
            +
                                p = pinyin(p, style=Style.NORMAL, strict=True)[0][0]
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                        finished = False
         | 
| 48 | 
            +
                        if len([c.isalpha() for c in p]) > 1:
         | 
| 49 | 
            +
                            for shenmu in ALL_SHENMU:
         | 
| 50 | 
            +
                                if p.startswith(shenmu) and not p.lstrip(shenmu).isnumeric():
         | 
| 51 | 
            +
                                    ph_list_ += [shenmu, p.lstrip(shenmu)]
         | 
| 52 | 
            +
                                    finished = True
         | 
| 53 | 
            +
                                    break
         | 
| 54 | 
            +
                        if not finished:
         | 
| 55 | 
            +
                            ph_list_.append(p)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    ph_list = ph_list_
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...]
         | 
| 60 | 
            +
                    sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes()
         | 
| 61 | 
            +
                    ph_list_ = []
         | 
| 62 | 
            +
                    for i in range(0, len(ph_list), 1):
         | 
| 63 | 
            +
                        if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes):
         | 
| 64 | 
            +
                            ph_list_.append(ph_list[i])
         | 
| 65 | 
            +
                    ph_list = ph_list_
         | 
| 66 | 
            +
                    return ph_list, txt
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            if __name__ == '__main__':
         | 
| 70 | 
            +
                phs, txt = TxtProcessor.process('他来到了,网易杭研大厦', {'use_tone': True})
         | 
| 71 | 
            +
                print(phs)
         | 
    	
        inference/m4singer/base_svs_infer.py
    ADDED
    
    | @@ -0,0 +1,242 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from modules.hifigan.hifigan import HifiGanGenerator
         | 
| 6 | 
            +
            from vocoders.hifigan import HifiGAN
         | 
| 7 | 
            +
            from inference.m4singer.m4singer.map import m4singer_pinyin2ph_func
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from utils import load_ckpt
         | 
| 10 | 
            +
            from utils.hparams import set_hparams, hparams
         | 
| 11 | 
            +
            from utils.text_encoder import TokenTextEncoder
         | 
| 12 | 
            +
            from pypinyin import pinyin, lazy_pinyin, Style
         | 
| 13 | 
            +
            import librosa
         | 
| 14 | 
            +
            import glob
         | 
| 15 | 
            +
            import re
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class BaseSVSInfer:
         | 
| 19 | 
            +
                def __init__(self, hparams, device=None):
         | 
| 20 | 
            +
                    if device is None:
         | 
| 21 | 
            +
                        device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 22 | 
            +
                    self.hparams = hparams
         | 
| 23 | 
            +
                    self.device = device
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    phone_list = ["<AP>", "<SP>", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g", "h",
         | 
| 26 | 
            +
                     "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iou", "j", "k", "l", "m", "n", "o", "ong", "ou",
         | 
| 27 | 
            +
                     "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "uei", "uen", "uo", "v", "van", "ve", "vn",
         | 
| 28 | 
            +
                     "x", "z", "zh"]
         | 
| 29 | 
            +
                    self.ph_encoder = TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
         | 
| 30 | 
            +
                    self.pinyin2phs = m4singer_pinyin2ph_func()
         | 
| 31 | 
            +
                    self.spk_map = {"Alto-1": 0, "Alto-2": 1, "Alto-3": 2, "Alto-4": 3, "Alto-5": 4, "Alto-6": 5, "Alto-7": 6, "Bass-1": 7,
         | 
| 32 | 
            +
                     "Bass-2": 8, "Bass-3": 9, "Soprano-1": 10, "Soprano-2": 11, "Soprano-3": 12, "Tenor-1": 13, "Tenor-2": 14,
         | 
| 33 | 
            +
                     "Tenor-3": 15, "Tenor-4": 16, "Tenor-5": 17, "Tenor-6": 18, "Tenor-7": 19}
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.model = self.build_model()
         | 
| 36 | 
            +
                    self.model.eval()
         | 
| 37 | 
            +
                    self.model.to(self.device)
         | 
| 38 | 
            +
                    self.vocoder = self.build_vocoder()
         | 
| 39 | 
            +
                    self.vocoder.eval()
         | 
| 40 | 
            +
                    self.vocoder.to(self.device)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def build_model(self):
         | 
| 43 | 
            +
                    raise NotImplementedError
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def forward_model(self, inp):
         | 
| 46 | 
            +
                    raise NotImplementedError
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def build_vocoder(self):
         | 
| 49 | 
            +
                    base_dir = hparams['vocoder_ckpt']
         | 
| 50 | 
            +
                    config_path = f'{base_dir}/config.yaml'
         | 
| 51 | 
            +
                    ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
         | 
| 52 | 
            +
                    lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
         | 
| 53 | 
            +
                    print('| load HifiGAN: ', ckpt)
         | 
| 54 | 
            +
                    ckpt_dict = torch.load(ckpt, map_location="cpu")
         | 
| 55 | 
            +
                    config = set_hparams(config_path, global_hparams=False)
         | 
| 56 | 
            +
                    state = ckpt_dict["state_dict"]["model_gen"]
         | 
| 57 | 
            +
                    vocoder = HifiGanGenerator(config)
         | 
| 58 | 
            +
                    vocoder.load_state_dict(state, strict=True)
         | 
| 59 | 
            +
                    vocoder.remove_weight_norm()
         | 
| 60 | 
            +
                    vocoder = vocoder.eval().to(self.device)
         | 
| 61 | 
            +
                    return vocoder
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def run_vocoder(self, c, **kwargs):
         | 
| 64 | 
            +
                    c = c.transpose(2, 1)  # [B, 80, T]
         | 
| 65 | 
            +
                    f0 = kwargs.get('f0')  # [B, T]
         | 
| 66 | 
            +
                    if f0 is not None and hparams.get('use_nsf'):
         | 
| 67 | 
            +
                        # f0 = torch.FloatTensor(f0).to(self.device)
         | 
| 68 | 
            +
                        y = self.vocoder(c, f0).view(-1)
         | 
| 69 | 
            +
                    else:
         | 
| 70 | 
            +
                        y = self.vocoder(c).view(-1)
         | 
| 71 | 
            +
                        # [T]
         | 
| 72 | 
            +
                    return y[None]
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def preprocess_word_level_input(self, inp):
         | 
| 75 | 
            +
                    # Pypinyin can't solve polyphonic words
         | 
| 76 | 
            +
                    text_raw = inp['text']
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # lyric
         | 
| 79 | 
            +
                    pinyins = lazy_pinyin(text_raw, strict=False)
         | 
| 80 | 
            +
                    ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs]
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # Note
         | 
| 83 | 
            +
                    note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != '']
         | 
| 84 | 
            +
                    mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != '']
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst):
         | 
| 87 | 
            +
                        print('Pass word-notes check.')
         | 
| 88 | 
            +
                    else:
         | 
| 89 | 
            +
                        print('The number of words does\'t match the number of notes\' windows. ',
         | 
| 90 | 
            +
                              'You should split the note(s) for each word by | mark.')
         | 
| 91 | 
            +
                        print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
         | 
| 92 | 
            +
                        print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst))
         | 
| 93 | 
            +
                        return None
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    note_lst = []
         | 
| 96 | 
            +
                    ph_lst = []
         | 
| 97 | 
            +
                    midi_dur_lst = []
         | 
| 98 | 
            +
                    is_slur = []
         | 
| 99 | 
            +
                    for idx, ph_per_word in enumerate(ph_per_word_lst):
         | 
| 100 | 
            +
                        # for phs in one word:
         | 
| 101 | 
            +
                        # single ph like ['ai']  or multiple phs like ['n', 'i']
         | 
| 102 | 
            +
                        ph_in_this_word = ph_per_word.split()
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        # for notes in one word:
         | 
| 105 | 
            +
                        # single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here.
         | 
| 106 | 
            +
                        note_in_this_word = note_per_word_lst[idx].split()
         | 
| 107 | 
            +
                        midi_dur_in_this_word = mididur_per_word_lst[idx].split()
         | 
| 108 | 
            +
                        # process for the model input
         | 
| 109 | 
            +
                        # Step 1.
         | 
| 110 | 
            +
                        #  Deal with note of 'not slur' case or the first note of 'slur' case
         | 
| 111 | 
            +
                        #  j        ie
         | 
| 112 | 
            +
                        #  F#4/Gb4  F#4/Gb4
         | 
| 113 | 
            +
                        #  0        0
         | 
| 114 | 
            +
                        for ph in ph_in_this_word:
         | 
| 115 | 
            +
                            ph_lst.append(ph)
         | 
| 116 | 
            +
                            note_lst.append(note_in_this_word[0])
         | 
| 117 | 
            +
                            midi_dur_lst.append(midi_dur_in_this_word[0])
         | 
| 118 | 
            +
                            is_slur.append(0)
         | 
| 119 | 
            +
                        # step 2.
         | 
| 120 | 
            +
                        #  Deal with the 2nd, 3rd... notes of 'slur' case
         | 
| 121 | 
            +
                        #  j        ie         ie
         | 
| 122 | 
            +
                        #  F#4/Gb4  F#4/Gb4    C#4/Db4
         | 
| 123 | 
            +
                        #  0        0          1
         | 
| 124 | 
            +
                        if len(note_in_this_word) > 1:  # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes.
         | 
| 125 | 
            +
                            for idx in range(1, len(note_in_this_word)):
         | 
| 126 | 
            +
                                ph_lst.append(ph_in_this_word[-1])
         | 
| 127 | 
            +
                                note_lst.append(note_in_this_word[idx])
         | 
| 128 | 
            +
                                midi_dur_lst.append(midi_dur_in_this_word[idx])
         | 
| 129 | 
            +
                                is_slur.append(1)
         | 
| 130 | 
            +
                    ph_seq = ' '.join(ph_lst)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
         | 
| 133 | 
            +
                        print(len(ph_lst), len(note_lst), len(midi_dur_lst))
         | 
| 134 | 
            +
                        print('Pass word-notes check.')
         | 
| 135 | 
            +
                    else:
         | 
| 136 | 
            +
                        print('The number of words does\'t match the number of notes\' windows. ',
         | 
| 137 | 
            +
                              'You should split the note(s) for each word by | mark.')
         | 
| 138 | 
            +
                        return None
         | 
| 139 | 
            +
                    return ph_seq, note_lst, midi_dur_lst, is_slur
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def preprocess_phoneme_level_input(self, inp):
         | 
| 142 | 
            +
                    ph_seq = inp['ph_seq']
         | 
| 143 | 
            +
                    note_lst = inp['note_seq'].split()
         | 
| 144 | 
            +
                    midi_dur_lst = inp['note_dur_seq'].split()
         | 
| 145 | 
            +
                    is_slur = [float(x) for x in inp['is_slur_seq'].split()]
         | 
| 146 | 
            +
                    print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
         | 
| 147 | 
            +
                    if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
         | 
| 148 | 
            +
                        print('Pass word-notes check.')
         | 
| 149 | 
            +
                    else:
         | 
| 150 | 
            +
                        print('The number of words does\'t match the number of notes\' windows. ',
         | 
| 151 | 
            +
                              'You should split the note(s) for each word by | mark.')
         | 
| 152 | 
            +
                        return None
         | 
| 153 | 
            +
                    return ph_seq, note_lst, midi_dur_lst, is_slur
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def preprocess_input(self, inp, input_type='word'):
         | 
| 156 | 
            +
                    """
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
         | 
| 159 | 
            +
                    :return:
         | 
| 160 | 
            +
                    """
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    item_name = inp.get('item_name', '<ITEM_NAME>')
         | 
| 163 | 
            +
                    spk_name = inp.get('spk_name', 'Alto-1')
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    # single spk
         | 
| 166 | 
            +
                    spk_id = self.spk_map[spk_name]
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # get ph seq, note lst, midi dur lst, is slur lst.
         | 
| 169 | 
            +
                    if input_type == 'word':
         | 
| 170 | 
            +
                        ret = self.preprocess_word_level_input(inp)
         | 
| 171 | 
            +
                    elif input_type == 'phoneme':
         | 
| 172 | 
            +
                        ret = self.preprocess_phoneme_level_input(inp)
         | 
| 173 | 
            +
                    else:
         | 
| 174 | 
            +
                        print('Invalid input type.')
         | 
| 175 | 
            +
                        return None
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    if ret:
         | 
| 178 | 
            +
                        ph_seq, note_lst, midi_dur_lst, is_slur = ret
         | 
| 179 | 
            +
                    else:
         | 
| 180 | 
            +
                        print('==========> Preprocess_word_level or phone_level input wrong.')
         | 
| 181 | 
            +
                        return None
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # convert note lst to midi id; convert note dur lst to midi duration
         | 
| 184 | 
            +
                    try:
         | 
| 185 | 
            +
                        midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
         | 
| 186 | 
            +
                                 for x in note_lst]
         | 
| 187 | 
            +
                        midi_dur_lst = [float(x) for x in midi_dur_lst]
         | 
| 188 | 
            +
                    except Exception as e:
         | 
| 189 | 
            +
                        print(e)
         | 
| 190 | 
            +
                        print('Invalid Input Type.')
         | 
| 191 | 
            +
                        return None
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    ph_token = self.ph_encoder.encode(ph_seq)
         | 
| 194 | 
            +
                    item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
         | 
| 195 | 
            +
                            'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
         | 
| 196 | 
            +
                            'is_slur': np.asarray(is_slur), }
         | 
| 197 | 
            +
                    item['ph_len'] = len(item['ph_token'])
         | 
| 198 | 
            +
                    return item
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                def input_to_batch(self, item):
         | 
| 201 | 
            +
                    item_names = [item['item_name']]
         | 
| 202 | 
            +
                    text = [item['text']]
         | 
| 203 | 
            +
                    ph = [item['ph']]
         | 
| 204 | 
            +
                    txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
         | 
| 205 | 
            +
                    txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
         | 
| 206 | 
            +
                    spk_ids = torch.LongTensor([item['spk_id']])[:].to(self.device)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device)
         | 
| 209 | 
            +
                    midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device)
         | 
| 210 | 
            +
                    is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    batch = {
         | 
| 213 | 
            +
                        'item_name': item_names,
         | 
| 214 | 
            +
                        'text': text,
         | 
| 215 | 
            +
                        'ph': ph,
         | 
| 216 | 
            +
                        'txt_tokens': txt_tokens,
         | 
| 217 | 
            +
                        'txt_lengths': txt_lengths,
         | 
| 218 | 
            +
                        'spk_ids': spk_ids,
         | 
| 219 | 
            +
                        'pitch_midi': pitch_midi,
         | 
| 220 | 
            +
                        'midi_dur': midi_dur,
         | 
| 221 | 
            +
                        'is_slur': is_slur
         | 
| 222 | 
            +
                    }
         | 
| 223 | 
            +
                    return batch
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def postprocess_output(self, output):
         | 
| 226 | 
            +
                    return output
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def infer_once(self, inp):
         | 
| 229 | 
            +
                    inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
         | 
| 230 | 
            +
                    output = self.forward_model(inp)
         | 
| 231 | 
            +
                    output = self.postprocess_output(output)
         | 
| 232 | 
            +
                    return output
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                @classmethod
         | 
| 235 | 
            +
                def example_run(cls, inp):
         | 
| 236 | 
            +
                    from utils.audio import save_wav
         | 
| 237 | 
            +
                    set_hparams(print_hparams=False)
         | 
| 238 | 
            +
                    infer_ins = cls(hparams)
         | 
| 239 | 
            +
                    out = infer_ins.infer_once(inp)
         | 
| 240 | 
            +
                    os.makedirs('infer_out', exist_ok=True)
         | 
| 241 | 
            +
                    f_name = inp['spk_name'] + ' | ' + inp['text']
         | 
| 242 | 
            +
                    save_wav(out, f'infer_out/{f_name}.wav', hparams['audio_sample_rate'])
         | 
    	
        inference/m4singer/ds_e2e.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            # from inference.tts.fs import FastSpeechInfer
         | 
| 3 | 
            +
            # from modules.tts.fs2_orig import FastSpeech2Orig
         | 
| 4 | 
            +
            from inference.m4singer.base_svs_infer import BaseSVSInfer
         | 
| 5 | 
            +
            from utils import load_ckpt
         | 
| 6 | 
            +
            from utils.hparams import hparams
         | 
| 7 | 
            +
            from usr.diff.shallow_diffusion_tts import GaussianDiffusion
         | 
| 8 | 
            +
            from usr.diffsinger_task import DIFF_DECODERS
         | 
| 9 | 
            +
            from modules.fastspeech.pe import PitchExtractor
         | 
| 10 | 
            +
            import utils
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class DiffSingerE2EInfer(BaseSVSInfer):
         | 
| 14 | 
            +
                def build_model(self):
         | 
| 15 | 
            +
                    model = GaussianDiffusion(
         | 
| 16 | 
            +
                        phone_encoder=self.ph_encoder,
         | 
| 17 | 
            +
                        out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
         | 
| 18 | 
            +
                        timesteps=hparams['timesteps'],
         | 
| 19 | 
            +
                        K_step=hparams['K_step'],
         | 
| 20 | 
            +
                        loss_type=hparams['diff_loss_type'],
         | 
| 21 | 
            +
                        spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
         | 
| 22 | 
            +
                    )
         | 
| 23 | 
            +
                    model.eval()
         | 
| 24 | 
            +
                    load_ckpt(model, hparams['work_dir'], 'model')
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    if hparams.get('pe_enable') is not None and hparams['pe_enable']:
         | 
| 27 | 
            +
                        self.pe = PitchExtractor().to(self.device)
         | 
| 28 | 
            +
                        utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
         | 
| 29 | 
            +
                        self.pe.eval()
         | 
| 30 | 
            +
                    return model
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def forward_model(self, inp):
         | 
| 33 | 
            +
                    sample = self.input_to_batch(inp)
         | 
| 34 | 
            +
                    txt_tokens = sample['txt_tokens']  # [B, T_t]
         | 
| 35 | 
            +
                    spk_id = sample.get('spk_ids')
         | 
| 36 | 
            +
                    with torch.no_grad():
         | 
| 37 | 
            +
                        output = self.model(txt_tokens, spk_embed=spk_id, ref_mels=None, infer=True,
         | 
| 38 | 
            +
                                            pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
         | 
| 39 | 
            +
                                            is_slur=sample['is_slur'])
         | 
| 40 | 
            +
                        mel_out = output['mel_out']  # [B, T,80]
         | 
| 41 | 
            +
                        if hparams.get('pe_enable') is not None and hparams['pe_enable']:
         | 
| 42 | 
            +
                            f0_pred = self.pe(mel_out)['f0_denorm_pred']  # pe predict from Pred mel
         | 
| 43 | 
            +
                        else:
         | 
| 44 | 
            +
                            f0_pred = output['f0_denorm']
         | 
| 45 | 
            +
                        wav_out = self.run_vocoder(mel_out, f0=f0_pred)
         | 
| 46 | 
            +
                    wav_out = wav_out.cpu().numpy()
         | 
| 47 | 
            +
                    return wav_out[0]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            if __name__ == '__main__':
         | 
| 50 | 
            +
                inp = {
         | 
| 51 | 
            +
                    'spk_name': 'Tenor-1',
         | 
| 52 | 
            +
                    'text': 'AP你要相信AP相信我们会像童话故事里AP',
         | 
| 53 | 
            +
                    'notes': 'rest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest',
         | 
| 54 | 
            +
                    'notes_duration': '0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14',
         | 
| 55 | 
            +
                    'input_type': 'word',
         | 
| 56 | 
            +
                } 
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                c = {
         | 
| 59 | 
            +
                    'spk_name': 'Tenor-1',
         | 
| 60 | 
            +
                    'text': '你要相信相信我们会像童话故事里',
         | 
| 61 | 
            +
                    'ph_seq': '<AP> n i iao iao x iang x in in <AP> x iang iang x in uo uo m en h uei x iang t ong ong h ua g u u sh i l i <AP>',
         | 
| 62 | 
            +
                    'note_seq': 'rest G#3 G#3 A#3 C4 D#4 D#4 D#4 D#4 F4 rest E4 E4 F4 F4 F4 D#4 A#3 A#3 A#3 A#3 A#3 C#4 C#4 B3 B3 C4 C#4 C#4 B3 B3 C4 A#3 A#3 G#3 G#3 rest',
         | 
| 63 | 
            +
                    'note_dur_seq': '0.14 0.47 0.47 0.1905 0.1895 0.41 0.41 0.3005 0.3005 0.3895 0.21 0.2391 0.2391 0.1809 0.32 0.32 0.4105 0.2095 0.35 0.35 0.43 0.43 0.45 0.45 0.2309 0.2309 0.2291 0.48 0.48 0.225 0.225 0.195 0.29 0.29 0.71 0.71 0.14',
         | 
| 64 | 
            +
                    'is_slur_seq': '0 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0',
         | 
| 65 | 
            +
                    'input_type': 'phoneme'
         | 
| 66 | 
            +
                }
         | 
| 67 | 
            +
                DiffSingerE2EInfer.example_run(inp)
         | 
    	
        inference/m4singer/gradio/gradio_settings.yaml
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            title: 'M4Singer'
         | 
| 2 | 
            +
            description: |
         | 
| 3 | 
            +
              This page aims to display the singing voice synthesis function of M4Singer. SingerID can be switched freely to preview the timbre of each singer. Click examples below to quickly load scores and audio.
         | 
| 4 | 
            +
               (本页面为M4Singer歌声合成功能展示。SingerID可以自由切换用以预览各歌手的音色。点击下方Examples可以快速加载乐谱和音频。)
         | 
| 5 | 
            +
             | 
| 6 | 
            +
              Please assign pitch and duration values to each Chinese character. The corresponding pitch and duration value of each character should be separated by a | separator. It is necessary to ensure that the note window separated by the separator is consistent with the number of Chinese characters. AP (aspirate) or SP (silence) is also viewed as a Chinese character.
         | 
| 7 | 
            +
               (请给每个汉字分配音高和时值, 每个字对应的音高和时值需要用 | 分隔符隔开。需要保证分隔符分割出来的音符窗口与汉字个数一致。换气或静音符也算一个汉字。)
         | 
| 8 | 
            +
             | 
| 9 | 
            +
              The notes corresponding to AP and SP are fixed as rest. If there are multiple notes in a window (| .... |), it means that the Chinese character corresponding to the window is glissando, and each note needs to be assigned a duration.
         | 
| 10 | 
            +
               (AP和SP对应的音符固定为rest。若一个窗口(| .... |)内有多个音符, 代表该窗口对应的汉字为滑音, 需要为每个音符都分配时长。)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            article: |
         | 
| 13 | 
            +
              Note: This page is running on CPU, please refer to <a href='https://github.com/M4Singer/M4Singer' style='color:blue;' target='_blank\'>Github REPO</a> for the local running solutions and for our dataset.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
              --------
         | 
| 16 | 
            +
              If our work is useful for your research, please consider citing:
         | 
| 17 | 
            +
              ```bibtex
         | 
| 18 | 
            +
              @inproceedings{
         | 
| 19 | 
            +
                zhang2022msinger,
         | 
| 20 | 
            +
                title={M4Singer: A Multi-Style, Multi-Singer and Musical Score Provided Mandarin Singing Corpus},
         | 
| 21 | 
            +
                author={Lichao Zhang and Ruiqi Li and Shoutong Wang and Liqun Deng and Jinglin Liu and Yi Ren and Jinzheng He and Rongjie Huang and Jieming Zhu and Xiao Chen and Zhou Zhao},
         | 
| 22 | 
            +
                booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
         | 
| 23 | 
            +
                year={2022},
         | 
| 24 | 
            +
              }
         | 
| 25 | 
            +
              ```
         | 
| 26 | 
            +
             | 
| 27 | 
            +
              
         | 
| 28 | 
            +
            example_inputs:
         | 
| 29 | 
            +
              - |-
         | 
| 30 | 
            +
                Tenor-1<sep>AP你要相信AP相信我们会像童话故事里AP<sep>rest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest<sep>0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14
         | 
| 31 | 
            +
              - |-
         | 
| 32 | 
            +
                Tenor-1<sep>AP因为在一千年以后AP世界早已没有我AP<sep>rest | C#4 | D4 | E4 | F#4 | E4 | D4 G#3 | A3 | D4 E4 | rest | F#4 | E4 | D4 | C#4 | B3 F#3 | F#3 | C4 C#4 | rest<sep>0.18 | 0.32 | 0.38 | 0.81 | 0.38 | 0.39 | 0.3155 0.2045 | 0.28 | 0.4609 1.0291 | 0.27 | 0.42 | 0.15 | 0.53 | 0.22 | 0.3059 0.2841 | 0.4 | 0.2909 1.1091 | 0.3
         | 
| 33 | 
            +
              - |-
         | 
| 34 | 
            +
                Tenor-2<sep>AP可是你在敲打AP我的窗棂AP<sep>rest | G#3 | B3 | B3 C#4 | E4 | C#4 B3 | G#3 | rest | C3 | E3 | B3 G#3 | F#3 | rest<sep>0.2 | 0.38 | 0.48 | 0.41 0.72 | 0.39 | 0.5195 0.2905 | 0.5 | 0.33 | 0.4 | 0.31 | 0.565 0.265 | 1.15 | 0.24
         | 
| 35 | 
            +
              - |-
         | 
| 36 | 
            +
                Tenor-2<sep>SP一杯敬朝阳一杯敬月光AP<sep>rest | G#3 | G#3 | G#3 | G3 | G3 G#3 | G3 | C4 | C4 | A#3 | C4 | rest<sep>0.33 | 0.26 | 0.23 | 0.27 | 0.36 | 0.3159 0.4041 | 0.54 | 0.21 | 0.32 | 0.24 | 0.58 | 0.17
         | 
| 37 | 
            +
              - |-
         | 
| 38 | 
            +
                Soprano-1<sep>SP乱石穿空AP惊涛拍岸AP<sep>rest | C#5 | D#5 | F5 D#5 | C#5 | rest | C#5 | C#5 | C#5 G#4 | G#4 | rest<sep>0.325 | 0.75 | 0.54 | 0.48 0.55 | 1.38 | 0.31 | 0.55 | 0.48 | 0.4891 0.4709 | 1.15 | 0.22
         | 
| 39 | 
            +
              - |-
         | 
| 40 | 
            +
                Soprano-1<sep>AP点点滴滴染绿了村寨AP<sep>rest | C5 | A#4 | C5 | D#5 F5 D#5 | D#5 | C5 | C5 | C5 | A#4 | rest<sep>0.175 | 0.24 | 0.26 | 1.08 | 0.3541 0.4364 0.2195 | 0.47 | 0.27 | 0.12 | 0.51 | 0.72 | 0.15
         | 
| 41 | 
            +
              - |-
         | 
| 42 | 
            +
                Alto-2<sep>AP拒绝声色的张扬AP不拒绝你AP<sep>rest | C4 | C4 | C4 | B3 A3 | C4 | C4 D4 | D4 | rest | D4 | D4 | C4 | G4 E4 | rest<sep>0.49 | 0.31 | 0.18 | 0.48 | 0.3 0.4 | 0.25 | 0.3591 0.2409 | 0.46 | 0.34 | 0.4 | 0.45 | 0.45 | 2.4545 0.9855 | 0.215
         | 
| 43 | 
            +
              - |-
         | 
| 44 | 
            +
                Alto-2<sep>AP半醒着AP笑着哭着都快活AP<sep>rest | D4 | B3 | C4 D4 | rest | E4 | D4 | E4 | D4 | E4 | E4 F#4 | F4 F#4 | rest<sep>0.165 | 0.45 | 0.53 | 0.3859 0.2441 | 0.35 | 0.38 | 0.17 | 0.32 | 0.26 | 0.33 | 0.38 0.21 | 0.3309 0.9491 | 0.125
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            inference_cls: inference.m4singer.ds_e2e.DiffSingerE2EInfer
         | 
| 48 | 
            +
            exp_name: m4singer_diff_e2e
         | 
    	
        inference/m4singer/gradio/infer.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import importlib
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
            import yaml
         | 
| 6 | 
            +
            from gradio.components import Textbox, Dropdown
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from inference.m4singer.base_svs_infer import BaseSVSInfer
         | 
| 9 | 
            +
            from utils.hparams import set_hparams
         | 
| 10 | 
            +
            from utils.hparams import hparams as hp
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            from inference.m4singer.gradio.share_btn import community_icon_html, loading_icon_html, share_js
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            class GradioInfer:
         | 
| 15 | 
            +
                def __init__(self, exp_name, inference_cls, title, description, article, example_inputs):
         | 
| 16 | 
            +
                    self.exp_name = exp_name
         | 
| 17 | 
            +
                    self.title = title
         | 
| 18 | 
            +
                    self.description = description
         | 
| 19 | 
            +
                    self.article = article
         | 
| 20 | 
            +
                    self.example_inputs = example_inputs
         | 
| 21 | 
            +
                    pkg = ".".join(inference_cls.split(".")[:-1])
         | 
| 22 | 
            +
                    cls_name = inference_cls.split(".")[-1]
         | 
| 23 | 
            +
                    self.inference_cls = getattr(importlib.import_module(pkg), cls_name)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def greet(self, singer, text, notes, notes_duration):
         | 
| 26 | 
            +
                    PUNCS = '。?;:'
         | 
| 27 | 
            +
                    sents = re.split(rf'([{PUNCS}])', text.replace('\n', ','))
         | 
| 28 | 
            +
                    sents_notes = re.split(rf'([{PUNCS}])', notes.replace('\n', ','))
         | 
| 29 | 
            +
                    sents_notes_dur = re.split(rf'([{PUNCS}])', notes_duration.replace('\n', ','))
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    if sents[-1] not in list(PUNCS):
         | 
| 32 | 
            +
                        sents = sents + ['']
         | 
| 33 | 
            +
                        sents_notes = sents_notes + ['']
         | 
| 34 | 
            +
                        sents_notes_dur = sents_notes_dur + ['']
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    audio_outs = []
         | 
| 37 | 
            +
                    s, n, n_dur = "", "", ""
         | 
| 38 | 
            +
                    for i in range(0, len(sents), 2):
         | 
| 39 | 
            +
                        if len(sents[i]) > 0:
         | 
| 40 | 
            +
                            s += sents[i] + sents[i + 1]
         | 
| 41 | 
            +
                            n += sents_notes[i] + sents_notes[i+1]
         | 
| 42 | 
            +
                            n_dur += sents_notes_dur[i] + sents_notes_dur[i+1]
         | 
| 43 | 
            +
                        if len(s) >= 400 or (i >= len(sents) - 2 and len(s) > 0):
         | 
| 44 | 
            +
                            audio_out = self.infer_ins.infer_once({
         | 
| 45 | 
            +
                                'spk_name': singer,
         | 
| 46 | 
            +
                                'text': s,
         | 
| 47 | 
            +
                                'notes': n,
         | 
| 48 | 
            +
                                'notes_duration': n_dur,
         | 
| 49 | 
            +
                            })
         | 
| 50 | 
            +
                            audio_out = audio_out * 32767
         | 
| 51 | 
            +
                            audio_out = audio_out.astype(np.int16)
         | 
| 52 | 
            +
                            audio_outs.append(audio_out)
         | 
| 53 | 
            +
                            audio_outs.append(np.zeros(int(hp['audio_sample_rate'] * 0.3)).astype(np.int16))
         | 
| 54 | 
            +
                            s = ""
         | 
| 55 | 
            +
                            n = ""
         | 
| 56 | 
            +
                    audio_outs = np.concatenate(audio_outs)
         | 
| 57 | 
            +
                    return (hp['audio_sample_rate'], audio_outs), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def run(self):
         | 
| 60 | 
            +
                    set_hparams(config=f'checkpoints/{self.exp_name}/config.yaml', exp_name=self.exp_name, print_hparams=False)
         | 
| 61 | 
            +
                    infer_cls = self.inference_cls
         | 
| 62 | 
            +
                    self.infer_ins: BaseSVSInfer = infer_cls(hp)
         | 
| 63 | 
            +
                    example_inputs = self.example_inputs
         | 
| 64 | 
            +
                    for i in range(len(example_inputs)):
         | 
| 65 | 
            +
                        singer, text, notes, notes_dur = example_inputs[i].split('<sep>')
         | 
| 66 | 
            +
                        example_inputs[i] = [singer, text, notes, notes_dur]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    singerList = \
         | 
| 69 | 
            +
                        [
         | 
| 70 | 
            +
                        'Tenor-1', 'Tenor-2', 'Tenor-3', 'Tenor-4', 'Tenor-5', 'Tenor-6', 'Tenor-7',
         | 
| 71 | 
            +
                        'Alto-1', 'Alto-2', 'Alto-3', 'Alto-4', 'Alto-5', 'Alto-6', 'Alto-7',
         | 
| 72 | 
            +
                        'Soprano-1', 'Soprano-2', 'Soprano-3',
         | 
| 73 | 
            +
                        'Bass-1',  'Bass-2',  'Bass-3',
         | 
| 74 | 
            +
                        ]
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    css = """
         | 
| 77 | 
            +
                    #share-btn-container {
         | 
| 78 | 
            +
                        display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
         | 
| 79 | 
            +
                    }
         | 
| 80 | 
            +
                    #share-btn {
         | 
| 81 | 
            +
                        all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
         | 
| 82 | 
            +
                    }
         | 
| 83 | 
            +
                    #share-btn * {
         | 
| 84 | 
            +
                        all: unset;
         | 
| 85 | 
            +
                    }
         | 
| 86 | 
            +
                    #share-btn-container div:nth-child(-n+2){
         | 
| 87 | 
            +
                        width: auto !important;
         | 
| 88 | 
            +
                        min-height: 0px !important;
         | 
| 89 | 
            +
                    }
         | 
| 90 | 
            +
                    #share-btn-container .wrap {
         | 
| 91 | 
            +
                        display: none !important;
         | 
| 92 | 
            +
                    }
         | 
| 93 | 
            +
                    """
         | 
| 94 | 
            +
                    with gr.Blocks(css=css) as demo:
         | 
| 95 | 
            +
                        gr.HTML("""<div style="text-align: center; margin: 0 auto;">
         | 
| 96 | 
            +
                                      <div
         | 
| 97 | 
            +
                                      style="
         | 
| 98 | 
            +
                                          display: inline-flex;
         | 
| 99 | 
            +
                                          align-items: center;
         | 
| 100 | 
            +
                                          gap: 0.8rem;
         | 
| 101 | 
            +
                                          font-size: 1.75rem;
         | 
| 102 | 
            +
                                      "
         | 
| 103 | 
            +
                                      >
         | 
| 104 | 
            +
                                      <h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 14px;">
         | 
| 105 | 
            +
                                          M4Singer
         | 
| 106 | 
            +
                                      </h1>
         | 
| 107 | 
            +
                                      </div>
         | 
| 108 | 
            +
                                    </div>
         | 
| 109 | 
            +
                                    """
         | 
| 110 | 
            +
                                )
         | 
| 111 | 
            +
                        gr.Markdown(self.description)
         | 
| 112 | 
            +
                        with gr.Row():
         | 
| 113 | 
            +
                            with gr.Column():
         | 
| 114 | 
            +
                                singer_l = Dropdown(choices=singerList, value=example_inputs[0][0], label="SingerID", elem_id="inp_singer")
         | 
| 115 | 
            +
                                inp_text = Textbox(lines=2, placeholder=None, value=example_inputs[0][1], label="input text", elem_id="inp_text")
         | 
| 116 | 
            +
                                inp_note = Textbox(lines=2, placeholder=None, value=example_inputs[0][2], label="input note", elem_id="inp_note")
         | 
| 117 | 
            +
                                inp_duration = Textbox(lines=2, placeholder=None, value=example_inputs[0][3], label="input duration", elem_id="inp_duration")
         | 
| 118 | 
            +
                                generate = gr.Button("Generate Singing Voice from Musical Score")
         | 
| 119 | 
            +
                            with gr.Column(lem_id="col-container"):
         | 
| 120 | 
            +
                                singing_output = gr.Audio(label="Result", type="numpy", elem_id="music-output")
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                                with gr.Group(elem_id="share-btn-container"):
         | 
| 123 | 
            +
                                    community_icon = gr.HTML(community_icon_html, visible=False)
         | 
| 124 | 
            +
                                    loading_icon = gr.HTML(loading_icon_html, visible=False)
         | 
| 125 | 
            +
                                    share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
         | 
| 126 | 
            +
                        gr.Examples(examples=self.example_inputs,
         | 
| 127 | 
            +
                                    inputs=[singer_l, inp_text, inp_note, inp_duration],
         | 
| 128 | 
            +
                                    outputs=[singing_output, share_button, community_icon, loading_icon],
         | 
| 129 | 
            +
                                    fn=self.greet,
         | 
| 130 | 
            +
                                    cache_examples=True)
         | 
| 131 | 
            +
                        gr.Markdown(self.article)
         | 
| 132 | 
            +
                        generate.click(self.greet,
         | 
| 133 | 
            +
                                           inputs=[singer_l, inp_text, inp_note, inp_duration],
         | 
| 134 | 
            +
                                           outputs=[singing_output, share_button, community_icon, loading_icon],)
         | 
| 135 | 
            +
                        share_button.click(None, [], [], _js=share_js)
         | 
| 136 | 
            +
                    demo.queue().launch(share=False)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            if __name__ == '__main__':
         | 
| 140 | 
            +
                gradio_config = yaml.safe_load(open('inference/m4singer/gradio/gradio_settings.yaml'))
         | 
| 141 | 
            +
                g = GradioInfer(**gradio_config)
         | 
| 142 | 
            +
                g.run()
         | 
| 143 | 
            +
             | 
    	
        inference/m4singer/gradio/share_btn.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
         | 
| 2 | 
            +
                <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
         | 
| 3 | 
            +
                <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
         | 
| 4 | 
            +
            </svg>"""
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
         | 
| 7 | 
            +
               style="color: #ffffff; 
         | 
| 8 | 
            +
            "
         | 
| 9 | 
            +
               xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            share_js = """async () => {
         | 
| 12 | 
            +
            	async function uploadFile(file){
         | 
| 13 | 
            +
            		const UPLOAD_URL = 'https://huggingface.co/uploads';
         | 
| 14 | 
            +
            		const response = await fetch(UPLOAD_URL, {
         | 
| 15 | 
            +
            			method: 'POST',
         | 
| 16 | 
            +
            			headers: {
         | 
| 17 | 
            +
            				'Content-Type': file.type,
         | 
| 18 | 
            +
            				'X-Requested-With': 'XMLHttpRequest',
         | 
| 19 | 
            +
            			},
         | 
| 20 | 
            +
            			body: file, /// <- File inherits from Blob
         | 
| 21 | 
            +
            		});
         | 
| 22 | 
            +
            		const url = await response.text();
         | 
| 23 | 
            +
            		return url;
         | 
| 24 | 
            +
            	}
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                async function getOutputMusicFile(audioEL){
         | 
| 27 | 
            +
                    const res = await fetch(audioEL.src);
         | 
| 28 | 
            +
                    const blob = await res.blob();
         | 
| 29 | 
            +
                    const audioId = Date.now() % 200;
         | 
| 30 | 
            +
                    const fileName = `SVS-${{audioId}}.wav`;
         | 
| 31 | 
            +
                    const musicBlob = new File([blob], fileName, { type: 'audio/wav' });
         | 
| 32 | 
            +
                    return musicBlob;
         | 
| 33 | 
            +
            	}
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                const gradioEl = document.querySelector('body > gradio-app');
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                //const gradioEl = document.querySelector("gradio-app").shadowRoot;
         | 
| 38 | 
            +
                const inputSinger = gradioEl.querySelector('#inp_singer select').value;
         | 
| 39 | 
            +
                const inputText = gradioEl.querySelector('#inp_text textarea').value;
         | 
| 40 | 
            +
                const inputNote = gradioEl.querySelector('#inp_note textarea').value;
         | 
| 41 | 
            +
                const inputDuration = gradioEl.querySelector('#inp_duration textarea').value;
         | 
| 42 | 
            +
                const outputMusic = gradioEl.querySelector('#music-output audio');
         | 
| 43 | 
            +
                const outputMusic_src = gradioEl.querySelector('#music-output audio').src;
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                const outputMusic_name = outputMusic_src.split('/').pop();
         | 
| 46 | 
            +
                let titleTxt = outputMusic_name;
         | 
| 47 | 
            +
                if(titleTxt.length > 30){
         | 
| 48 | 
            +
                    titleTxt = 'demo';
         | 
| 49 | 
            +
                }
         | 
| 50 | 
            +
                const shareBtnEl = gradioEl.querySelector('#share-btn');
         | 
| 51 | 
            +
                const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
         | 
| 52 | 
            +
                const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
         | 
| 53 | 
            +
                if(!outputMusic){
         | 
| 54 | 
            +
                    return;
         | 
| 55 | 
            +
                };
         | 
| 56 | 
            +
                shareBtnEl.style.pointerEvents = 'none';
         | 
| 57 | 
            +
                shareIconEl.style.display = 'none';
         | 
| 58 | 
            +
                loadingIconEl.style.removeProperty('display');
         | 
| 59 | 
            +
                const musicFile = await getOutputMusicFile(outputMusic);
         | 
| 60 | 
            +
                const dataOutputMusic = await uploadFile(musicFile);
         | 
| 61 | 
            +
                const descriptionMd = `#### Input Musical Score:
         | 
| 62 | 
            +
            ${inputSinger}
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            ${inputText}
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            ${inputNote}
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            ${inputDuration}
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
            #### Singing Voice:
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
            <audio controls>
         | 
| 73 | 
            +
                <source src="${dataOutputMusic}" type="audio/wav">
         | 
| 74 | 
            +
            Your browser does not support the audio element.
         | 
| 75 | 
            +
            </audio>
         | 
| 76 | 
            +
            `;
         | 
| 77 | 
            +
                const params = new URLSearchParams({
         | 
| 78 | 
            +
                    title: titleTxt,
         | 
| 79 | 
            +
                    description: descriptionMd,
         | 
| 80 | 
            +
                });
         | 
| 81 | 
            +
            	const paramsStr = params.toString();
         | 
| 82 | 
            +
            	window.open(`https://huggingface.co/spaces/zlc99/M4Singer/discussions/new?${paramsStr}`, '_blank');
         | 
| 83 | 
            +
                shareBtnEl.style.removeProperty('pointer-events');
         | 
| 84 | 
            +
                shareIconEl.style.removeProperty('display');
         | 
| 85 | 
            +
                loadingIconEl.style.display = 'none';
         | 
| 86 | 
            +
            }"""
         | 
    	
        inference/m4singer/m4singer/m4singer_pinyin2ph.txt
    ADDED
    
    | @@ -0,0 +1,413 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            | a      | a        |
         | 
| 2 | 
            +
            | ai     | ai       |
         | 
| 3 | 
            +
            | an     | an       |
         | 
| 4 | 
            +
            | ang    | ang      |
         | 
| 5 | 
            +
            | ao     | ao       |
         | 
| 6 | 
            +
            | ba     | b a      |
         | 
| 7 | 
            +
            | bai    | b ai     |
         | 
| 8 | 
            +
            | ban    | b an     |
         | 
| 9 | 
            +
            | bang   | b ang    |
         | 
| 10 | 
            +
            | bao    | b ao     |
         | 
| 11 | 
            +
            | bei    | b ei     |
         | 
| 12 | 
            +
            | ben    | b en     |
         | 
| 13 | 
            +
            | beng   | b eng    |
         | 
| 14 | 
            +
            | bi     | b i      |
         | 
| 15 | 
            +
            | bian   | b ian    |
         | 
| 16 | 
            +
            | biao   | b iao    |
         | 
| 17 | 
            +
            | bie    | b ie     |
         | 
| 18 | 
            +
            | bin    | b in     |
         | 
| 19 | 
            +
            | bing   | b ing    |
         | 
| 20 | 
            +
            | bo     | b o      |
         | 
| 21 | 
            +
            | bu     | b u      |
         | 
| 22 | 
            +
            | ca     | c a      |
         | 
| 23 | 
            +
            | cai    | c ai     |
         | 
| 24 | 
            +
            | can    | c an     |
         | 
| 25 | 
            +
            | cang   | c ang    |
         | 
| 26 | 
            +
            | cao    | c ao     |
         | 
| 27 | 
            +
            | ce     | c e      |
         | 
| 28 | 
            +
            | cei    | c ei     |
         | 
| 29 | 
            +
            | cen    | c en     |
         | 
| 30 | 
            +
            | ceng   | c eng    |
         | 
| 31 | 
            +
            | cha    | ch a     |
         | 
| 32 | 
            +
            | chai   | ch ai    |
         | 
| 33 | 
            +
            | chan   | ch an    |
         | 
| 34 | 
            +
            | chang  | ch ang   |
         | 
| 35 | 
            +
            | chao   | ch ao    |
         | 
| 36 | 
            +
            | che    | ch e     |
         | 
| 37 | 
            +
            | chen   | ch en    |
         | 
| 38 | 
            +
            | cheng  | ch eng   |
         | 
| 39 | 
            +
            | chi    | ch i     |
         | 
| 40 | 
            +
            | chong  | ch ong   |
         | 
| 41 | 
            +
            | chou   | ch ou    |
         | 
| 42 | 
            +
            | chu    | ch u     |
         | 
| 43 | 
            +
            | chua   | ch ua    |
         | 
| 44 | 
            +
            | chuai  | ch uai   |
         | 
| 45 | 
            +
            | chuan  | ch uan   |
         | 
| 46 | 
            +
            | chuang | ch uang  |
         | 
| 47 | 
            +
            | chui   | ch uei   |
         | 
| 48 | 
            +
            | chun   | ch uen   |
         | 
| 49 | 
            +
            | chuo   | ch uo    |
         | 
| 50 | 
            +
            | ci     | c i      |
         | 
| 51 | 
            +
            | cong   | c ong    |
         | 
| 52 | 
            +
            | cou    | c ou     |
         | 
| 53 | 
            +
            | cu     | c u      |
         | 
| 54 | 
            +
            | cuan   | c uan    |
         | 
| 55 | 
            +
            | cui    | c uei    |
         | 
| 56 | 
            +
            | cun    | c uen    |
         | 
| 57 | 
            +
            | cuo    | c uo     |
         | 
| 58 | 
            +
            | da     | d a      |
         | 
| 59 | 
            +
            | dai    | d ai     |
         | 
| 60 | 
            +
            | dan    | d an     |
         | 
| 61 | 
            +
            | dang   | d ang    |
         | 
| 62 | 
            +
            | dao    | d ao     |
         | 
| 63 | 
            +
            | de     | d e      |
         | 
| 64 | 
            +
            | dei    | d ei     |
         | 
| 65 | 
            +
            | den    | d en     |
         | 
| 66 | 
            +
            | deng   | d eng    |
         | 
| 67 | 
            +
            | di     | d i      |
         | 
| 68 | 
            +
            | dia    | d ia     |
         | 
| 69 | 
            +
            | dian   | d ian    |
         | 
| 70 | 
            +
            | diao   | d iao    |
         | 
| 71 | 
            +
            | die    | d ie     |
         | 
| 72 | 
            +
            | ding   | d ing    |
         | 
| 73 | 
            +
            | diu    | d iou    |
         | 
| 74 | 
            +
            | dong   | d ong    |
         | 
| 75 | 
            +
            | dou    | d ou     |
         | 
| 76 | 
            +
            | du     | d u      |
         | 
| 77 | 
            +
            | duan   | d uan    |
         | 
| 78 | 
            +
            | dui    | d uei    |
         | 
| 79 | 
            +
            | dun    | d uen    |
         | 
| 80 | 
            +
            | duo    | d uo     |
         | 
| 81 | 
            +
            | e      | e        |
         | 
| 82 | 
            +
            | ei     | ei       |
         | 
| 83 | 
            +
            | en     | en       |
         | 
| 84 | 
            +
            | eng    | eng      |
         | 
| 85 | 
            +
            | er     | er       |
         | 
| 86 | 
            +
            | fa     | f a      |
         | 
| 87 | 
            +
            | fan    | f an     |
         | 
| 88 | 
            +
            | fang   | f ang    |
         | 
| 89 | 
            +
            | fei    | f ei     |
         | 
| 90 | 
            +
            | fen    | f en     |
         | 
| 91 | 
            +
            | feng   | f eng    |
         | 
| 92 | 
            +
            | fo     | f o      |
         | 
| 93 | 
            +
            | fou    | f ou     |
         | 
| 94 | 
            +
            | fu     | f u      |
         | 
| 95 | 
            +
            | ga     | g a      |
         | 
| 96 | 
            +
            | gai    | g ai     |
         | 
| 97 | 
            +
            | gan    | g an     |
         | 
| 98 | 
            +
            | gang   | g ang    |
         | 
| 99 | 
            +
            | gao    | g ao     |
         | 
| 100 | 
            +
            | ge     | g e      |
         | 
| 101 | 
            +
            | gei    | g ei     |
         | 
| 102 | 
            +
            | gen    | g en     |
         | 
| 103 | 
            +
            | geng   | g eng    |
         | 
| 104 | 
            +
            | gong   | g ong    |
         | 
| 105 | 
            +
            | gou    | g ou     |
         | 
| 106 | 
            +
            | gu     | g u      |
         | 
| 107 | 
            +
            | gua    | g ua     |
         | 
| 108 | 
            +
            | guai   | g uai    |
         | 
| 109 | 
            +
            | guan   | g uan    |
         | 
| 110 | 
            +
            | guang  | g uang   |
         | 
| 111 | 
            +
            | gui    | g uei    |
         | 
| 112 | 
            +
            | gun    | g uen    |
         | 
| 113 | 
            +
            | guo    | g uo     |
         | 
| 114 | 
            +
            | ha     | h a      |
         | 
| 115 | 
            +
            | hai    | h ai     |
         | 
| 116 | 
            +
            | han    | h an     |
         | 
| 117 | 
            +
            | hang   | h ang    |
         | 
| 118 | 
            +
            | hao    | h ao     |
         | 
| 119 | 
            +
            | he     | h e      |
         | 
| 120 | 
            +
            | hei    | h ei     |
         | 
| 121 | 
            +
            | hen    | h en     |
         | 
| 122 | 
            +
            | heng   | h eng    |
         | 
| 123 | 
            +
            | hong   | h ong    |
         | 
| 124 | 
            +
            | hou    | h ou     |
         | 
| 125 | 
            +
            | hu     | h u      |
         | 
| 126 | 
            +
            | hua    | h ua     |
         | 
| 127 | 
            +
            | huai   | h uai    |
         | 
| 128 | 
            +
            | huan   | h uan    |
         | 
| 129 | 
            +
            | huang  | h uang   |
         | 
| 130 | 
            +
            | hui    | h uei    |
         | 
| 131 | 
            +
            | hun    | h uen    |
         | 
| 132 | 
            +
            | huo    | h uo     |
         | 
| 133 | 
            +
            | ji     | j i      |
         | 
| 134 | 
            +
            | jia    | j ia     |
         | 
| 135 | 
            +
            | jian   | j ian    |
         | 
| 136 | 
            +
            | jiang  | j iang   |
         | 
| 137 | 
            +
            | jiao   | j iao    |
         | 
| 138 | 
            +
            | jie    | j ie     |
         | 
| 139 | 
            +
            | jin    | j in     |
         | 
| 140 | 
            +
            | jing   | j ing    |
         | 
| 141 | 
            +
            | jiong  | j iong   |
         | 
| 142 | 
            +
            | jiu    | j iou    |
         | 
| 143 | 
            +
            | ju     | j v      |
         | 
| 144 | 
            +
            | juan   | j van    |
         | 
| 145 | 
            +
            | jue    | j ve     |
         | 
| 146 | 
            +
            | jun    | j vn     |
         | 
| 147 | 
            +
            | ka     | k a      |
         | 
| 148 | 
            +
            | kai    | k ai     |
         | 
| 149 | 
            +
            | kan    | k an     |
         | 
| 150 | 
            +
            | kang   | k ang    |
         | 
| 151 | 
            +
            | kao    | k ao     |
         | 
| 152 | 
            +
            | ke     | k e      |
         | 
| 153 | 
            +
            | kei    | k ei     |
         | 
| 154 | 
            +
            | ken    | k en     |
         | 
| 155 | 
            +
            | keng   | k eng    |
         | 
| 156 | 
            +
            | kong   | k ong    |
         | 
| 157 | 
            +
            | kou    | k ou     |
         | 
| 158 | 
            +
            | ku     | k u      |
         | 
| 159 | 
            +
            | kua    | k ua     |
         | 
| 160 | 
            +
            | kuai   | k uai    |
         | 
| 161 | 
            +
            | kuan   | k uan    |
         | 
| 162 | 
            +
            | kuang  | k uang   |
         | 
| 163 | 
            +
            | kui    | k uei    |
         | 
| 164 | 
            +
            | kun    | k uen    |
         | 
| 165 | 
            +
            | kuo    | k uo     |
         | 
| 166 | 
            +
            | la     | l a      |
         | 
| 167 | 
            +
            | lai    | l ai     |
         | 
| 168 | 
            +
            | lan    | l an     |
         | 
| 169 | 
            +
            | lang   | l ang    |
         | 
| 170 | 
            +
            | lao    | l ao     |
         | 
| 171 | 
            +
            | le     | l e      |
         | 
| 172 | 
            +
            | lei    | l ei     |
         | 
| 173 | 
            +
            | leng   | l eng    |
         | 
| 174 | 
            +
            | li     | l i      |
         | 
| 175 | 
            +
            | lia    | l ia     |
         | 
| 176 | 
            +
            | lian   | l ian    |
         | 
| 177 | 
            +
            | liang  | l iang   |
         | 
| 178 | 
            +
            | liao   | l iao    |
         | 
| 179 | 
            +
            | lie    | l ie     |
         | 
| 180 | 
            +
            | lin    | l in     |
         | 
| 181 | 
            +
            | ling   | l ing    |
         | 
| 182 | 
            +
            | liu    | l iou     |
         | 
| 183 | 
            +
            | lo     | l o      |
         | 
| 184 | 
            +
            | long   | l ong    |
         | 
| 185 | 
            +
            | lou    | l ou     |
         | 
| 186 | 
            +
            | lu     | l u      |
         | 
| 187 | 
            +
            | luan   | l uan    |
         | 
| 188 | 
            +
            | lun    | l uen    |
         | 
| 189 | 
            +
            | luo    | l uo     |
         | 
| 190 | 
            +
            | lv     | l v      |
         | 
| 191 | 
            +
            | lve    | l ve     |
         | 
| 192 | 
            +
            | m      | m        |
         | 
| 193 | 
            +
            | ma     | m a      |
         | 
| 194 | 
            +
            | mai    | m ai     |
         | 
| 195 | 
            +
            | man    | m an     |
         | 
| 196 | 
            +
            | mang   | m ang    |
         | 
| 197 | 
            +
            | mao    | m ao     |
         | 
| 198 | 
            +
            | me     | m e      |
         | 
| 199 | 
            +
            | mei    | m ei     |
         | 
| 200 | 
            +
            | men    | m en     |
         | 
| 201 | 
            +
            | meng   | m eng    |
         | 
| 202 | 
            +
            | mi     | m i      |
         | 
| 203 | 
            +
            | mian   | m ian    |
         | 
| 204 | 
            +
            | miao   | m iao    |
         | 
| 205 | 
            +
            | mie    | m ie     |
         | 
| 206 | 
            +
            | min    | m in     |
         | 
| 207 | 
            +
            | ming   | m ing    |
         | 
| 208 | 
            +
            | miu    | m iou     |
         | 
| 209 | 
            +
            | mo     | m o      |
         | 
| 210 | 
            +
            | mou    | m ou     |
         | 
| 211 | 
            +
            | mu     | m u      |
         | 
| 212 | 
            +
            | n      | n        |
         | 
| 213 | 
            +
            | na     | n a      |
         | 
| 214 | 
            +
            | nai    | n ai     |
         | 
| 215 | 
            +
            | nan    | n an     |
         | 
| 216 | 
            +
            | nang   | n ang    |
         | 
| 217 | 
            +
            | nao    | n ao     |
         | 
| 218 | 
            +
            | ne     | n e      |
         | 
| 219 | 
            +
            | nei    | n ei     |
         | 
| 220 | 
            +
            | nen    | n en     |
         | 
| 221 | 
            +
            | neng   | n eng    |
         | 
| 222 | 
            +
            | ni     | n i      |
         | 
| 223 | 
            +
            | nian   | n ian    |
         | 
| 224 | 
            +
            | niang  | n iang   |
         | 
| 225 | 
            +
            | niao   | n iao    |
         | 
| 226 | 
            +
            | nie    | n ie     |
         | 
| 227 | 
            +
            | nin    | n in     |
         | 
| 228 | 
            +
            | ning   | n ing    |
         | 
| 229 | 
            +
            | niu    | n iou    |
         | 
| 230 | 
            +
            | nong   | n ong    |
         | 
| 231 | 
            +
            | nou    | n ou     |
         | 
| 232 | 
            +
            | nu     | n u      |
         | 
| 233 | 
            +
            | nuan   | n uan    |
         | 
| 234 | 
            +
            | nuo    | n uo     |
         | 
| 235 | 
            +
            | nv     | n v      |
         | 
| 236 | 
            +
            | nve    | n ve     |
         | 
| 237 | 
            +
            | o      | o        |
         | 
| 238 | 
            +
            | ou     | ou       |
         | 
| 239 | 
            +
            | pa     | p a      |
         | 
| 240 | 
            +
            | pai    | p ai     |
         | 
| 241 | 
            +
            | pan    | p an     |
         | 
| 242 | 
            +
            | pang   | p ang    |
         | 
| 243 | 
            +
            | pao    | p ao     |
         | 
| 244 | 
            +
            | pei    | p ei     |
         | 
| 245 | 
            +
            | pen    | p en     |
         | 
| 246 | 
            +
            | peng   | p eng    |
         | 
| 247 | 
            +
            | pi     | p i      |
         | 
| 248 | 
            +
            | pian   | p ian    |
         | 
| 249 | 
            +
            | piao   | p iao    |
         | 
| 250 | 
            +
            | pie    | p ie     |
         | 
| 251 | 
            +
            | pin    | p in     |
         | 
| 252 | 
            +
            | ping   | p ing    |
         | 
| 253 | 
            +
            | po     | p o      |
         | 
| 254 | 
            +
            | pou    | p ou     |
         | 
| 255 | 
            +
            | pu     | p u      |
         | 
| 256 | 
            +
            | qi     | q i      |
         | 
| 257 | 
            +
            | qia    | q ia     |
         | 
| 258 | 
            +
            | qian   | q ian    |
         | 
| 259 | 
            +
            | qiang  | q iang   |
         | 
| 260 | 
            +
            | qiao   | q iao    |
         | 
| 261 | 
            +
            | qie    | q ie     |
         | 
| 262 | 
            +
            | qin    | q in     |
         | 
| 263 | 
            +
            | qing   | q ing    |
         | 
| 264 | 
            +
            | qiong  | q iong   |
         | 
| 265 | 
            +
            | qiu    | q iou    |
         | 
| 266 | 
            +
            | qu     | q v      |
         | 
| 267 | 
            +
            | quan   | q van    |
         | 
| 268 | 
            +
            | que    | q ve     |
         | 
| 269 | 
            +
            | qun    | q vn     |
         | 
| 270 | 
            +
            | ran    | r an     |
         | 
| 271 | 
            +
            | rang   | r ang    |
         | 
| 272 | 
            +
            | rao    | r ao     |
         | 
| 273 | 
            +
            | re     | r e      |
         | 
| 274 | 
            +
            | ren    | r en     |
         | 
| 275 | 
            +
            | reng   | r eng    |
         | 
| 276 | 
            +
            | ri     | r i      |
         | 
| 277 | 
            +
            | rong   | r ong    |
         | 
| 278 | 
            +
            | rou    | r ou     |
         | 
| 279 | 
            +
            | ru     | r u      |
         | 
| 280 | 
            +
            | rua    | r ua     |
         | 
| 281 | 
            +
            | ruan   | r uan    |
         | 
| 282 | 
            +
            | rui    | r uei    |
         | 
| 283 | 
            +
            | run    | r uen    |
         | 
| 284 | 
            +
            | ruo    | r uo     |
         | 
| 285 | 
            +
            | sa     | s a      |
         | 
| 286 | 
            +
            | sai    | s ai     |
         | 
| 287 | 
            +
            | san    | s an     |
         | 
| 288 | 
            +
            | sang   | s ang    |
         | 
| 289 | 
            +
            | sao    | s ao     |
         | 
| 290 | 
            +
            | se     | s e      |
         | 
| 291 | 
            +
            | sen    | s en     |
         | 
| 292 | 
            +
            | seng   | s eng    |
         | 
| 293 | 
            +
            | sha    | sh a     |
         | 
| 294 | 
            +
            | shai   | sh ai    |
         | 
| 295 | 
            +
            | shan   | sh an    |
         | 
| 296 | 
            +
            | shang  | sh ang   |
         | 
| 297 | 
            +
            | shao   | sh ao    |
         | 
| 298 | 
            +
            | she    | sh e     |
         | 
| 299 | 
            +
            | shei   | sh ei    |
         | 
| 300 | 
            +
            | shen   | sh en    |
         | 
| 301 | 
            +
            | sheng  | sh eng   |
         | 
| 302 | 
            +
            | shi    | sh i     |
         | 
| 303 | 
            +
            | shou   | sh ou    |
         | 
| 304 | 
            +
            | shu    | sh u     |
         | 
| 305 | 
            +
            | shua   | sh ua    |
         | 
| 306 | 
            +
            | shuai  | sh uai   |
         | 
| 307 | 
            +
            | shuan  | sh uan   |
         | 
| 308 | 
            +
            | shuang | sh uang  |
         | 
| 309 | 
            +
            | shui   | sh uei   |
         | 
| 310 | 
            +
            | shun   | sh uen   |
         | 
| 311 | 
            +
            | shuo   | sh uo    |
         | 
| 312 | 
            +
            | si     | s i      |
         | 
| 313 | 
            +
            | song   | s ong    |
         | 
| 314 | 
            +
            | sou    | s ou     |
         | 
| 315 | 
            +
            | su     | s u      |
         | 
| 316 | 
            +
            | suan   | s uan    |
         | 
| 317 | 
            +
            | sui    | s uei    |
         | 
| 318 | 
            +
            | sun    | s uen    |
         | 
| 319 | 
            +
            | suo    | s uo     |
         | 
| 320 | 
            +
            | ta     | t a      |
         | 
| 321 | 
            +
            | tai    | t ai     |
         | 
| 322 | 
            +
            | tan    | t an     |
         | 
| 323 | 
            +
            | tang   | t ang    |
         | 
| 324 | 
            +
            | tao    | t ao     |
         | 
| 325 | 
            +
            | te     | t e      |
         | 
| 326 | 
            +
            | tei    | t ei     |
         | 
| 327 | 
            +
            | teng   | t eng    |
         | 
| 328 | 
            +
            | ti     | t i      |
         | 
| 329 | 
            +
            | tian   | t ian    |
         | 
| 330 | 
            +
            | tiao   | t iao    |
         | 
| 331 | 
            +
            | tie    | t ie     |
         | 
| 332 | 
            +
            | ting   | t ing    |
         | 
| 333 | 
            +
            | tong   | t ong    |
         | 
| 334 | 
            +
            | tou    | t ou     |
         | 
| 335 | 
            +
            | tu     | t u      |
         | 
| 336 | 
            +
            | tuan   | t uan    |
         | 
| 337 | 
            +
            | tui    | t uei    |
         | 
| 338 | 
            +
            | tun    | t uen    |
         | 
| 339 | 
            +
            | tuo    | t uo     |
         | 
| 340 | 
            +
            | wa     | ua       |
         | 
| 341 | 
            +
            | wai    | uai      |
         | 
| 342 | 
            +
            | wan    | uan      |
         | 
| 343 | 
            +
            | wang   | uang     |
         | 
| 344 | 
            +
            | wei    | uei      |
         | 
| 345 | 
            +
            | wen    | uen      |
         | 
| 346 | 
            +
            | weng   | ueng     |
         | 
| 347 | 
            +
            | wo     | uo       |
         | 
| 348 | 
            +
            | wu     | u        |
         | 
| 349 | 
            +
            | xi     | x i      |
         | 
| 350 | 
            +
            | xia    | x ia     |
         | 
| 351 | 
            +
            | xian   | x ian    |
         | 
| 352 | 
            +
            | xiang  | x iang   |
         | 
| 353 | 
            +
            | xiao   | x iao    |
         | 
| 354 | 
            +
            | xie    | x ie     |
         | 
| 355 | 
            +
            | xin    | x in     |
         | 
| 356 | 
            +
            | xing   | x ing    |
         | 
| 357 | 
            +
            | xiong  | x iong   |
         | 
| 358 | 
            +
            | xiu    | x iou     |
         | 
| 359 | 
            +
            | xu     | x v      |
         | 
| 360 | 
            +
            | xuan   | x van    |
         | 
| 361 | 
            +
            | xue    | x ve     |
         | 
| 362 | 
            +
            | xun    | x vn     |
         | 
| 363 | 
            +
            | ya     | ia       |
         | 
| 364 | 
            +
            | yan    | ian      |
         | 
| 365 | 
            +
            | yang   | iang     |
         | 
| 366 | 
            +
            | yao    | iao      |
         | 
| 367 | 
            +
            | ye     | ie       |
         | 
| 368 | 
            +
            | yi     | i        |
         | 
| 369 | 
            +
            | yin    | in       |
         | 
| 370 | 
            +
            | ying   | ing      |
         | 
| 371 | 
            +
            | yong   | iong     |
         | 
| 372 | 
            +
            | you    | iou      |
         | 
| 373 | 
            +
            | yu     | v        |
         | 
| 374 | 
            +
            | yuan   | van      |
         | 
| 375 | 
            +
            | yue    | ve       |
         | 
| 376 | 
            +
            | yun    | vn       |
         | 
| 377 | 
            +
            | za     | z a      |
         | 
| 378 | 
            +
            | zai    | z ai     |
         | 
| 379 | 
            +
            | zan    | z an     |
         | 
| 380 | 
            +
            | zang   | z ang    |
         | 
| 381 | 
            +
            | zao    | z ao     |
         | 
| 382 | 
            +
            | ze     | z e      |
         | 
| 383 | 
            +
            | zei    | z ei     |
         | 
| 384 | 
            +
            | zen    | z en     |
         | 
| 385 | 
            +
            | zeng   | z eng    |
         | 
| 386 | 
            +
            | zha    | zh a     |
         | 
| 387 | 
            +
            | zhai   | zh ai    |
         | 
| 388 | 
            +
            | zhan   | zh an    |
         | 
| 389 | 
            +
            | zhang  | zh ang   |
         | 
| 390 | 
            +
            | zhao   | zh ao    |
         | 
| 391 | 
            +
            | zhe    | zh e     |
         | 
| 392 | 
            +
            | zhei   | zh ei    |
         | 
| 393 | 
            +
            | zhen   | zh en    |
         | 
| 394 | 
            +
            | zheng  | zh eng   |
         | 
| 395 | 
            +
            | zhi    | zh i     |
         | 
| 396 | 
            +
            | zhong  | zh ong   |
         | 
| 397 | 
            +
            | zhou   | zh ou    |
         | 
| 398 | 
            +
            | zhu    | zh u     |
         | 
| 399 | 
            +
            | zhua   | zh ua    |
         | 
| 400 | 
            +
            | zhuai  | zh uai   |
         | 
| 401 | 
            +
            | zhuan  | zh uan   |
         | 
| 402 | 
            +
            | zhuang | zh uang  |
         | 
| 403 | 
            +
            | zhui   | zh uei   |
         | 
| 404 | 
            +
            | zhun   | zh uen   |
         | 
| 405 | 
            +
            | zhuo   | zh uo    |
         | 
| 406 | 
            +
            | zi     | z i      |
         | 
| 407 | 
            +
            | zong   | z ong    |
         | 
| 408 | 
            +
            | zou    | z ou     |
         | 
| 409 | 
            +
            | zu     | z u      |
         | 
| 410 | 
            +
            | zuan   | z uan    |
         | 
| 411 | 
            +
            | zui    | z uei    |
         | 
| 412 | 
            +
            | zun    | z uen    |
         | 
| 413 | 
            +
            | zuo    | z uo     |
         | 
    	
        inference/m4singer/m4singer/map.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            def m4singer_pinyin2ph_func():
         | 
| 2 | 
            +
                pinyin2phs = {'AP': '<AP>', 'SP': '<SP>'}
         | 
| 3 | 
            +
                with open('inference/m4singer/m4singer/m4singer_pinyin2ph.txt') as rf:
         | 
| 4 | 
            +
                    for line in rf.readlines():
         | 
| 5 | 
            +
                        elements = [x.strip() for x in line.split('|') if x.strip() != '']
         | 
| 6 | 
            +
                        pinyin2phs[elements[0]] = elements[1]
         | 
| 7 | 
            +
                return pinyin2phs
         | 
    	
        modules/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        modules/commons/common_layers.py
    ADDED
    
    | @@ -0,0 +1,668 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from torch import nn
         | 
| 4 | 
            +
            from torch.nn import Parameter
         | 
| 5 | 
            +
            import torch.onnx.operators
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            import utils
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Reshape(nn.Module):
         | 
| 11 | 
            +
                def __init__(self, *args):
         | 
| 12 | 
            +
                    super(Reshape, self).__init__()
         | 
| 13 | 
            +
                    self.shape = args
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def forward(self, x):
         | 
| 16 | 
            +
                    return x.view(self.shape)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class Permute(nn.Module):
         | 
| 20 | 
            +
                def __init__(self, *args):
         | 
| 21 | 
            +
                    super(Permute, self).__init__()
         | 
| 22 | 
            +
                    self.args = args
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def forward(self, x):
         | 
| 25 | 
            +
                    return x.permute(self.args)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class LinearNorm(torch.nn.Module):
         | 
| 29 | 
            +
                def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
         | 
| 30 | 
            +
                    super(LinearNorm, self).__init__()
         | 
| 31 | 
            +
                    self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    torch.nn.init.xavier_uniform_(
         | 
| 34 | 
            +
                        self.linear_layer.weight,
         | 
| 35 | 
            +
                        gain=torch.nn.init.calculate_gain(w_init_gain))
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def forward(self, x):
         | 
| 38 | 
            +
                    return self.linear_layer(x)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class ConvNorm(torch.nn.Module):
         | 
| 42 | 
            +
                def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
         | 
| 43 | 
            +
                             padding=None, dilation=1, bias=True, w_init_gain='linear'):
         | 
| 44 | 
            +
                    super(ConvNorm, self).__init__()
         | 
| 45 | 
            +
                    if padding is None:
         | 
| 46 | 
            +
                        assert (kernel_size % 2 == 1)
         | 
| 47 | 
            +
                        padding = int(dilation * (kernel_size - 1) / 2)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.conv = torch.nn.Conv1d(in_channels, out_channels,
         | 
| 50 | 
            +
                                                kernel_size=kernel_size, stride=stride,
         | 
| 51 | 
            +
                                                padding=padding, dilation=dilation,
         | 
| 52 | 
            +
                                                bias=bias)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    torch.nn.init.xavier_uniform_(
         | 
| 55 | 
            +
                        self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def forward(self, signal):
         | 
| 58 | 
            +
                    conv_signal = self.conv(signal)
         | 
| 59 | 
            +
                    return conv_signal
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def Embedding(num_embeddings, embedding_dim, padding_idx=None):
         | 
| 63 | 
            +
                m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
         | 
| 64 | 
            +
                nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
         | 
| 65 | 
            +
                if padding_idx is not None:
         | 
| 66 | 
            +
                    nn.init.constant_(m.weight[padding_idx], 0)
         | 
| 67 | 
            +
                return m
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
         | 
| 71 | 
            +
                if not export and torch.cuda.is_available():
         | 
| 72 | 
            +
                    try:
         | 
| 73 | 
            +
                        from apex.normalization import FusedLayerNorm
         | 
| 74 | 
            +
                        return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
         | 
| 75 | 
            +
                    except ImportError:
         | 
| 76 | 
            +
                        pass
         | 
| 77 | 
            +
                return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def Linear(in_features, out_features, bias=True):
         | 
| 81 | 
            +
                m = nn.Linear(in_features, out_features, bias)
         | 
| 82 | 
            +
                nn.init.xavier_uniform_(m.weight)
         | 
| 83 | 
            +
                if bias:
         | 
| 84 | 
            +
                    nn.init.constant_(m.bias, 0.)
         | 
| 85 | 
            +
                return m
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            class SinusoidalPositionalEmbedding(nn.Module):
         | 
| 89 | 
            +
                """This module produces sinusoidal positional embeddings of any length.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                Padding symbols are ignored.
         | 
| 92 | 
            +
                """
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __init__(self, embedding_dim, padding_idx, init_size=1024):
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    self.embedding_dim = embedding_dim
         | 
| 97 | 
            +
                    self.padding_idx = padding_idx
         | 
| 98 | 
            +
                    self.weights = SinusoidalPositionalEmbedding.get_embedding(
         | 
| 99 | 
            +
                        init_size,
         | 
| 100 | 
            +
                        embedding_dim,
         | 
| 101 | 
            +
                        padding_idx,
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
                    self.register_buffer('_float_tensor', torch.FloatTensor(1))
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                @staticmethod
         | 
| 106 | 
            +
                def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
         | 
| 107 | 
            +
                    """Build sinusoidal embeddings.
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    This matches the implementation in tensor2tensor, but differs slightly
         | 
| 110 | 
            +
                    from the description in Section 3.5 of "Attention Is All You Need".
         | 
| 111 | 
            +
                    """
         | 
| 112 | 
            +
                    half_dim = embedding_dim // 2
         | 
| 113 | 
            +
                    emb = math.log(10000) / (half_dim - 1)
         | 
| 114 | 
            +
                    emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
         | 
| 115 | 
            +
                    emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
         | 
| 116 | 
            +
                    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
         | 
| 117 | 
            +
                    if embedding_dim % 2 == 1:
         | 
| 118 | 
            +
                        # zero pad
         | 
| 119 | 
            +
                        emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
         | 
| 120 | 
            +
                    if padding_idx is not None:
         | 
| 121 | 
            +
                        emb[padding_idx, :] = 0
         | 
| 122 | 
            +
                    return emb
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
         | 
| 125 | 
            +
                    """Input is expected to be of size [bsz x seqlen]."""
         | 
| 126 | 
            +
                    bsz, seq_len = input.shape[:2]
         | 
| 127 | 
            +
                    max_pos = self.padding_idx + 1 + seq_len
         | 
| 128 | 
            +
                    if self.weights is None or max_pos > self.weights.size(0):
         | 
| 129 | 
            +
                        # recompute/expand embeddings if needed
         | 
| 130 | 
            +
                        self.weights = SinusoidalPositionalEmbedding.get_embedding(
         | 
| 131 | 
            +
                            max_pos,
         | 
| 132 | 
            +
                            self.embedding_dim,
         | 
| 133 | 
            +
                            self.padding_idx,
         | 
| 134 | 
            +
                        )
         | 
| 135 | 
            +
                    self.weights = self.weights.to(self._float_tensor)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if incremental_state is not None:
         | 
| 138 | 
            +
                        # positions is the same for every token when decoding a single step
         | 
| 139 | 
            +
                        pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
         | 
| 140 | 
            +
                        return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
         | 
| 143 | 
            +
                    return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def max_positions(self):
         | 
| 146 | 
            +
                    """Maximum number of supported positions."""
         | 
| 147 | 
            +
                    return int(1e5)  # an arbitrary large number
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            class ConvTBC(nn.Module):
         | 
| 151 | 
            +
                def __init__(self, in_channels, out_channels, kernel_size, padding=0):
         | 
| 152 | 
            +
                    super(ConvTBC, self).__init__()
         | 
| 153 | 
            +
                    self.in_channels = in_channels
         | 
| 154 | 
            +
                    self.out_channels = out_channels
         | 
| 155 | 
            +
                    self.kernel_size = kernel_size
         | 
| 156 | 
            +
                    self.padding = padding
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    self.weight = torch.nn.Parameter(torch.Tensor(
         | 
| 159 | 
            +
                        self.kernel_size, in_channels, out_channels))
         | 
| 160 | 
            +
                    self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def forward(self, input):
         | 
| 163 | 
            +
                    return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            class MultiheadAttention(nn.Module):
         | 
| 167 | 
            +
                def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
         | 
| 168 | 
            +
                             add_bias_kv=False, add_zero_attn=False, self_attention=False,
         | 
| 169 | 
            +
                             encoder_decoder_attention=False):
         | 
| 170 | 
            +
                    super().__init__()
         | 
| 171 | 
            +
                    self.embed_dim = embed_dim
         | 
| 172 | 
            +
                    self.kdim = kdim if kdim is not None else embed_dim
         | 
| 173 | 
            +
                    self.vdim = vdim if vdim is not None else embed_dim
         | 
| 174 | 
            +
                    self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    self.num_heads = num_heads
         | 
| 177 | 
            +
                    self.dropout = dropout
         | 
| 178 | 
            +
                    self.head_dim = embed_dim // num_heads
         | 
| 179 | 
            +
                    assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
         | 
| 180 | 
            +
                    self.scaling = self.head_dim ** -0.5
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    self.self_attention = self_attention
         | 
| 183 | 
            +
                    self.encoder_decoder_attention = encoder_decoder_attention
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
         | 
| 186 | 
            +
                                                                         'value to be of the same size'
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    if self.qkv_same_dim:
         | 
| 189 | 
            +
                        self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
         | 
| 192 | 
            +
                        self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
         | 
| 193 | 
            +
                        self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if bias:
         | 
| 196 | 
            +
                        self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
         | 
| 197 | 
            +
                    else:
         | 
| 198 | 
            +
                        self.register_parameter('in_proj_bias', None)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    if add_bias_kv:
         | 
| 203 | 
            +
                        self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
         | 
| 204 | 
            +
                        self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
         | 
| 205 | 
            +
                    else:
         | 
| 206 | 
            +
                        self.bias_k = self.bias_v = None
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    self.add_zero_attn = add_zero_attn
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    self.reset_parameters()
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    self.enable_torch_version = False
         | 
| 213 | 
            +
                    if hasattr(F, "multi_head_attention_forward"):
         | 
| 214 | 
            +
                        self.enable_torch_version = True
         | 
| 215 | 
            +
                    else:
         | 
| 216 | 
            +
                        self.enable_torch_version = False
         | 
| 217 | 
            +
                    self.last_attn_probs = None
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def reset_parameters(self):
         | 
| 220 | 
            +
                    if self.qkv_same_dim:
         | 
| 221 | 
            +
                        nn.init.xavier_uniform_(self.in_proj_weight)
         | 
| 222 | 
            +
                    else:
         | 
| 223 | 
            +
                        nn.init.xavier_uniform_(self.k_proj_weight)
         | 
| 224 | 
            +
                        nn.init.xavier_uniform_(self.v_proj_weight)
         | 
| 225 | 
            +
                        nn.init.xavier_uniform_(self.q_proj_weight)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    nn.init.xavier_uniform_(self.out_proj.weight)
         | 
| 228 | 
            +
                    if self.in_proj_bias is not None:
         | 
| 229 | 
            +
                        nn.init.constant_(self.in_proj_bias, 0.)
         | 
| 230 | 
            +
                        nn.init.constant_(self.out_proj.bias, 0.)
         | 
| 231 | 
            +
                    if self.bias_k is not None:
         | 
| 232 | 
            +
                        nn.init.xavier_normal_(self.bias_k)
         | 
| 233 | 
            +
                    if self.bias_v is not None:
         | 
| 234 | 
            +
                        nn.init.xavier_normal_(self.bias_v)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def forward(
         | 
| 237 | 
            +
                        self,
         | 
| 238 | 
            +
                        query, key, value,
         | 
| 239 | 
            +
                        key_padding_mask=None,
         | 
| 240 | 
            +
                        incremental_state=None,
         | 
| 241 | 
            +
                        need_weights=True,
         | 
| 242 | 
            +
                        static_kv=False,
         | 
| 243 | 
            +
                        attn_mask=None,
         | 
| 244 | 
            +
                        before_softmax=False,
         | 
| 245 | 
            +
                        need_head_weights=False,
         | 
| 246 | 
            +
                        enc_dec_attn_constraint_mask=None,
         | 
| 247 | 
            +
                        reset_attn_weight=None
         | 
| 248 | 
            +
                ):
         | 
| 249 | 
            +
                    """Input shape: Time x Batch x Channel
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    Args:
         | 
| 252 | 
            +
                        key_padding_mask (ByteTensor, optional): mask to exclude
         | 
| 253 | 
            +
                            keys that are pads, of shape `(batch, src_len)`, where
         | 
| 254 | 
            +
                            padding elements are indicated by 1s.
         | 
| 255 | 
            +
                        need_weights (bool, optional): return the attention weights,
         | 
| 256 | 
            +
                            averaged over heads (default: False).
         | 
| 257 | 
            +
                        attn_mask (ByteTensor, optional): typically used to
         | 
| 258 | 
            +
                            implement causal attention, where the mask prevents the
         | 
| 259 | 
            +
                            attention from looking forward in time (default: None).
         | 
| 260 | 
            +
                        before_softmax (bool, optional): return the raw attention
         | 
| 261 | 
            +
                            weights and values before the attention softmax.
         | 
| 262 | 
            +
                        need_head_weights (bool, optional): return the attention
         | 
| 263 | 
            +
                            weights for each head. Implies *need_weights*. Default:
         | 
| 264 | 
            +
                            return the average attention weights over all heads.
         | 
| 265 | 
            +
                    """
         | 
| 266 | 
            +
                    if need_head_weights:
         | 
| 267 | 
            +
                        need_weights = True
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    tgt_len, bsz, embed_dim = query.size()
         | 
| 270 | 
            +
                    assert embed_dim == self.embed_dim
         | 
| 271 | 
            +
                    assert list(query.size()) == [tgt_len, bsz, embed_dim]
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
         | 
| 274 | 
            +
                        if self.qkv_same_dim:
         | 
| 275 | 
            +
                            return F.multi_head_attention_forward(query, key, value,
         | 
| 276 | 
            +
                                                                  self.embed_dim, self.num_heads,
         | 
| 277 | 
            +
                                                                  self.in_proj_weight,
         | 
| 278 | 
            +
                                                                  self.in_proj_bias, self.bias_k, self.bias_v,
         | 
| 279 | 
            +
                                                                  self.add_zero_attn, self.dropout,
         | 
| 280 | 
            +
                                                                  self.out_proj.weight, self.out_proj.bias,
         | 
| 281 | 
            +
                                                                  self.training, key_padding_mask, need_weights,
         | 
| 282 | 
            +
                                                                  attn_mask)
         | 
| 283 | 
            +
                        else:
         | 
| 284 | 
            +
                            return F.multi_head_attention_forward(query, key, value,
         | 
| 285 | 
            +
                                                                  self.embed_dim, self.num_heads,
         | 
| 286 | 
            +
                                                                  torch.empty([0]),
         | 
| 287 | 
            +
                                                                  self.in_proj_bias, self.bias_k, self.bias_v,
         | 
| 288 | 
            +
                                                                  self.add_zero_attn, self.dropout,
         | 
| 289 | 
            +
                                                                  self.out_proj.weight, self.out_proj.bias,
         | 
| 290 | 
            +
                                                                  self.training, key_padding_mask, need_weights,
         | 
| 291 | 
            +
                                                                  attn_mask, use_separate_proj_weight=True,
         | 
| 292 | 
            +
                                                                  q_proj_weight=self.q_proj_weight,
         | 
| 293 | 
            +
                                                                  k_proj_weight=self.k_proj_weight,
         | 
| 294 | 
            +
                                                                  v_proj_weight=self.v_proj_weight)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    if incremental_state is not None:
         | 
| 297 | 
            +
                        print('Not implemented error.')
         | 
| 298 | 
            +
                        exit()
         | 
| 299 | 
            +
                    else:
         | 
| 300 | 
            +
                        saved_state = None
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    if self.self_attention:
         | 
| 303 | 
            +
                        # self-attention
         | 
| 304 | 
            +
                        q, k, v = self.in_proj_qkv(query)
         | 
| 305 | 
            +
                    elif self.encoder_decoder_attention:
         | 
| 306 | 
            +
                        # encoder-decoder attention
         | 
| 307 | 
            +
                        q = self.in_proj_q(query)
         | 
| 308 | 
            +
                        if key is None:
         | 
| 309 | 
            +
                            assert value is None
         | 
| 310 | 
            +
                            k = v = None
         | 
| 311 | 
            +
                        else:
         | 
| 312 | 
            +
                            k = self.in_proj_k(key)
         | 
| 313 | 
            +
                            v = self.in_proj_v(key)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    else:
         | 
| 316 | 
            +
                        q = self.in_proj_q(query)
         | 
| 317 | 
            +
                        k = self.in_proj_k(key)
         | 
| 318 | 
            +
                        v = self.in_proj_v(value)
         | 
| 319 | 
            +
                    q *= self.scaling
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    if self.bias_k is not None:
         | 
| 322 | 
            +
                        assert self.bias_v is not None
         | 
| 323 | 
            +
                        k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
         | 
| 324 | 
            +
                        v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
         | 
| 325 | 
            +
                        if attn_mask is not None:
         | 
| 326 | 
            +
                            attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
         | 
| 327 | 
            +
                        if key_padding_mask is not None:
         | 
| 328 | 
            +
                            key_padding_mask = torch.cat(
         | 
| 329 | 
            +
                                [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
         | 
| 332 | 
            +
                    if k is not None:
         | 
| 333 | 
            +
                        k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
         | 
| 334 | 
            +
                    if v is not None:
         | 
| 335 | 
            +
                        v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    if saved_state is not None:
         | 
| 338 | 
            +
                        print('Not implemented error.')
         | 
| 339 | 
            +
                        exit()
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    src_len = k.size(1)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    # This is part of a workaround to get around fork/join parallelism
         | 
| 344 | 
            +
                    # not supporting Optional types.
         | 
| 345 | 
            +
                    if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
         | 
| 346 | 
            +
                        key_padding_mask = None
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    if key_padding_mask is not None:
         | 
| 349 | 
            +
                        assert key_padding_mask.size(0) == bsz
         | 
| 350 | 
            +
                        assert key_padding_mask.size(1) == src_len
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    if self.add_zero_attn:
         | 
| 353 | 
            +
                        src_len += 1
         | 
| 354 | 
            +
                        k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
         | 
| 355 | 
            +
                        v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
         | 
| 356 | 
            +
                        if attn_mask is not None:
         | 
| 357 | 
            +
                            attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
         | 
| 358 | 
            +
                        if key_padding_mask is not None:
         | 
| 359 | 
            +
                            key_padding_mask = torch.cat(
         | 
| 360 | 
            +
                                [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    attn_weights = torch.bmm(q, k.transpose(1, 2))
         | 
| 363 | 
            +
                    attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    if attn_mask is not None:
         | 
| 368 | 
            +
                        if len(attn_mask.shape) == 2:
         | 
| 369 | 
            +
                            attn_mask = attn_mask.unsqueeze(0)
         | 
| 370 | 
            +
                        elif len(attn_mask.shape) == 3:
         | 
| 371 | 
            +
                            attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
         | 
| 372 | 
            +
                                bsz * self.num_heads, tgt_len, src_len)
         | 
| 373 | 
            +
                        attn_weights = attn_weights + attn_mask
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    if enc_dec_attn_constraint_mask is not None:  # bs x head x L_kv
         | 
| 376 | 
            +
                        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         | 
| 377 | 
            +
                        attn_weights = attn_weights.masked_fill(
         | 
| 378 | 
            +
                            enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
         | 
| 379 | 
            +
                            -1e9,
         | 
| 380 | 
            +
                        )
         | 
| 381 | 
            +
                        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    if key_padding_mask is not None:
         | 
| 384 | 
            +
                        # don't attend to padding symbols
         | 
| 385 | 
            +
                        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         | 
| 386 | 
            +
                        attn_weights = attn_weights.masked_fill(
         | 
| 387 | 
            +
                            key_padding_mask.unsqueeze(1).unsqueeze(2),
         | 
| 388 | 
            +
                            -1e9,
         | 
| 389 | 
            +
                        )
         | 
| 390 | 
            +
                        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    if before_softmax:
         | 
| 395 | 
            +
                        return attn_weights, v
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    attn_weights_float = utils.softmax(attn_weights, dim=-1)
         | 
| 398 | 
            +
                    attn_weights = attn_weights_float.type_as(attn_weights)
         | 
| 399 | 
            +
                    attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    if reset_attn_weight is not None:
         | 
| 402 | 
            +
                        if reset_attn_weight:
         | 
| 403 | 
            +
                            self.last_attn_probs = attn_probs.detach()
         | 
| 404 | 
            +
                        else:
         | 
| 405 | 
            +
                            assert self.last_attn_probs is not None
         | 
| 406 | 
            +
                            attn_probs = self.last_attn_probs
         | 
| 407 | 
            +
                    attn = torch.bmm(attn_probs, v)
         | 
| 408 | 
            +
                    assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
         | 
| 409 | 
            +
                    attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         | 
| 410 | 
            +
                    attn = self.out_proj(attn)
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    if need_weights:
         | 
| 413 | 
            +
                        attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
         | 
| 414 | 
            +
                        if not need_head_weights:
         | 
| 415 | 
            +
                            # average attention weights over heads
         | 
| 416 | 
            +
                            attn_weights = attn_weights.mean(dim=0)
         | 
| 417 | 
            +
                    else:
         | 
| 418 | 
            +
                        attn_weights = None
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    return attn, (attn_weights, attn_logits)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                def in_proj_qkv(self, query):
         | 
| 423 | 
            +
                    return self._in_proj(query).chunk(3, dim=-1)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                def in_proj_q(self, query):
         | 
| 426 | 
            +
                    if self.qkv_same_dim:
         | 
| 427 | 
            +
                        return self._in_proj(query, end=self.embed_dim)
         | 
| 428 | 
            +
                    else:
         | 
| 429 | 
            +
                        bias = self.in_proj_bias
         | 
| 430 | 
            +
                        if bias is not None:
         | 
| 431 | 
            +
                            bias = bias[:self.embed_dim]
         | 
| 432 | 
            +
                        return F.linear(query, self.q_proj_weight, bias)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                def in_proj_k(self, key):
         | 
| 435 | 
            +
                    if self.qkv_same_dim:
         | 
| 436 | 
            +
                        return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
         | 
| 437 | 
            +
                    else:
         | 
| 438 | 
            +
                        weight = self.k_proj_weight
         | 
| 439 | 
            +
                        bias = self.in_proj_bias
         | 
| 440 | 
            +
                        if bias is not None:
         | 
| 441 | 
            +
                            bias = bias[self.embed_dim:2 * self.embed_dim]
         | 
| 442 | 
            +
                        return F.linear(key, weight, bias)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                def in_proj_v(self, value):
         | 
| 445 | 
            +
                    if self.qkv_same_dim:
         | 
| 446 | 
            +
                        return self._in_proj(value, start=2 * self.embed_dim)
         | 
| 447 | 
            +
                    else:
         | 
| 448 | 
            +
                        weight = self.v_proj_weight
         | 
| 449 | 
            +
                        bias = self.in_proj_bias
         | 
| 450 | 
            +
                        if bias is not None:
         | 
| 451 | 
            +
                            bias = bias[2 * self.embed_dim:]
         | 
| 452 | 
            +
                        return F.linear(value, weight, bias)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                def _in_proj(self, input, start=0, end=None):
         | 
| 455 | 
            +
                    weight = self.in_proj_weight
         | 
| 456 | 
            +
                    bias = self.in_proj_bias
         | 
| 457 | 
            +
                    weight = weight[start:end, :]
         | 
| 458 | 
            +
                    if bias is not None:
         | 
| 459 | 
            +
                        bias = bias[start:end]
         | 
| 460 | 
            +
                    return F.linear(input, weight, bias)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
             | 
| 463 | 
            +
                def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
         | 
| 464 | 
            +
                    return attn_weights
         | 
| 465 | 
            +
             | 
| 466 | 
            +
             | 
| 467 | 
            +
            class Swish(torch.autograd.Function):
         | 
| 468 | 
            +
                @staticmethod
         | 
| 469 | 
            +
                def forward(ctx, i):
         | 
| 470 | 
            +
                    result = i * torch.sigmoid(i)
         | 
| 471 | 
            +
                    ctx.save_for_backward(i)
         | 
| 472 | 
            +
                    return result
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                @staticmethod
         | 
| 475 | 
            +
                def backward(ctx, grad_output):
         | 
| 476 | 
            +
                    i = ctx.saved_variables[0]
         | 
| 477 | 
            +
                    sigmoid_i = torch.sigmoid(i)
         | 
| 478 | 
            +
                    return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
         | 
| 479 | 
            +
             | 
| 480 | 
            +
             | 
| 481 | 
            +
            class CustomSwish(nn.Module):
         | 
| 482 | 
            +
                def forward(self, input_tensor):
         | 
| 483 | 
            +
                    return Swish.apply(input_tensor)
         | 
| 484 | 
            +
             | 
| 485 | 
            +
             | 
| 486 | 
            +
            class TransformerFFNLayer(nn.Module):
         | 
| 487 | 
            +
                def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
         | 
| 488 | 
            +
                    super().__init__()
         | 
| 489 | 
            +
                    self.kernel_size = kernel_size
         | 
| 490 | 
            +
                    self.dropout = dropout
         | 
| 491 | 
            +
                    self.act = act
         | 
| 492 | 
            +
                    if padding == 'SAME':
         | 
| 493 | 
            +
                        self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
         | 
| 494 | 
            +
                    elif padding == 'LEFT':
         | 
| 495 | 
            +
                        self.ffn_1 = nn.Sequential(
         | 
| 496 | 
            +
                            nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
         | 
| 497 | 
            +
                            nn.Conv1d(hidden_size, filter_size, kernel_size)
         | 
| 498 | 
            +
                        )
         | 
| 499 | 
            +
                    self.ffn_2 = Linear(filter_size, hidden_size)
         | 
| 500 | 
            +
                    if self.act == 'swish':
         | 
| 501 | 
            +
                        self.swish_fn = CustomSwish()
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                def forward(self, x, incremental_state=None):
         | 
| 504 | 
            +
                    # x: T x B x C
         | 
| 505 | 
            +
                    if incremental_state is not None:
         | 
| 506 | 
            +
                        assert incremental_state is None, 'Nar-generation does not allow this.'
         | 
| 507 | 
            +
                        exit(1)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
         | 
| 510 | 
            +
                    x = x * self.kernel_size ** -0.5
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    if incremental_state is not None:
         | 
| 513 | 
            +
                        x = x[-1:]
         | 
| 514 | 
            +
                    if self.act == 'gelu':
         | 
| 515 | 
            +
                        x = F.gelu(x)
         | 
| 516 | 
            +
                    if self.act == 'relu':
         | 
| 517 | 
            +
                        x = F.relu(x)
         | 
| 518 | 
            +
                    if self.act == 'swish':
         | 
| 519 | 
            +
                        x = self.swish_fn(x)
         | 
| 520 | 
            +
                    x = F.dropout(x, self.dropout, training=self.training)
         | 
| 521 | 
            +
                    x = self.ffn_2(x)
         | 
| 522 | 
            +
                    return x
         | 
| 523 | 
            +
             | 
| 524 | 
            +
             | 
| 525 | 
            +
            class BatchNorm1dTBC(nn.Module):
         | 
| 526 | 
            +
                def __init__(self, c):
         | 
| 527 | 
            +
                    super(BatchNorm1dTBC, self).__init__()
         | 
| 528 | 
            +
                    self.bn = nn.BatchNorm1d(c)
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                def forward(self, x):
         | 
| 531 | 
            +
                    """
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    :param x: [T, B, C]
         | 
| 534 | 
            +
                    :return: [T, B, C]
         | 
| 535 | 
            +
                    """
         | 
| 536 | 
            +
                    x = x.permute(1, 2, 0)  # [B, C, T]
         | 
| 537 | 
            +
                    x = self.bn(x)  # [B, C, T]
         | 
| 538 | 
            +
                    x = x.permute(2, 0, 1)  # [T, B, C]
         | 
| 539 | 
            +
                    return x
         | 
| 540 | 
            +
             | 
| 541 | 
            +
             | 
| 542 | 
            +
            class EncSALayer(nn.Module):
         | 
| 543 | 
            +
                def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
         | 
| 544 | 
            +
                             relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
         | 
| 545 | 
            +
                    super().__init__()
         | 
| 546 | 
            +
                    self.c = c
         | 
| 547 | 
            +
                    self.dropout = dropout
         | 
| 548 | 
            +
                    self.num_heads = num_heads
         | 
| 549 | 
            +
                    if num_heads > 0:
         | 
| 550 | 
            +
                        if norm == 'ln':
         | 
| 551 | 
            +
                            self.layer_norm1 = LayerNorm(c)
         | 
| 552 | 
            +
                        elif norm == 'bn':
         | 
| 553 | 
            +
                            self.layer_norm1 = BatchNorm1dTBC(c)
         | 
| 554 | 
            +
                        self.self_attn = MultiheadAttention(
         | 
| 555 | 
            +
                            self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
         | 
| 556 | 
            +
                        )
         | 
| 557 | 
            +
                    if norm == 'ln':
         | 
| 558 | 
            +
                        self.layer_norm2 = LayerNorm(c)
         | 
| 559 | 
            +
                    elif norm == 'bn':
         | 
| 560 | 
            +
                        self.layer_norm2 = BatchNorm1dTBC(c)
         | 
| 561 | 
            +
                    self.ffn = TransformerFFNLayer(
         | 
| 562 | 
            +
                        c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                def forward(self, x, encoder_padding_mask=None, **kwargs):
         | 
| 565 | 
            +
                    layer_norm_training = kwargs.get('layer_norm_training', None)
         | 
| 566 | 
            +
                    if layer_norm_training is not None:
         | 
| 567 | 
            +
                        self.layer_norm1.training = layer_norm_training
         | 
| 568 | 
            +
                        self.layer_norm2.training = layer_norm_training
         | 
| 569 | 
            +
                    if self.num_heads > 0:
         | 
| 570 | 
            +
                        residual = x
         | 
| 571 | 
            +
                        x = self.layer_norm1(x)
         | 
| 572 | 
            +
                        x, _, = self.self_attn(
         | 
| 573 | 
            +
                            query=x,
         | 
| 574 | 
            +
                            key=x,
         | 
| 575 | 
            +
                            value=x,
         | 
| 576 | 
            +
                            key_padding_mask=encoder_padding_mask
         | 
| 577 | 
            +
                        )
         | 
| 578 | 
            +
                        x = F.dropout(x, self.dropout, training=self.training)
         | 
| 579 | 
            +
                        x = residual + x
         | 
| 580 | 
            +
                        x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                    residual = x
         | 
| 583 | 
            +
                    x = self.layer_norm2(x)
         | 
| 584 | 
            +
                    x = self.ffn(x)
         | 
| 585 | 
            +
                    x = F.dropout(x, self.dropout, training=self.training)
         | 
| 586 | 
            +
                    x = residual + x
         | 
| 587 | 
            +
                    x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
         | 
| 588 | 
            +
                    return x
         | 
| 589 | 
            +
             | 
| 590 | 
            +
             | 
| 591 | 
            +
            class DecSALayer(nn.Module):
         | 
| 592 | 
            +
                def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
         | 
| 593 | 
            +
                    super().__init__()
         | 
| 594 | 
            +
                    self.c = c
         | 
| 595 | 
            +
                    self.dropout = dropout
         | 
| 596 | 
            +
                    self.layer_norm1 = LayerNorm(c)
         | 
| 597 | 
            +
                    self.self_attn = MultiheadAttention(
         | 
| 598 | 
            +
                        c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
         | 
| 599 | 
            +
                    )
         | 
| 600 | 
            +
                    self.layer_norm2 = LayerNorm(c)
         | 
| 601 | 
            +
                    self.encoder_attn = MultiheadAttention(
         | 
| 602 | 
            +
                        c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
         | 
| 603 | 
            +
                    )
         | 
| 604 | 
            +
                    self.layer_norm3 = LayerNorm(c)
         | 
| 605 | 
            +
                    self.ffn = TransformerFFNLayer(
         | 
| 606 | 
            +
                        c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                def forward(
         | 
| 609 | 
            +
                        self,
         | 
| 610 | 
            +
                        x,
         | 
| 611 | 
            +
                        encoder_out=None,
         | 
| 612 | 
            +
                        encoder_padding_mask=None,
         | 
| 613 | 
            +
                        incremental_state=None,
         | 
| 614 | 
            +
                        self_attn_mask=None,
         | 
| 615 | 
            +
                        self_attn_padding_mask=None,
         | 
| 616 | 
            +
                        attn_out=None,
         | 
| 617 | 
            +
                        reset_attn_weight=None,
         | 
| 618 | 
            +
                        **kwargs,
         | 
| 619 | 
            +
                ):
         | 
| 620 | 
            +
                    layer_norm_training = kwargs.get('layer_norm_training', None)
         | 
| 621 | 
            +
                    if layer_norm_training is not None:
         | 
| 622 | 
            +
                        self.layer_norm1.training = layer_norm_training
         | 
| 623 | 
            +
                        self.layer_norm2.training = layer_norm_training
         | 
| 624 | 
            +
                        self.layer_norm3.training = layer_norm_training
         | 
| 625 | 
            +
                    residual = x
         | 
| 626 | 
            +
                    x = self.layer_norm1(x)
         | 
| 627 | 
            +
                    x, _ = self.self_attn(
         | 
| 628 | 
            +
                        query=x,
         | 
| 629 | 
            +
                        key=x,
         | 
| 630 | 
            +
                        value=x,
         | 
| 631 | 
            +
                        key_padding_mask=self_attn_padding_mask,
         | 
| 632 | 
            +
                        incremental_state=incremental_state,
         | 
| 633 | 
            +
                        attn_mask=self_attn_mask
         | 
| 634 | 
            +
                    )
         | 
| 635 | 
            +
                    x = F.dropout(x, self.dropout, training=self.training)
         | 
| 636 | 
            +
                    x = residual + x
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                    residual = x
         | 
| 639 | 
            +
                    x = self.layer_norm2(x)
         | 
| 640 | 
            +
                    if encoder_out is not None:
         | 
| 641 | 
            +
                        x, attn = self.encoder_attn(
         | 
| 642 | 
            +
                            query=x,
         | 
| 643 | 
            +
                            key=encoder_out,
         | 
| 644 | 
            +
                            value=encoder_out,
         | 
| 645 | 
            +
                            key_padding_mask=encoder_padding_mask,
         | 
| 646 | 
            +
                            incremental_state=incremental_state,
         | 
| 647 | 
            +
                            static_kv=True,
         | 
| 648 | 
            +
                            enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
         | 
| 649 | 
            +
                            reset_attn_weight=reset_attn_weight
         | 
| 650 | 
            +
                        )
         | 
| 651 | 
            +
                        attn_logits = attn[1]
         | 
| 652 | 
            +
                    else:
         | 
| 653 | 
            +
                        assert attn_out is not None
         | 
| 654 | 
            +
                        x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
         | 
| 655 | 
            +
                        attn_logits = None
         | 
| 656 | 
            +
                    x = F.dropout(x, self.dropout, training=self.training)
         | 
| 657 | 
            +
                    x = residual + x
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    residual = x
         | 
| 660 | 
            +
                    x = self.layer_norm3(x)
         | 
| 661 | 
            +
                    x = self.ffn(x, incremental_state=incremental_state)
         | 
| 662 | 
            +
                    x = F.dropout(x, self.dropout, training=self.training)
         | 
| 663 | 
            +
                    x = residual + x
         | 
| 664 | 
            +
                    # if len(attn_logits.size()) > 3:
         | 
| 665 | 
            +
                    #    indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
         | 
| 666 | 
            +
                    #    attn_logits = attn_logits.gather(1,
         | 
| 667 | 
            +
                    #        indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
         | 
| 668 | 
            +
                    return x, attn_logits
         | 
    	
        modules/commons/espnet_positional_embedding.py
    ADDED
    
    | @@ -0,0 +1,113 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class PositionalEncoding(torch.nn.Module):
         | 
| 6 | 
            +
                """Positional encoding.
         | 
| 7 | 
            +
                Args:
         | 
| 8 | 
            +
                    d_model (int): Embedding dimension.
         | 
| 9 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 10 | 
            +
                    max_len (int): Maximum input length.
         | 
| 11 | 
            +
                    reverse (bool): Whether to reverse the input position.
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
         | 
| 15 | 
            +
                    """Construct an PositionalEncoding object."""
         | 
| 16 | 
            +
                    super(PositionalEncoding, self).__init__()
         | 
| 17 | 
            +
                    self.d_model = d_model
         | 
| 18 | 
            +
                    self.reverse = reverse
         | 
| 19 | 
            +
                    self.xscale = math.sqrt(self.d_model)
         | 
| 20 | 
            +
                    self.dropout = torch.nn.Dropout(p=dropout_rate)
         | 
| 21 | 
            +
                    self.pe = None
         | 
| 22 | 
            +
                    self.extend_pe(torch.tensor(0.0).expand(1, max_len))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def extend_pe(self, x):
         | 
| 25 | 
            +
                    """Reset the positional encodings."""
         | 
| 26 | 
            +
                    if self.pe is not None:
         | 
| 27 | 
            +
                        if self.pe.size(1) >= x.size(1):
         | 
| 28 | 
            +
                            if self.pe.dtype != x.dtype or self.pe.device != x.device:
         | 
| 29 | 
            +
                                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
         | 
| 30 | 
            +
                            return
         | 
| 31 | 
            +
                    pe = torch.zeros(x.size(1), self.d_model)
         | 
| 32 | 
            +
                    if self.reverse:
         | 
| 33 | 
            +
                        position = torch.arange(
         | 
| 34 | 
            +
                            x.size(1) - 1, -1, -1.0, dtype=torch.float32
         | 
| 35 | 
            +
                        ).unsqueeze(1)
         | 
| 36 | 
            +
                    else:
         | 
| 37 | 
            +
                        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
         | 
| 38 | 
            +
                    div_term = torch.exp(
         | 
| 39 | 
            +
                        torch.arange(0, self.d_model, 2, dtype=torch.float32)
         | 
| 40 | 
            +
                        * -(math.log(10000.0) / self.d_model)
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
                    pe[:, 0::2] = torch.sin(position * div_term)
         | 
| 43 | 
            +
                    pe[:, 1::2] = torch.cos(position * div_term)
         | 
| 44 | 
            +
                    pe = pe.unsqueeze(0)
         | 
| 45 | 
            +
                    self.pe = pe.to(device=x.device, dtype=x.dtype)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 48 | 
            +
                    """Add positional encoding.
         | 
| 49 | 
            +
                    Args:
         | 
| 50 | 
            +
                        x (torch.Tensor): Input tensor (batch, time, `*`).
         | 
| 51 | 
            +
                    Returns:
         | 
| 52 | 
            +
                        torch.Tensor: Encoded tensor (batch, time, `*`).
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    self.extend_pe(x)
         | 
| 55 | 
            +
                    x = x * self.xscale + self.pe[:, : x.size(1)]
         | 
| 56 | 
            +
                    return self.dropout(x)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            class ScaledPositionalEncoding(PositionalEncoding):
         | 
| 60 | 
            +
                """Scaled positional encoding module.
         | 
| 61 | 
            +
                See Sec. 3.2  https://arxiv.org/abs/1809.08895
         | 
| 62 | 
            +
                Args:
         | 
| 63 | 
            +
                    d_model (int): Embedding dimension.
         | 
| 64 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 65 | 
            +
                    max_len (int): Maximum input length.
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def __init__(self, d_model, dropout_rate, max_len=5000):
         | 
| 69 | 
            +
                    """Initialize class."""
         | 
| 70 | 
            +
                    super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
         | 
| 71 | 
            +
                    self.alpha = torch.nn.Parameter(torch.tensor(1.0))
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def reset_parameters(self):
         | 
| 74 | 
            +
                    """Reset parameters."""
         | 
| 75 | 
            +
                    self.alpha.data = torch.tensor(1.0)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def forward(self, x):
         | 
| 78 | 
            +
                    """Add positional encoding.
         | 
| 79 | 
            +
                    Args:
         | 
| 80 | 
            +
                        x (torch.Tensor): Input tensor (batch, time, `*`).
         | 
| 81 | 
            +
                    Returns:
         | 
| 82 | 
            +
                        torch.Tensor: Encoded tensor (batch, time, `*`).
         | 
| 83 | 
            +
                    """
         | 
| 84 | 
            +
                    self.extend_pe(x)
         | 
| 85 | 
            +
                    x = x + self.alpha * self.pe[:, : x.size(1)]
         | 
| 86 | 
            +
                    return self.dropout(x)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            class RelPositionalEncoding(PositionalEncoding):
         | 
| 90 | 
            +
                """Relative positional encoding module.
         | 
| 91 | 
            +
                See : Appendix B in https://arxiv.org/abs/1901.02860
         | 
| 92 | 
            +
                Args:
         | 
| 93 | 
            +
                    d_model (int): Embedding dimension.
         | 
| 94 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 95 | 
            +
                    max_len (int): Maximum input length.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def __init__(self, d_model, dropout_rate, max_len=5000):
         | 
| 99 | 
            +
                    """Initialize class."""
         | 
| 100 | 
            +
                    super().__init__(d_model, dropout_rate, max_len, reverse=True)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def forward(self, x):
         | 
| 103 | 
            +
                    """Compute positional encoding.
         | 
| 104 | 
            +
                    Args:
         | 
| 105 | 
            +
                        x (torch.Tensor): Input tensor (batch, time, `*`).
         | 
| 106 | 
            +
                    Returns:
         | 
| 107 | 
            +
                        torch.Tensor: Encoded tensor (batch, time, `*`).
         | 
| 108 | 
            +
                        torch.Tensor: Positional embedding tensor (1, time, `*`).
         | 
| 109 | 
            +
                    """
         | 
| 110 | 
            +
                    self.extend_pe(x)
         | 
| 111 | 
            +
                    x = x * self.xscale
         | 
| 112 | 
            +
                    pos_emb = self.pe[:, : x.size(1)]
         | 
| 113 | 
            +
                    return self.dropout(x) + self.dropout(pos_emb)
         | 
    	
        modules/commons/ssim.py
    ADDED
    
    | @@ -0,0 +1,391 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # '''
         | 
| 2 | 
            +
            # https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
         | 
| 3 | 
            +
            # '''
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # import torch
         | 
| 6 | 
            +
            # import torch.jit
         | 
| 7 | 
            +
            # import torch.nn.functional as F
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # @torch.jit.script
         | 
| 11 | 
            +
            # def create_window(window_size: int, sigma: float, channel: int):
         | 
| 12 | 
            +
            #     '''
         | 
| 13 | 
            +
            #     Create 1-D gauss kernel
         | 
| 14 | 
            +
            #     :param window_size: the size of gauss kernel
         | 
| 15 | 
            +
            #     :param sigma: sigma of normal distribution
         | 
| 16 | 
            +
            #     :param channel: input channel
         | 
| 17 | 
            +
            #     :return: 1D kernel
         | 
| 18 | 
            +
            #     '''
         | 
| 19 | 
            +
            #     coords = torch.arange(window_size, dtype=torch.float)
         | 
| 20 | 
            +
            #     coords -= window_size // 2
         | 
| 21 | 
            +
            #
         | 
| 22 | 
            +
            #     g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
         | 
| 23 | 
            +
            #     g /= g.sum()
         | 
| 24 | 
            +
            #
         | 
| 25 | 
            +
            #     g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
         | 
| 26 | 
            +
            #     return g
         | 
| 27 | 
            +
            #
         | 
| 28 | 
            +
            #
         | 
| 29 | 
            +
            # @torch.jit.script
         | 
| 30 | 
            +
            # def _gaussian_filter(x, window_1d, use_padding: bool):
         | 
| 31 | 
            +
            #     '''
         | 
| 32 | 
            +
            #     Blur input with 1-D kernel
         | 
| 33 | 
            +
            #     :param x: batch of tensors to be blured
         | 
| 34 | 
            +
            #     :param window_1d: 1-D gauss kernel
         | 
| 35 | 
            +
            #     :param use_padding: padding image before conv
         | 
| 36 | 
            +
            #     :return: blured tensors
         | 
| 37 | 
            +
            #     '''
         | 
| 38 | 
            +
            #     C = x.shape[1]
         | 
| 39 | 
            +
            #     padding = 0
         | 
| 40 | 
            +
            #     if use_padding:
         | 
| 41 | 
            +
            #         window_size = window_1d.shape[3]
         | 
| 42 | 
            +
            #         padding = window_size // 2
         | 
| 43 | 
            +
            #     out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
         | 
| 44 | 
            +
            #     out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
         | 
| 45 | 
            +
            #     return out
         | 
| 46 | 
            +
            #
         | 
| 47 | 
            +
            #
         | 
| 48 | 
            +
            # @torch.jit.script
         | 
| 49 | 
            +
            # def ssim(X, Y, window, data_range: float, use_padding: bool = False):
         | 
| 50 | 
            +
            #     '''
         | 
| 51 | 
            +
            #     Calculate ssim index for X and Y
         | 
| 52 | 
            +
            #     :param X: images [B, C, H, N_bins]
         | 
| 53 | 
            +
            #     :param Y: images [B, C, H, N_bins]
         | 
| 54 | 
            +
            #     :param window: 1-D gauss kernel
         | 
| 55 | 
            +
            #     :param data_range: value range of input images. (usually 1.0 or 255)
         | 
| 56 | 
            +
            #     :param use_padding: padding image before conv
         | 
| 57 | 
            +
            #     :return:
         | 
| 58 | 
            +
            #     '''
         | 
| 59 | 
            +
            #
         | 
| 60 | 
            +
            #     K1 = 0.01
         | 
| 61 | 
            +
            #     K2 = 0.03
         | 
| 62 | 
            +
            #     compensation = 1.0
         | 
| 63 | 
            +
            #
         | 
| 64 | 
            +
            #     C1 = (K1 * data_range) ** 2
         | 
| 65 | 
            +
            #     C2 = (K2 * data_range) ** 2
         | 
| 66 | 
            +
            #
         | 
| 67 | 
            +
            #     mu1 = _gaussian_filter(X, window, use_padding)
         | 
| 68 | 
            +
            #     mu2 = _gaussian_filter(Y, window, use_padding)
         | 
| 69 | 
            +
            #     sigma1_sq = _gaussian_filter(X * X, window, use_padding)
         | 
| 70 | 
            +
            #     sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
         | 
| 71 | 
            +
            #     sigma12 = _gaussian_filter(X * Y, window, use_padding)
         | 
| 72 | 
            +
            #
         | 
| 73 | 
            +
            #     mu1_sq = mu1.pow(2)
         | 
| 74 | 
            +
            #     mu2_sq = mu2.pow(2)
         | 
| 75 | 
            +
            #     mu1_mu2 = mu1 * mu2
         | 
| 76 | 
            +
            #
         | 
| 77 | 
            +
            #     sigma1_sq = compensation * (sigma1_sq - mu1_sq)
         | 
| 78 | 
            +
            #     sigma2_sq = compensation * (sigma2_sq - mu2_sq)
         | 
| 79 | 
            +
            #     sigma12 = compensation * (sigma12 - mu1_mu2)
         | 
| 80 | 
            +
            #
         | 
| 81 | 
            +
            #     cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
         | 
| 82 | 
            +
            #     # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
         | 
| 83 | 
            +
            #     cs_map = cs_map.clamp_min(0.)
         | 
| 84 | 
            +
            #     ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
         | 
| 85 | 
            +
            #
         | 
| 86 | 
            +
            #     ssim_val = ssim_map.mean(dim=(1, 2, 3))  # reduce along CHW
         | 
| 87 | 
            +
            #     cs = cs_map.mean(dim=(1, 2, 3))
         | 
| 88 | 
            +
            #
         | 
| 89 | 
            +
            #     return ssim_val, cs
         | 
| 90 | 
            +
            #
         | 
| 91 | 
            +
            #
         | 
| 92 | 
            +
            # @torch.jit.script
         | 
| 93 | 
            +
            # def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
         | 
| 94 | 
            +
            #     '''
         | 
| 95 | 
            +
            #     interface of ms-ssim
         | 
| 96 | 
            +
            #     :param X: a batch of images, (N,C,H,W)
         | 
| 97 | 
            +
            #     :param Y: a batch of images, (N,C,H,W)
         | 
| 98 | 
            +
            #     :param window: 1-D gauss kernel
         | 
| 99 | 
            +
            #     :param data_range: value range of input images. (usually 1.0 or 255)
         | 
| 100 | 
            +
            #     :param weights: weights for different levels
         | 
| 101 | 
            +
            #     :param use_padding: padding image before conv
         | 
| 102 | 
            +
            #     :param eps: use for avoid grad nan.
         | 
| 103 | 
            +
            #     :return:
         | 
| 104 | 
            +
            #     '''
         | 
| 105 | 
            +
            #     levels = weights.shape[0]
         | 
| 106 | 
            +
            #     cs_vals = []
         | 
| 107 | 
            +
            #     ssim_vals = []
         | 
| 108 | 
            +
            #     for _ in range(levels):
         | 
| 109 | 
            +
            #         ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
         | 
| 110 | 
            +
            #         # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
         | 
| 111 | 
            +
            #         ssim_val = ssim_val.clamp_min(eps)
         | 
| 112 | 
            +
            #         cs = cs.clamp_min(eps)
         | 
| 113 | 
            +
            #         cs_vals.append(cs)
         | 
| 114 | 
            +
            #
         | 
| 115 | 
            +
            #         ssim_vals.append(ssim_val)
         | 
| 116 | 
            +
            #         padding = (X.shape[2] % 2, X.shape[3] % 2)
         | 
| 117 | 
            +
            #         X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
         | 
| 118 | 
            +
            #         Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
         | 
| 119 | 
            +
            #
         | 
| 120 | 
            +
            #     cs_vals = torch.stack(cs_vals, dim=0)
         | 
| 121 | 
            +
            #     ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
         | 
| 122 | 
            +
            #     return ms_ssim_val
         | 
| 123 | 
            +
            #
         | 
| 124 | 
            +
            #
         | 
| 125 | 
            +
            # class SSIM(torch.jit.ScriptModule):
         | 
| 126 | 
            +
            #     __constants__ = ['data_range', 'use_padding']
         | 
| 127 | 
            +
            #
         | 
| 128 | 
            +
            #     def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
         | 
| 129 | 
            +
            #         '''
         | 
| 130 | 
            +
            #         :param window_size: the size of gauss kernel
         | 
| 131 | 
            +
            #         :param window_sigma: sigma of normal distribution
         | 
| 132 | 
            +
            #         :param data_range: value range of input images. (usually 1.0 or 255)
         | 
| 133 | 
            +
            #         :param channel: input channels (default: 3)
         | 
| 134 | 
            +
            #         :param use_padding: padding image before conv
         | 
| 135 | 
            +
            #         '''
         | 
| 136 | 
            +
            #         super().__init__()
         | 
| 137 | 
            +
            #         assert window_size % 2 == 1, 'Window size must be odd.'
         | 
| 138 | 
            +
            #         window = create_window(window_size, window_sigma, channel)
         | 
| 139 | 
            +
            #         self.register_buffer('window', window)
         | 
| 140 | 
            +
            #         self.data_range = data_range
         | 
| 141 | 
            +
            #         self.use_padding = use_padding
         | 
| 142 | 
            +
            #
         | 
| 143 | 
            +
            #     @torch.jit.script_method
         | 
| 144 | 
            +
            #     def forward(self, X, Y):
         | 
| 145 | 
            +
            #         r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
         | 
| 146 | 
            +
            #         return r[0]
         | 
| 147 | 
            +
            #
         | 
| 148 | 
            +
            #
         | 
| 149 | 
            +
            # class MS_SSIM(torch.jit.ScriptModule):
         | 
| 150 | 
            +
            #     __constants__ = ['data_range', 'use_padding', 'eps']
         | 
| 151 | 
            +
            #
         | 
| 152 | 
            +
            #     def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
         | 
| 153 | 
            +
            #                  levels=None, eps=1e-8):
         | 
| 154 | 
            +
            #         '''
         | 
| 155 | 
            +
            #         class for ms-ssim
         | 
| 156 | 
            +
            #         :param window_size: the size of gauss kernel
         | 
| 157 | 
            +
            #         :param window_sigma: sigma of normal distribution
         | 
| 158 | 
            +
            #         :param data_range: value range of input images. (usually 1.0 or 255)
         | 
| 159 | 
            +
            #         :param channel: input channels
         | 
| 160 | 
            +
            #         :param use_padding: padding image before conv
         | 
| 161 | 
            +
            #         :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
         | 
| 162 | 
            +
            #         :param levels: number of downsampling
         | 
| 163 | 
            +
            #         :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
         | 
| 164 | 
            +
            #         '''
         | 
| 165 | 
            +
            #         super().__init__()
         | 
| 166 | 
            +
            #         assert window_size % 2 == 1, 'Window size must be odd.'
         | 
| 167 | 
            +
            #         self.data_range = data_range
         | 
| 168 | 
            +
            #         self.use_padding = use_padding
         | 
| 169 | 
            +
            #         self.eps = eps
         | 
| 170 | 
            +
            #
         | 
| 171 | 
            +
            #         window = create_window(window_size, window_sigma, channel)
         | 
| 172 | 
            +
            #         self.register_buffer('window', window)
         | 
| 173 | 
            +
            #
         | 
| 174 | 
            +
            #         if weights is None:
         | 
| 175 | 
            +
            #             weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
         | 
| 176 | 
            +
            #         weights = torch.tensor(weights, dtype=torch.float)
         | 
| 177 | 
            +
            #
         | 
| 178 | 
            +
            #         if levels is not None:
         | 
| 179 | 
            +
            #             weights = weights[:levels]
         | 
| 180 | 
            +
            #             weights = weights / weights.sum()
         | 
| 181 | 
            +
            #
         | 
| 182 | 
            +
            #         self.register_buffer('weights', weights)
         | 
| 183 | 
            +
            #
         | 
| 184 | 
            +
            #     @torch.jit.script_method
         | 
| 185 | 
            +
            #     def forward(self, X, Y):
         | 
| 186 | 
            +
            #         return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
         | 
| 187 | 
            +
            #                        use_padding=self.use_padding, eps=self.eps)
         | 
| 188 | 
            +
            #
         | 
| 189 | 
            +
            #
         | 
| 190 | 
            +
            # if __name__ == '__main__':
         | 
| 191 | 
            +
            #     print('Simple Test')
         | 
| 192 | 
            +
            #     im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
         | 
| 193 | 
            +
            #     img1 = im / 255
         | 
| 194 | 
            +
            #     img2 = img1 * 0.5
         | 
| 195 | 
            +
            #
         | 
| 196 | 
            +
            #     losser = SSIM(data_range=1.).cuda()
         | 
| 197 | 
            +
            #     loss = losser(img1, img2).mean()
         | 
| 198 | 
            +
            #
         | 
| 199 | 
            +
            #     losser2 = MS_SSIM(data_range=1.).cuda()
         | 
| 200 | 
            +
            #     loss2 = losser2(img1, img2).mean()
         | 
| 201 | 
            +
            #
         | 
| 202 | 
            +
            #     print(loss.item())
         | 
| 203 | 
            +
            #     print(loss2.item())
         | 
| 204 | 
            +
            #
         | 
| 205 | 
            +
            # if __name__ == '__main__':
         | 
| 206 | 
            +
            #     print('Training Test')
         | 
| 207 | 
            +
            #     import cv2
         | 
| 208 | 
            +
            #     import torch.optim
         | 
| 209 | 
            +
            #     import numpy as np
         | 
| 210 | 
            +
            #     import imageio
         | 
| 211 | 
            +
            #     import time
         | 
| 212 | 
            +
            #
         | 
| 213 | 
            +
            #     out_test_video = False
         | 
| 214 | 
            +
            #     # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
         | 
| 215 | 
            +
            #     video_use_gif = False
         | 
| 216 | 
            +
            #
         | 
| 217 | 
            +
            #     im = cv2.imread('test_img1.jpg', 1)
         | 
| 218 | 
            +
            #     t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
         | 
| 219 | 
            +
            #
         | 
| 220 | 
            +
            #     if out_test_video:
         | 
| 221 | 
            +
            #         if video_use_gif:
         | 
| 222 | 
            +
            #             fps = 0.5
         | 
| 223 | 
            +
            #             out_wh = (im.shape[1] // 2, im.shape[0] // 2)
         | 
| 224 | 
            +
            #             suffix = '.gif'
         | 
| 225 | 
            +
            #         else:
         | 
| 226 | 
            +
            #             fps = 5
         | 
| 227 | 
            +
            #             out_wh = (im.shape[1], im.shape[0])
         | 
| 228 | 
            +
            #             suffix = '.mkv'
         | 
| 229 | 
            +
            #         video_last_time = time.perf_counter()
         | 
| 230 | 
            +
            #         video = imageio.get_writer('ssim_test' + suffix, fps=fps)
         | 
| 231 | 
            +
            #
         | 
| 232 | 
            +
            #     # 测试ssim
         | 
| 233 | 
            +
            #     print('Training SSIM')
         | 
| 234 | 
            +
            #     rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
         | 
| 235 | 
            +
            #     rand_im.requires_grad = True
         | 
| 236 | 
            +
            #     optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
         | 
| 237 | 
            +
            #     losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
         | 
| 238 | 
            +
            #     ssim_score = 0
         | 
| 239 | 
            +
            #     while ssim_score < 0.999:
         | 
| 240 | 
            +
            #         optim.zero_grad()
         | 
| 241 | 
            +
            #         loss = losser(rand_im, t_im)
         | 
| 242 | 
            +
            #         (-loss).sum().backward()
         | 
| 243 | 
            +
            #         ssim_score = loss.item()
         | 
| 244 | 
            +
            #         optim.step()
         | 
| 245 | 
            +
            #         r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
         | 
| 246 | 
            +
            #         r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
         | 
| 247 | 
            +
            #
         | 
| 248 | 
            +
            #         if out_test_video:
         | 
| 249 | 
            +
            #             if time.perf_counter() - video_last_time > 1. / fps:
         | 
| 250 | 
            +
            #                 video_last_time = time.perf_counter()
         | 
| 251 | 
            +
            #                 out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
         | 
| 252 | 
            +
            #                 out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
         | 
| 253 | 
            +
            #                 if isinstance(out_frame, cv2.UMat):
         | 
| 254 | 
            +
            #                     out_frame = out_frame.get()
         | 
| 255 | 
            +
            #                 video.append_data(out_frame)
         | 
| 256 | 
            +
            #
         | 
| 257 | 
            +
            #         cv2.imshow('ssim', r_im)
         | 
| 258 | 
            +
            #         cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
         | 
| 259 | 
            +
            #         cv2.waitKey(1)
         | 
| 260 | 
            +
            #
         | 
| 261 | 
            +
            #     if out_test_video:
         | 
| 262 | 
            +
            #         video.close()
         | 
| 263 | 
            +
            #
         | 
| 264 | 
            +
            #     # 测试ms_ssim
         | 
| 265 | 
            +
            #     if out_test_video:
         | 
| 266 | 
            +
            #         if video_use_gif:
         | 
| 267 | 
            +
            #             fps = 0.5
         | 
| 268 | 
            +
            #             out_wh = (im.shape[1] // 2, im.shape[0] // 2)
         | 
| 269 | 
            +
            #             suffix = '.gif'
         | 
| 270 | 
            +
            #         else:
         | 
| 271 | 
            +
            #             fps = 5
         | 
| 272 | 
            +
            #             out_wh = (im.shape[1], im.shape[0])
         | 
| 273 | 
            +
            #             suffix = '.mkv'
         | 
| 274 | 
            +
            #         video_last_time = time.perf_counter()
         | 
| 275 | 
            +
            #         video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
         | 
| 276 | 
            +
            #
         | 
| 277 | 
            +
            #     print('Training MS_SSIM')
         | 
| 278 | 
            +
            #     rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
         | 
| 279 | 
            +
            #     rand_im.requires_grad = True
         | 
| 280 | 
            +
            #     optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
         | 
| 281 | 
            +
            #     losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
         | 
| 282 | 
            +
            #     ssim_score = 0
         | 
| 283 | 
            +
            #     while ssim_score < 0.999:
         | 
| 284 | 
            +
            #         optim.zero_grad()
         | 
| 285 | 
            +
            #         loss = losser(rand_im, t_im)
         | 
| 286 | 
            +
            #         (-loss).sum().backward()
         | 
| 287 | 
            +
            #         ssim_score = loss.item()
         | 
| 288 | 
            +
            #         optim.step()
         | 
| 289 | 
            +
            #         r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
         | 
| 290 | 
            +
            #         r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
         | 
| 291 | 
            +
            #
         | 
| 292 | 
            +
            #         if out_test_video:
         | 
| 293 | 
            +
            #             if time.perf_counter() - video_last_time > 1. / fps:
         | 
| 294 | 
            +
            #                 video_last_time = time.perf_counter()
         | 
| 295 | 
            +
            #                 out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
         | 
| 296 | 
            +
            #                 out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
         | 
| 297 | 
            +
            #                 if isinstance(out_frame, cv2.UMat):
         | 
| 298 | 
            +
            #                     out_frame = out_frame.get()
         | 
| 299 | 
            +
            #                 video.append_data(out_frame)
         | 
| 300 | 
            +
            #
         | 
| 301 | 
            +
            #         cv2.imshow('ms_ssim', r_im)
         | 
| 302 | 
            +
            #         cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
         | 
| 303 | 
            +
            #         cv2.waitKey(1)
         | 
| 304 | 
            +
            #
         | 
| 305 | 
            +
            #     if out_test_video:
         | 
| 306 | 
            +
            #         video.close()
         | 
| 307 | 
            +
             | 
| 308 | 
            +
            """
         | 
| 309 | 
            +
            Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
         | 
| 310 | 
            +
            """
         | 
| 311 | 
            +
             | 
| 312 | 
            +
            import torch
         | 
| 313 | 
            +
            import torch.nn.functional as F
         | 
| 314 | 
            +
            from torch.autograd import Variable
         | 
| 315 | 
            +
            import numpy as np
         | 
| 316 | 
            +
            from math import exp
         | 
| 317 | 
            +
             | 
| 318 | 
            +
             | 
| 319 | 
            +
            def gaussian(window_size, sigma):
         | 
| 320 | 
            +
                gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
         | 
| 321 | 
            +
                return gauss / gauss.sum()
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            def create_window(window_size, channel):
         | 
| 325 | 
            +
                _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
         | 
| 326 | 
            +
                _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
         | 
| 327 | 
            +
                window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
         | 
| 328 | 
            +
                return window
         | 
| 329 | 
            +
             | 
| 330 | 
            +
             | 
| 331 | 
            +
            def _ssim(img1, img2, window, window_size, channel, size_average=True):
         | 
| 332 | 
            +
                mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
         | 
| 333 | 
            +
                mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                mu1_sq = mu1.pow(2)
         | 
| 336 | 
            +
                mu2_sq = mu2.pow(2)
         | 
| 337 | 
            +
                mu1_mu2 = mu1 * mu2
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
         | 
| 340 | 
            +
                sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
         | 
| 341 | 
            +
                sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                C1 = 0.01 ** 2
         | 
| 344 | 
            +
                C2 = 0.03 ** 2
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                if size_average:
         | 
| 349 | 
            +
                    return ssim_map.mean()
         | 
| 350 | 
            +
                else:
         | 
| 351 | 
            +
                    return ssim_map.mean(1)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
             | 
| 354 | 
            +
            class SSIM(torch.nn.Module):
         | 
| 355 | 
            +
                def __init__(self, window_size=11, size_average=True):
         | 
| 356 | 
            +
                    super(SSIM, self).__init__()
         | 
| 357 | 
            +
                    self.window_size = window_size
         | 
| 358 | 
            +
                    self.size_average = size_average
         | 
| 359 | 
            +
                    self.channel = 1
         | 
| 360 | 
            +
                    self.window = create_window(window_size, self.channel)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                def forward(self, img1, img2):
         | 
| 363 | 
            +
                    (_, channel, _, _) = img1.size()
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    if channel == self.channel and self.window.data.type() == img1.data.type():
         | 
| 366 | 
            +
                        window = self.window
         | 
| 367 | 
            +
                    else:
         | 
| 368 | 
            +
                        window = create_window(self.window_size, channel)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                        if img1.is_cuda:
         | 
| 371 | 
            +
                            window = window.cuda(img1.get_device())
         | 
| 372 | 
            +
                        window = window.type_as(img1)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                        self.window = window
         | 
| 375 | 
            +
                        self.channel = channel
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
            window = None
         | 
| 381 | 
            +
             | 
| 382 | 
            +
             | 
| 383 | 
            +
            def ssim(img1, img2, window_size=11, size_average=True):
         | 
| 384 | 
            +
                (_, channel, _, _) = img1.size()
         | 
| 385 | 
            +
                global window
         | 
| 386 | 
            +
                if window is None:
         | 
| 387 | 
            +
                    window = create_window(window_size, channel)
         | 
| 388 | 
            +
                    if img1.is_cuda:
         | 
| 389 | 
            +
                        window = window.cuda(img1.get_device())
         | 
| 390 | 
            +
                    window = window.type_as(img1)
         | 
| 391 | 
            +
                return _ssim(img1, img2, window, window_size, channel, size_average)
         | 
    	
        modules/diffsinger_midi/fs2.py
    ADDED
    
    | @@ -0,0 +1,118 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from modules.commons.common_layers import *
         | 
| 2 | 
            +
            from modules.commons.common_layers import Embedding
         | 
| 3 | 
            +
            from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
         | 
| 4 | 
            +
                EnergyPredictor, FastspeechEncoder
         | 
| 5 | 
            +
            from utils.cwt import cwt2f0
         | 
| 6 | 
            +
            from utils.hparams import hparams
         | 
| 7 | 
            +
            from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
         | 
| 8 | 
            +
            from modules.fastspeech.fs2 import FastSpeech2
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class FastspeechMIDIEncoder(FastspeechEncoder):
         | 
| 12 | 
            +
                def forward_embedding(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
         | 
| 13 | 
            +
                    # embed tokens and positions
         | 
| 14 | 
            +
                    x = self.embed_scale * self.embed_tokens(txt_tokens)
         | 
| 15 | 
            +
                    x = x + midi_embedding + midi_dur_embedding + slur_embedding
         | 
| 16 | 
            +
                    if hparams['use_pos_embed']:
         | 
| 17 | 
            +
                        if hparams.get('rel_pos') is not None and hparams['rel_pos']:
         | 
| 18 | 
            +
                            x = self.embed_positions(x)
         | 
| 19 | 
            +
                        else:
         | 
| 20 | 
            +
                            positions = self.embed_positions(txt_tokens)
         | 
| 21 | 
            +
                            x = x + positions
         | 
| 22 | 
            +
                    x = F.dropout(x, p=self.dropout, training=self.training)
         | 
| 23 | 
            +
                    return x
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def forward(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
         | 
| 26 | 
            +
                    """
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    :param txt_tokens: [B, T]
         | 
| 29 | 
            +
                    :return: {
         | 
| 30 | 
            +
                        'encoder_out': [T x B x C]
         | 
| 31 | 
            +
                    }
         | 
| 32 | 
            +
                    """
         | 
| 33 | 
            +
                    encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
         | 
| 34 | 
            +
                    x = self.forward_embedding(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding)  # [B, T, H]
         | 
| 35 | 
            +
                    x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
         | 
| 36 | 
            +
                    return x
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            FS_ENCODERS = {
         | 
| 40 | 
            +
                'fft': lambda hp, embed_tokens, d: FastspeechMIDIEncoder(
         | 
| 41 | 
            +
                    embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
         | 
| 42 | 
            +
                    num_heads=hp['num_heads']),
         | 
| 43 | 
            +
            }
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class FastSpeech2MIDI(FastSpeech2):
         | 
| 47 | 
            +
                def __init__(self, dictionary, out_dims=None):
         | 
| 48 | 
            +
                    super().__init__(dictionary, out_dims)
         | 
| 49 | 
            +
                    del self.encoder
         | 
| 50 | 
            +
                    self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
         | 
| 51 | 
            +
                    self.midi_embed = Embedding(300, self.hidden_size, self.padding_idx)
         | 
| 52 | 
            +
                    self.midi_dur_layer = Linear(1, self.hidden_size)
         | 
| 53 | 
            +
                    self.is_slur_embed = Embedding(2, self.hidden_size)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
         | 
| 56 | 
            +
                            ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
         | 
| 57 | 
            +
                            spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
         | 
| 58 | 
            +
                    ret = {}
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    midi_embedding = self.midi_embed(kwargs['pitch_midi'])
         | 
| 61 | 
            +
                    midi_dur_embedding, slur_embedding = 0, 0
         | 
| 62 | 
            +
                    if kwargs.get('midi_dur') is not None:
         | 
| 63 | 
            +
                        midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None])  # [B, T, 1] -> [B, T, H]
         | 
| 64 | 
            +
                    if kwargs.get('is_slur') is not None:
         | 
| 65 | 
            +
                        slur_embedding = self.is_slur_embed(kwargs['is_slur'])
         | 
| 66 | 
            +
                    encoder_out = self.encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding)  # [B, T, C]
         | 
| 67 | 
            +
                    src_nonpadding = (txt_tokens > 0).float()[:, :, None]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # add ref style embed
         | 
| 70 | 
            +
                    # Not implemented
         | 
| 71 | 
            +
                    # variance encoder
         | 
| 72 | 
            +
                    var_embed = 0
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # encoder_out_dur denotes encoder outputs for duration predictor
         | 
| 75 | 
            +
                    # in speech adaptation, duration predictor use old speaker embedding
         | 
| 76 | 
            +
                    if hparams['use_spk_embed']:
         | 
| 77 | 
            +
                        spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
         | 
| 78 | 
            +
                    elif hparams['use_spk_id']:
         | 
| 79 | 
            +
                        spk_embed_id = spk_embed
         | 
| 80 | 
            +
                        if spk_embed_dur_id is None:
         | 
| 81 | 
            +
                            spk_embed_dur_id = spk_embed_id
         | 
| 82 | 
            +
                        if spk_embed_f0_id is None:
         | 
| 83 | 
            +
                            spk_embed_f0_id = spk_embed_id
         | 
| 84 | 
            +
                        spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
         | 
| 85 | 
            +
                        spk_embed_dur = spk_embed_f0 = spk_embed
         | 
| 86 | 
            +
                        if hparams['use_split_spk_id']:
         | 
| 87 | 
            +
                            spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
         | 
| 88 | 
            +
                            spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
         | 
| 89 | 
            +
                    else:
         | 
| 90 | 
            +
                        spk_embed_dur = spk_embed_f0 = spk_embed = 0
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    # add dur
         | 
| 93 | 
            +
                    dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
         | 
| 100 | 
            +
                    decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_)  # [B, T, H]
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    # add pitch and energy embed
         | 
| 105 | 
            +
                    pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
         | 
| 106 | 
            +
                    if hparams['use_pitch_embed']:
         | 
| 107 | 
            +
                        pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
         | 
| 108 | 
            +
                        decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
         | 
| 109 | 
            +
                    if hparams['use_energy_embed']:
         | 
| 110 | 
            +
                        decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    if skip_decoder:
         | 
| 115 | 
            +
                        return ret
         | 
| 116 | 
            +
                    ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    return ret
         | 
    	
        modules/fastspeech/fs2.py
    ADDED
    
    | @@ -0,0 +1,255 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from modules.commons.common_layers import *
         | 
| 2 | 
            +
            from modules.commons.common_layers import Embedding
         | 
| 3 | 
            +
            from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
         | 
| 4 | 
            +
                EnergyPredictor, FastspeechEncoder
         | 
| 5 | 
            +
            from utils.cwt import cwt2f0
         | 
| 6 | 
            +
            from utils.hparams import hparams
         | 
| 7 | 
            +
            from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            FS_ENCODERS = {
         | 
| 10 | 
            +
                'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
         | 
| 11 | 
            +
                    embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
         | 
| 12 | 
            +
                    num_heads=hp['num_heads']),
         | 
| 13 | 
            +
            }
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            FS_DECODERS = {
         | 
| 16 | 
            +
                'fft': lambda hp: FastspeechDecoder(
         | 
| 17 | 
            +
                    hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
         | 
| 18 | 
            +
            }
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class FastSpeech2(nn.Module):
         | 
| 22 | 
            +
                def __init__(self, dictionary, out_dims=None):
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
                    self.dictionary = dictionary
         | 
| 25 | 
            +
                    self.padding_idx = dictionary.pad()
         | 
| 26 | 
            +
                    self.enc_layers = hparams['enc_layers']
         | 
| 27 | 
            +
                    self.dec_layers = hparams['dec_layers']
         | 
| 28 | 
            +
                    self.hidden_size = hparams['hidden_size']
         | 
| 29 | 
            +
                    self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
         | 
| 30 | 
            +
                    self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
         | 
| 31 | 
            +
                    self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
         | 
| 32 | 
            +
                    self.out_dims = out_dims
         | 
| 33 | 
            +
                    if out_dims is None:
         | 
| 34 | 
            +
                        self.out_dims = hparams['audio_num_mel_bins']
         | 
| 35 | 
            +
                    self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    if hparams['use_spk_id']:
         | 
| 38 | 
            +
                        self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
         | 
| 39 | 
            +
                        if hparams['use_split_spk_id']:
         | 
| 40 | 
            +
                            self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
         | 
| 41 | 
            +
                            self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
         | 
| 42 | 
            +
                    elif hparams['use_spk_embed']:
         | 
| 43 | 
            +
                        self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
         | 
| 44 | 
            +
                    predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
         | 
| 45 | 
            +
                    self.dur_predictor = DurationPredictor(
         | 
| 46 | 
            +
                        self.hidden_size,
         | 
| 47 | 
            +
                        n_chans=predictor_hidden,
         | 
| 48 | 
            +
                        n_layers=hparams['dur_predictor_layers'],
         | 
| 49 | 
            +
                        dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
         | 
| 50 | 
            +
                        kernel_size=hparams['dur_predictor_kernel'])
         | 
| 51 | 
            +
                    self.length_regulator = LengthRegulator()
         | 
| 52 | 
            +
                    if hparams['use_pitch_embed']:
         | 
| 53 | 
            +
                        self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
         | 
| 54 | 
            +
                        if hparams['pitch_type'] == 'cwt':
         | 
| 55 | 
            +
                            h = hparams['cwt_hidden_size']
         | 
| 56 | 
            +
                            cwt_out_dims = 10
         | 
| 57 | 
            +
                            if hparams['use_uv']:
         | 
| 58 | 
            +
                                cwt_out_dims = cwt_out_dims + 1
         | 
| 59 | 
            +
                            self.cwt_predictor = nn.Sequential(
         | 
| 60 | 
            +
                                nn.Linear(self.hidden_size, h),
         | 
| 61 | 
            +
                                PitchPredictor(
         | 
| 62 | 
            +
                                    h,
         | 
| 63 | 
            +
                                    n_chans=predictor_hidden,
         | 
| 64 | 
            +
                                    n_layers=hparams['predictor_layers'],
         | 
| 65 | 
            +
                                    dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
         | 
| 66 | 
            +
                                    padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
         | 
| 67 | 
            +
                            self.cwt_stats_layers = nn.Sequential(
         | 
| 68 | 
            +
                                nn.Linear(self.hidden_size, h), nn.ReLU(),
         | 
| 69 | 
            +
                                nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
         | 
| 70 | 
            +
                            )
         | 
| 71 | 
            +
                        else:
         | 
| 72 | 
            +
                            self.pitch_predictor = PitchPredictor(
         | 
| 73 | 
            +
                                self.hidden_size,
         | 
| 74 | 
            +
                                n_chans=predictor_hidden,
         | 
| 75 | 
            +
                                n_layers=hparams['predictor_layers'],
         | 
| 76 | 
            +
                                dropout_rate=hparams['predictor_dropout'],
         | 
| 77 | 
            +
                                odim=2 if hparams['pitch_type'] == 'frame' else 1,
         | 
| 78 | 
            +
                                padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
         | 
| 79 | 
            +
                    if hparams['use_energy_embed']:
         | 
| 80 | 
            +
                        self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
         | 
| 81 | 
            +
                        self.energy_predictor = EnergyPredictor(
         | 
| 82 | 
            +
                            self.hidden_size,
         | 
| 83 | 
            +
                            n_chans=predictor_hidden,
         | 
| 84 | 
            +
                            n_layers=hparams['predictor_layers'],
         | 
| 85 | 
            +
                            dropout_rate=hparams['predictor_dropout'], odim=1,
         | 
| 86 | 
            +
                            padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def build_embedding(self, dictionary, embed_dim):
         | 
| 89 | 
            +
                    num_embeddings = len(dictionary)
         | 
| 90 | 
            +
                    emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
         | 
| 91 | 
            +
                    return emb
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
         | 
| 94 | 
            +
                            ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
         | 
| 95 | 
            +
                            spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
         | 
| 96 | 
            +
                    ret = {}
         | 
| 97 | 
            +
                    encoder_out = self.encoder(txt_tokens)  # [B, T, C]
         | 
| 98 | 
            +
                    src_nonpadding = (txt_tokens > 0).float()[:, :, None]
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    # add ref style embed
         | 
| 101 | 
            +
                    # Not implemented
         | 
| 102 | 
            +
                    # variance encoder
         | 
| 103 | 
            +
                    var_embed = 0
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # encoder_out_dur denotes encoder outputs for duration predictor
         | 
| 106 | 
            +
                    # in speech adaptation, duration predictor use old speaker embedding
         | 
| 107 | 
            +
                    if hparams['use_spk_embed']:
         | 
| 108 | 
            +
                        spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
         | 
| 109 | 
            +
                    elif hparams['use_spk_id']:
         | 
| 110 | 
            +
                        spk_embed_id = spk_embed
         | 
| 111 | 
            +
                        if spk_embed_dur_id is None:
         | 
| 112 | 
            +
                            spk_embed_dur_id = spk_embed_id
         | 
| 113 | 
            +
                        if spk_embed_f0_id is None:
         | 
| 114 | 
            +
                            spk_embed_f0_id = spk_embed_id
         | 
| 115 | 
            +
                        spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
         | 
| 116 | 
            +
                        spk_embed_dur = spk_embed_f0 = spk_embed
         | 
| 117 | 
            +
                        if hparams['use_split_spk_id']:
         | 
| 118 | 
            +
                            spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
         | 
| 119 | 
            +
                            spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        spk_embed_dur = spk_embed_f0 = spk_embed = 0
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # add dur
         | 
| 124 | 
            +
                    dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
         | 
| 131 | 
            +
                    decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_)  # [B, T, H]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # add pitch and energy embed
         | 
| 136 | 
            +
                    pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
         | 
| 137 | 
            +
                    if hparams['use_pitch_embed']:
         | 
| 138 | 
            +
                        pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
         | 
| 139 | 
            +
                        decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
         | 
| 140 | 
            +
                    if hparams['use_energy_embed']:
         | 
| 141 | 
            +
                        decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if skip_decoder:
         | 
| 146 | 
            +
                        return ret
         | 
| 147 | 
            +
                    ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    return ret
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def add_dur(self, dur_input, mel2ph, txt_tokens, ret):
         | 
| 152 | 
            +
                    """
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    :param dur_input: [B, T_txt, H]
         | 
| 155 | 
            +
                    :param mel2ph: [B, T_mel]
         | 
| 156 | 
            +
                    :param txt_tokens: [B, T_txt]
         | 
| 157 | 
            +
                    :param ret:
         | 
| 158 | 
            +
                    :return:
         | 
| 159 | 
            +
                    """
         | 
| 160 | 
            +
                    src_padding = txt_tokens == 0
         | 
| 161 | 
            +
                    dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
         | 
| 162 | 
            +
                    if mel2ph is None:
         | 
| 163 | 
            +
                        dur, xs = self.dur_predictor.inference(dur_input, src_padding)
         | 
| 164 | 
            +
                        ret['dur'] = xs
         | 
| 165 | 
            +
                        ret['dur_choice'] = dur
         | 
| 166 | 
            +
                        mel2ph = self.length_regulator(dur, src_padding).detach()
         | 
| 167 | 
            +
                        # from modules.fastspeech.fake_modules import FakeLengthRegulator
         | 
| 168 | 
            +
                        # fake_lr = FakeLengthRegulator()
         | 
| 169 | 
            +
                        # fake_mel2ph = fake_lr(dur, (1 - src_padding.long()).sum(-1))[..., 0].detach()
         | 
| 170 | 
            +
                        # print(mel2ph == fake_mel2ph)
         | 
| 171 | 
            +
                    else:
         | 
| 172 | 
            +
                        ret['dur'] = self.dur_predictor(dur_input, src_padding)
         | 
| 173 | 
            +
                    ret['mel2ph'] = mel2ph
         | 
| 174 | 
            +
                    return mel2ph
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def add_energy(self, decoder_inp, energy, ret):
         | 
| 177 | 
            +
                    decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
         | 
| 178 | 
            +
                    ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
         | 
| 179 | 
            +
                    if energy is None:
         | 
| 180 | 
            +
                        energy = energy_pred
         | 
| 181 | 
            +
                    energy = torch.clamp(energy * 256 // 4, max=255).long()
         | 
| 182 | 
            +
                    energy_embed = self.energy_embed(energy)
         | 
| 183 | 
            +
                    return energy_embed
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
         | 
| 186 | 
            +
                    if hparams['pitch_type'] == 'ph':
         | 
| 187 | 
            +
                        pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
         | 
| 188 | 
            +
                        pitch_padding = encoder_out.sum().abs() == 0
         | 
| 189 | 
            +
                        ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
         | 
| 190 | 
            +
                        if f0 is None:
         | 
| 191 | 
            +
                            f0 = pitch_pred[:, :, 0]
         | 
| 192 | 
            +
                        ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
         | 
| 193 | 
            +
                        pitch = f0_to_coarse(f0_denorm)  # start from 0 [B, T_txt]
         | 
| 194 | 
            +
                        pitch = F.pad(pitch, [1, 0])
         | 
| 195 | 
            +
                        pitch = torch.gather(pitch, 1, mel2ph)  # [B, T_mel]
         | 
| 196 | 
            +
                        pitch_embed = self.pitch_embed(pitch)
         | 
| 197 | 
            +
                        return pitch_embed
         | 
| 198 | 
            +
                    decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    pitch_padding = mel2ph == 0
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    if hparams['pitch_type'] == 'cwt':
         | 
| 203 | 
            +
                        pitch_padding = None
         | 
| 204 | 
            +
                        ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
         | 
| 205 | 
            +
                        stats_out = self.cwt_stats_layers(encoder_out[:, 0, :])  # [B, 2]
         | 
| 206 | 
            +
                        mean = ret['f0_mean'] = stats_out[:, 0]
         | 
| 207 | 
            +
                        std = ret['f0_std'] = stats_out[:, 1]
         | 
| 208 | 
            +
                        cwt_spec = cwt_out[:, :, :10]
         | 
| 209 | 
            +
                        if f0 is None:
         | 
| 210 | 
            +
                            std = std * hparams['cwt_std_scale']
         | 
| 211 | 
            +
                            f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
         | 
| 212 | 
            +
                            if hparams['use_uv']:
         | 
| 213 | 
            +
                                assert cwt_out.shape[-1] == 11
         | 
| 214 | 
            +
                                uv = cwt_out[:, :, -1] > 0
         | 
| 215 | 
            +
                    elif hparams['pitch_ar']:
         | 
| 216 | 
            +
                        ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
         | 
| 217 | 
            +
                        if f0 is None:
         | 
| 218 | 
            +
                            f0 = pitch_pred[:, :, 0]
         | 
| 219 | 
            +
                    else:
         | 
| 220 | 
            +
                        ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
         | 
| 221 | 
            +
                        if f0 is None:
         | 
| 222 | 
            +
                            f0 = pitch_pred[:, :, 0]
         | 
| 223 | 
            +
                        if hparams['use_uv'] and uv is None:
         | 
| 224 | 
            +
                            uv = pitch_pred[:, :, 1] > 0
         | 
| 225 | 
            +
                    ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
         | 
| 226 | 
            +
                    if pitch_padding is not None:
         | 
| 227 | 
            +
                        f0[pitch_padding] = 0
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    pitch = f0_to_coarse(f0_denorm)  # start from 0
         | 
| 230 | 
            +
                    pitch_embed = self.pitch_embed(pitch)
         | 
| 231 | 
            +
                    return pitch_embed
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
         | 
| 234 | 
            +
                    x = decoder_inp  # [B, T, H]
         | 
| 235 | 
            +
                    x = self.decoder(x)
         | 
| 236 | 
            +
                    x = self.mel_out(x)
         | 
| 237 | 
            +
                    return x * tgt_nonpadding
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
         | 
| 240 | 
            +
                    f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
         | 
| 241 | 
            +
                    f0 = torch.cat(
         | 
| 242 | 
            +
                        [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
         | 
| 243 | 
            +
                    f0_norm = norm_f0(f0, None, hparams)
         | 
| 244 | 
            +
                    return f0_norm
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                def out2mel(self, out):
         | 
| 247 | 
            +
                    return out
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                @staticmethod
         | 
| 250 | 
            +
                def mel_norm(x):
         | 
| 251 | 
            +
                    return (x + 5.5) / (6.3 / 2) - 1
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                @staticmethod
         | 
| 254 | 
            +
                def mel_denorm(x):
         | 
| 255 | 
            +
                    return (x + 1) * (6.3 / 2) - 5.5
         | 
    	
        modules/fastspeech/pe.py
    ADDED
    
    | @@ -0,0 +1,149 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from modules.commons.common_layers import *
         | 
| 2 | 
            +
            from utils.hparams import hparams
         | 
| 3 | 
            +
            from modules.fastspeech.tts_modules import PitchPredictor
         | 
| 4 | 
            +
            from utils.pitch_utils import denorm_f0
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class Prenet(nn.Module):
         | 
| 8 | 
            +
                def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
         | 
| 9 | 
            +
                    super(Prenet, self).__init__()
         | 
| 10 | 
            +
                    padding = kernel // 2
         | 
| 11 | 
            +
                    self.layers = []
         | 
| 12 | 
            +
                    self.strides = strides if strides is not None else [1] * n_layers
         | 
| 13 | 
            +
                    for l in range(n_layers):
         | 
| 14 | 
            +
                        self.layers.append(nn.Sequential(
         | 
| 15 | 
            +
                            nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
         | 
| 16 | 
            +
                            nn.ReLU(),
         | 
| 17 | 
            +
                            nn.BatchNorm1d(out_dim)
         | 
| 18 | 
            +
                        ))
         | 
| 19 | 
            +
                        in_dim = out_dim
         | 
| 20 | 
            +
                    self.layers = nn.ModuleList(self.layers)
         | 
| 21 | 
            +
                    self.out_proj = nn.Linear(out_dim, out_dim)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def forward(self, x):
         | 
| 24 | 
            +
                    """
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    :param x: [B, T, 80]
         | 
| 27 | 
            +
                    :return: [L, B, T, H], [B, T, H]
         | 
| 28 | 
            +
                    """
         | 
| 29 | 
            +
                    padding_mask = x.abs().sum(-1).eq(0).data  # [B, T]
         | 
| 30 | 
            +
                    nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :]  # [B, 1, T]
         | 
| 31 | 
            +
                    x = x.transpose(1, 2)
         | 
| 32 | 
            +
                    hiddens = []
         | 
| 33 | 
            +
                    for i, l in enumerate(self.layers):
         | 
| 34 | 
            +
                        nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
         | 
| 35 | 
            +
                        x = l(x) * nonpadding_mask_TB
         | 
| 36 | 
            +
                    hiddens.append(x)
         | 
| 37 | 
            +
                    hiddens = torch.stack(hiddens, 0)  # [L, B, H, T]
         | 
| 38 | 
            +
                    hiddens = hiddens.transpose(2, 3)  # [L, B, T, H]
         | 
| 39 | 
            +
                    x = self.out_proj(x.transpose(1, 2))  # [B, T, H]
         | 
| 40 | 
            +
                    x = x * nonpadding_mask_TB.transpose(1, 2)
         | 
| 41 | 
            +
                    return hiddens, x
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class ConvBlock(nn.Module):
         | 
| 45 | 
            +
                def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
         | 
| 48 | 
            +
                    self.norm = norm
         | 
| 49 | 
            +
                    if self.norm == 'bn':
         | 
| 50 | 
            +
                        self.norm = nn.BatchNorm1d(n_chans)
         | 
| 51 | 
            +
                    elif self.norm == 'in':
         | 
| 52 | 
            +
                        self.norm = nn.InstanceNorm1d(n_chans, affine=True)
         | 
| 53 | 
            +
                    elif self.norm == 'gn':
         | 
| 54 | 
            +
                        self.norm = nn.GroupNorm(n_chans // 16, n_chans)
         | 
| 55 | 
            +
                    elif self.norm == 'ln':
         | 
| 56 | 
            +
                        self.norm = LayerNorm(n_chans // 16, n_chans)
         | 
| 57 | 
            +
                    elif self.norm == 'wn':
         | 
| 58 | 
            +
                        self.conv = torch.nn.utils.weight_norm(self.conv.conv)
         | 
| 59 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 60 | 
            +
                    self.relu = nn.ReLU()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def forward(self, x):
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    :param x: [B, C, T]
         | 
| 66 | 
            +
                    :return: [B, C, T]
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
                    x = self.conv(x)
         | 
| 69 | 
            +
                    if not isinstance(self.norm, str):
         | 
| 70 | 
            +
                        if self.norm == 'none':
         | 
| 71 | 
            +
                            pass
         | 
| 72 | 
            +
                        elif self.norm == 'ln':
         | 
| 73 | 
            +
                            x = self.norm(x.transpose(1, 2)).transpose(1, 2)
         | 
| 74 | 
            +
                        else:
         | 
| 75 | 
            +
                            x = self.norm(x)
         | 
| 76 | 
            +
                    x = self.relu(x)
         | 
| 77 | 
            +
                    x = self.dropout(x)
         | 
| 78 | 
            +
                    return x
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            class ConvStacks(nn.Module):
         | 
| 82 | 
            +
                def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
         | 
| 83 | 
            +
                             dropout=0, strides=None, res=True):
         | 
| 84 | 
            +
                    super().__init__()
         | 
| 85 | 
            +
                    self.conv = torch.nn.ModuleList()
         | 
| 86 | 
            +
                    self.kernel_size = kernel_size
         | 
| 87 | 
            +
                    self.res = res
         | 
| 88 | 
            +
                    self.in_proj = Linear(idim, n_chans)
         | 
| 89 | 
            +
                    if strides is None:
         | 
| 90 | 
            +
                        strides = [1] * n_layers
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        assert len(strides) == n_layers
         | 
| 93 | 
            +
                    for idx in range(n_layers):
         | 
| 94 | 
            +
                        self.conv.append(ConvBlock(
         | 
| 95 | 
            +
                            n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
         | 
| 96 | 
            +
                    self.out_proj = Linear(n_chans, odim)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def forward(self, x, return_hiddens=False):
         | 
| 99 | 
            +
                    """
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    :param x: [B, T, H]
         | 
| 102 | 
            +
                    :return: [B, T, H]
         | 
| 103 | 
            +
                    """
         | 
| 104 | 
            +
                    x = self.in_proj(x)
         | 
| 105 | 
            +
                    x = x.transpose(1, -1)  # (B, idim, Tmax)
         | 
| 106 | 
            +
                    hiddens = []
         | 
| 107 | 
            +
                    for f in self.conv:
         | 
| 108 | 
            +
                        x_ = f(x)
         | 
| 109 | 
            +
                        x = x + x_ if self.res else x_  # (B, C, Tmax)
         | 
| 110 | 
            +
                        hiddens.append(x)
         | 
| 111 | 
            +
                    x = x.transpose(1, -1)
         | 
| 112 | 
            +
                    x = self.out_proj(x)  # (B, Tmax, H)
         | 
| 113 | 
            +
                    if return_hiddens:
         | 
| 114 | 
            +
                        hiddens = torch.stack(hiddens, 1)  # [B, L, C, T]
         | 
| 115 | 
            +
                        return x, hiddens
         | 
| 116 | 
            +
                    return x
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            class PitchExtractor(nn.Module):
         | 
| 120 | 
            +
                def __init__(self, n_mel_bins=80, conv_layers=2):
         | 
| 121 | 
            +
                    super().__init__()
         | 
| 122 | 
            +
                    self.hidden_size = 256
         | 
| 123 | 
            +
                    self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
         | 
| 124 | 
            +
                    self.conv_layers = conv_layers
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
         | 
| 127 | 
            +
                    if self.conv_layers > 0:
         | 
| 128 | 
            +
                        self.mel_encoder = ConvStacks(
         | 
| 129 | 
            +
                                idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
         | 
| 130 | 
            +
                    self.pitch_predictor = PitchPredictor(
         | 
| 131 | 
            +
                        self.hidden_size, n_chans=self.predictor_hidden,
         | 
| 132 | 
            +
                        n_layers=5, dropout_rate=0.5, odim=2,
         | 
| 133 | 
            +
                        padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def forward(self, mel_input=None):
         | 
| 136 | 
            +
                    ret = {}
         | 
| 137 | 
            +
                    mel_hidden = self.mel_prenet(mel_input)[1]
         | 
| 138 | 
            +
                    if self.conv_layers > 0:
         | 
| 139 | 
            +
                        mel_hidden = self.mel_encoder(mel_hidden)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    pitch_padding = mel_input.abs().sum(-1) == 0
         | 
| 144 | 
            +
                    use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv']
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    ret['f0_denorm_pred'] = denorm_f0(
         | 
| 147 | 
            +
                        pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
         | 
| 148 | 
            +
                        hparams, pitch_padding=pitch_padding)
         | 
| 149 | 
            +
                    return ret
         | 
    	
        modules/fastspeech/tts_modules.py
    ADDED
    
    | @@ -0,0 +1,357 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from torch.nn import functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from modules.commons.espnet_positional_embedding import RelPositionalEncoding
         | 
| 9 | 
            +
            from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
         | 
| 10 | 
            +
            from utils.hparams import hparams
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            DEFAULT_MAX_SOURCE_POSITIONS = 2000
         | 
| 13 | 
            +
            DEFAULT_MAX_TARGET_POSITIONS = 2000
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class TransformerEncoderLayer(nn.Module):
         | 
| 17 | 
            +
                def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
         | 
| 18 | 
            +
                    super().__init__()
         | 
| 19 | 
            +
                    self.hidden_size = hidden_size
         | 
| 20 | 
            +
                    self.dropout = dropout
         | 
| 21 | 
            +
                    self.num_heads = num_heads
         | 
| 22 | 
            +
                    self.op = EncSALayer(
         | 
| 23 | 
            +
                        hidden_size, num_heads, dropout=dropout,
         | 
| 24 | 
            +
                        attention_dropout=0.0, relu_dropout=dropout,
         | 
| 25 | 
            +
                        kernel_size=kernel_size
         | 
| 26 | 
            +
                        if kernel_size is not None else hparams['enc_ffn_kernel_size'],
         | 
| 27 | 
            +
                        padding=hparams['ffn_padding'],
         | 
| 28 | 
            +
                        norm=norm, act=hparams['ffn_act'])
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, x, **kwargs):
         | 
| 31 | 
            +
                    return self.op(x, **kwargs)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            ######################
         | 
| 35 | 
            +
            # fastspeech modules
         | 
| 36 | 
            +
            ######################
         | 
| 37 | 
            +
            class LayerNorm(torch.nn.LayerNorm):
         | 
| 38 | 
            +
                """Layer normalization module.
         | 
| 39 | 
            +
                :param int nout: output dim size
         | 
| 40 | 
            +
                :param int dim: dimension to be normalized
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def __init__(self, nout, dim=-1):
         | 
| 44 | 
            +
                    """Construct an LayerNorm object."""
         | 
| 45 | 
            +
                    super(LayerNorm, self).__init__(nout, eps=1e-12)
         | 
| 46 | 
            +
                    self.dim = dim
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def forward(self, x):
         | 
| 49 | 
            +
                    """Apply layer normalization.
         | 
| 50 | 
            +
                    :param torch.Tensor x: input tensor
         | 
| 51 | 
            +
                    :return: layer normalized tensor
         | 
| 52 | 
            +
                    :rtype torch.Tensor
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    if self.dim == -1:
         | 
| 55 | 
            +
                        return super(LayerNorm, self).forward(x)
         | 
| 56 | 
            +
                    return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            class DurationPredictor(torch.nn.Module):
         | 
| 60 | 
            +
                """Duration predictor module.
         | 
| 61 | 
            +
                This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
         | 
| 62 | 
            +
                The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
         | 
| 63 | 
            +
                .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
         | 
| 64 | 
            +
                    https://arxiv.org/pdf/1905.09263.pdf
         | 
| 65 | 
            +
                Note:
         | 
| 66 | 
            +
                    The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
         | 
| 67 | 
            +
                    the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
         | 
| 71 | 
            +
                    """Initilize duration predictor module.
         | 
| 72 | 
            +
                    Args:
         | 
| 73 | 
            +
                        idim (int): Input dimension.
         | 
| 74 | 
            +
                        n_layers (int, optional): Number of convolutional layers.
         | 
| 75 | 
            +
                        n_chans (int, optional): Number of channels of convolutional layers.
         | 
| 76 | 
            +
                        kernel_size (int, optional): Kernel size of convolutional layers.
         | 
| 77 | 
            +
                        dropout_rate (float, optional): Dropout rate.
         | 
| 78 | 
            +
                        offset (float, optional): Offset value to avoid nan in log domain.
         | 
| 79 | 
            +
                    """
         | 
| 80 | 
            +
                    super(DurationPredictor, self).__init__()
         | 
| 81 | 
            +
                    self.offset = offset
         | 
| 82 | 
            +
                    self.conv = torch.nn.ModuleList()
         | 
| 83 | 
            +
                    self.kernel_size = kernel_size
         | 
| 84 | 
            +
                    self.padding = padding
         | 
| 85 | 
            +
                    for idx in range(n_layers):
         | 
| 86 | 
            +
                        in_chans = idim if idx == 0 else n_chans
         | 
| 87 | 
            +
                        self.conv += [torch.nn.Sequential(
         | 
| 88 | 
            +
                            torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
         | 
| 89 | 
            +
                                                   if padding == 'SAME'
         | 
| 90 | 
            +
                                                   else (kernel_size - 1, 0), 0),
         | 
| 91 | 
            +
                            torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
         | 
| 92 | 
            +
                            torch.nn.ReLU(),
         | 
| 93 | 
            +
                            LayerNorm(n_chans, dim=1),
         | 
| 94 | 
            +
                            torch.nn.Dropout(dropout_rate)
         | 
| 95 | 
            +
                        )]
         | 
| 96 | 
            +
                    if hparams['dur_loss'] in ['mse', 'huber']:
         | 
| 97 | 
            +
                        odims = 1
         | 
| 98 | 
            +
                    elif hparams['dur_loss'] == 'mog':
         | 
| 99 | 
            +
                        odims = 15
         | 
| 100 | 
            +
                    elif hparams['dur_loss'] == 'crf':
         | 
| 101 | 
            +
                        odims = 32
         | 
| 102 | 
            +
                        from torchcrf import CRF
         | 
| 103 | 
            +
                        self.crf = CRF(odims, batch_first=True)
         | 
| 104 | 
            +
                    self.linear = torch.nn.Linear(n_chans, odims)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def _forward(self, xs, x_masks=None, is_inference=False):
         | 
| 107 | 
            +
                    xs = xs.transpose(1, -1)  # (B, idim, Tmax)
         | 
| 108 | 
            +
                    for f in self.conv:
         | 
| 109 | 
            +
                        xs = f(xs)  # (B, C, Tmax)
         | 
| 110 | 
            +
                        if x_masks is not None:
         | 
| 111 | 
            +
                            xs = xs * (1 - x_masks.float())[:, None, :]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    xs = self.linear(xs.transpose(1, -1))  # [B, T, C]
         | 
| 114 | 
            +
                    xs = xs * (1 - x_masks.float())[:, :, None]  # (B, T, C)
         | 
| 115 | 
            +
                    if is_inference:
         | 
| 116 | 
            +
                        return self.out2dur(xs), xs
         | 
| 117 | 
            +
                    else:
         | 
| 118 | 
            +
                        if hparams['dur_loss'] in ['mse']:
         | 
| 119 | 
            +
                            xs = xs.squeeze(-1)  # (B, Tmax)
         | 
| 120 | 
            +
                    return xs
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def out2dur(self, xs):
         | 
| 123 | 
            +
                    if hparams['dur_loss'] in ['mse']:
         | 
| 124 | 
            +
                        # NOTE: calculate in log domain
         | 
| 125 | 
            +
                        xs = xs.squeeze(-1)  # (B, Tmax)
         | 
| 126 | 
            +
                        dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long()  # avoid negative value
         | 
| 127 | 
            +
                    elif hparams['dur_loss'] == 'mog':
         | 
| 128 | 
            +
                        return NotImplementedError
         | 
| 129 | 
            +
                    elif hparams['dur_loss'] == 'crf':
         | 
| 130 | 
            +
                        dur = torch.LongTensor(self.crf.decode(xs)).cuda()
         | 
| 131 | 
            +
                    return dur
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def forward(self, xs, x_masks=None):
         | 
| 134 | 
            +
                    """Calculate forward propagation.
         | 
| 135 | 
            +
                    Args:
         | 
| 136 | 
            +
                        xs (Tensor): Batch of input sequences (B, Tmax, idim).
         | 
| 137 | 
            +
                        x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
         | 
| 138 | 
            +
                    Returns:
         | 
| 139 | 
            +
                        Tensor: Batch of predicted durations in log domain (B, Tmax).
         | 
| 140 | 
            +
                    """
         | 
| 141 | 
            +
                    return self._forward(xs, x_masks, False)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                def inference(self, xs, x_masks=None):
         | 
| 144 | 
            +
                    """Inference duration.
         | 
| 145 | 
            +
                    Args:
         | 
| 146 | 
            +
                        xs (Tensor): Batch of input sequences (B, Tmax, idim).
         | 
| 147 | 
            +
                        x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
         | 
| 148 | 
            +
                    Returns:
         | 
| 149 | 
            +
                        LongTensor: Batch of predicted durations in linear domain (B, Tmax).
         | 
| 150 | 
            +
                    """
         | 
| 151 | 
            +
                    return self._forward(xs, x_masks, True)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            class LengthRegulator(torch.nn.Module):
         | 
| 155 | 
            +
                def __init__(self, pad_value=0.0):
         | 
| 156 | 
            +
                    super(LengthRegulator, self).__init__()
         | 
| 157 | 
            +
                    self.pad_value = pad_value
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def forward(self, dur, dur_padding=None, alpha=1.0):
         | 
| 160 | 
            +
                    """
         | 
| 161 | 
            +
                    Example (no batch dim version):
         | 
| 162 | 
            +
                        1. dur = [2,2,3]
         | 
| 163 | 
            +
                        2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
         | 
| 164 | 
            +
                        3. token_mask = [[1,1,0,0,0,0,0],
         | 
| 165 | 
            +
                                         [0,0,1,1,0,0,0],
         | 
| 166 | 
            +
                                         [0,0,0,0,1,1,1]]
         | 
| 167 | 
            +
                        4. token_idx * token_mask = [[1,1,0,0,0,0,0],
         | 
| 168 | 
            +
                                                     [0,0,2,2,0,0,0],
         | 
| 169 | 
            +
                                                     [0,0,0,0,3,3,3]]
         | 
| 170 | 
            +
                        5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    :param dur: Batch of durations of each frame (B, T_txt)
         | 
| 173 | 
            +
                    :param dur_padding: Batch of padding of each frame (B, T_txt)
         | 
| 174 | 
            +
                    :param alpha: duration rescale coefficient
         | 
| 175 | 
            +
                    :return:
         | 
| 176 | 
            +
                        mel2ph (B, T_speech)
         | 
| 177 | 
            +
                    """
         | 
| 178 | 
            +
                    assert alpha > 0
         | 
| 179 | 
            +
                    dur = torch.round(dur.float() * alpha).long()
         | 
| 180 | 
            +
                    if dur_padding is not None:
         | 
| 181 | 
            +
                        dur = dur * (1 - dur_padding.long())
         | 
| 182 | 
            +
                    token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
         | 
| 183 | 
            +
                    dur_cumsum = torch.cumsum(dur, 1)
         | 
| 184 | 
            +
                    dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
         | 
| 187 | 
            +
                    token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
         | 
| 188 | 
            +
                    mel2ph = (token_idx * token_mask.long()).sum(1)
         | 
| 189 | 
            +
                    return mel2ph
         | 
| 190 | 
            +
             | 
| 191 | 
            +
             | 
| 192 | 
            +
            class PitchPredictor(torch.nn.Module):
         | 
| 193 | 
            +
                def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
         | 
| 194 | 
            +
                             dropout_rate=0.1, padding='SAME'):
         | 
| 195 | 
            +
                    """Initilize pitch predictor module.
         | 
| 196 | 
            +
                    Args:
         | 
| 197 | 
            +
                        idim (int): Input dimension.
         | 
| 198 | 
            +
                        n_layers (int, optional): Number of convolutional layers.
         | 
| 199 | 
            +
                        n_chans (int, optional): Number of channels of convolutional layers.
         | 
| 200 | 
            +
                        kernel_size (int, optional): Kernel size of convolutional layers.
         | 
| 201 | 
            +
                        dropout_rate (float, optional): Dropout rate.
         | 
| 202 | 
            +
                    """
         | 
| 203 | 
            +
                    super(PitchPredictor, self).__init__()
         | 
| 204 | 
            +
                    self.conv = torch.nn.ModuleList()
         | 
| 205 | 
            +
                    self.kernel_size = kernel_size
         | 
| 206 | 
            +
                    self.padding = padding
         | 
| 207 | 
            +
                    for idx in range(n_layers):
         | 
| 208 | 
            +
                        in_chans = idim if idx == 0 else n_chans
         | 
| 209 | 
            +
                        self.conv += [torch.nn.Sequential(
         | 
| 210 | 
            +
                            torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
         | 
| 211 | 
            +
                                                   if padding == 'SAME'
         | 
| 212 | 
            +
                                                   else (kernel_size - 1, 0), 0),
         | 
| 213 | 
            +
                            torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
         | 
| 214 | 
            +
                            torch.nn.ReLU(),
         | 
| 215 | 
            +
                            LayerNorm(n_chans, dim=1),
         | 
| 216 | 
            +
                            torch.nn.Dropout(dropout_rate)
         | 
| 217 | 
            +
                        )]
         | 
| 218 | 
            +
                    self.linear = torch.nn.Linear(n_chans, odim)
         | 
| 219 | 
            +
                    self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
         | 
| 220 | 
            +
                    self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def forward(self, xs):
         | 
| 223 | 
            +
                    """
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    :param xs: [B, T, H]
         | 
| 226 | 
            +
                    :return: [B, T, H]
         | 
| 227 | 
            +
                    """
         | 
| 228 | 
            +
                    positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
         | 
| 229 | 
            +
                    xs = xs + positions
         | 
| 230 | 
            +
                    xs = xs.transpose(1, -1)  # (B, idim, Tmax)
         | 
| 231 | 
            +
                    for f in self.conv:
         | 
| 232 | 
            +
                        xs = f(xs)  # (B, C, Tmax)
         | 
| 233 | 
            +
                    # NOTE: calculate in log domain
         | 
| 234 | 
            +
                    xs = self.linear(xs.transpose(1, -1))  # (B, Tmax, H)
         | 
| 235 | 
            +
                    return xs
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            class EnergyPredictor(PitchPredictor):
         | 
| 239 | 
            +
                pass
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
         | 
| 243 | 
            +
                B, _ = mel2ph.shape
         | 
| 244 | 
            +
                dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
         | 
| 245 | 
            +
                dur = dur[:, 1:]
         | 
| 246 | 
            +
                if max_dur is not None:
         | 
| 247 | 
            +
                    dur = dur.clamp(max=max_dur)
         | 
| 248 | 
            +
                return dur
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            class FFTBlocks(nn.Module):
         | 
| 252 | 
            +
                def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
         | 
| 253 | 
            +
                             use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
         | 
| 254 | 
            +
                    super().__init__()
         | 
| 255 | 
            +
                    self.num_layers = num_layers
         | 
| 256 | 
            +
                    embed_dim = self.hidden_size = hidden_size
         | 
| 257 | 
            +
                    self.dropout = dropout if dropout is not None else hparams['dropout']
         | 
| 258 | 
            +
                    self.use_pos_embed = use_pos_embed
         | 
| 259 | 
            +
                    self.use_last_norm = use_last_norm
         | 
| 260 | 
            +
                    if use_pos_embed:
         | 
| 261 | 
            +
                        self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
         | 
| 262 | 
            +
                        self.padding_idx = 0
         | 
| 263 | 
            +
                        self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
         | 
| 264 | 
            +
                        self.embed_positions = SinusoidalPositionalEmbedding(
         | 
| 265 | 
            +
                            embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
         | 
| 266 | 
            +
                        )
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    self.layers = nn.ModuleList([])
         | 
| 269 | 
            +
                    self.layers.extend([
         | 
| 270 | 
            +
                        TransformerEncoderLayer(self.hidden_size, self.dropout,
         | 
| 271 | 
            +
                                                kernel_size=ffn_kernel_size, num_heads=num_heads)
         | 
| 272 | 
            +
                        for _ in range(self.num_layers)
         | 
| 273 | 
            +
                    ])
         | 
| 274 | 
            +
                    if self.use_last_norm:
         | 
| 275 | 
            +
                        if norm == 'ln':
         | 
| 276 | 
            +
                            self.layer_norm = nn.LayerNorm(embed_dim)
         | 
| 277 | 
            +
                        elif norm == 'bn':
         | 
| 278 | 
            +
                            self.layer_norm = BatchNorm1dTBC(embed_dim)
         | 
| 279 | 
            +
                    else:
         | 
| 280 | 
            +
                        self.layer_norm = None
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
         | 
| 283 | 
            +
                    """
         | 
| 284 | 
            +
                    :param x: [B, T, C]
         | 
| 285 | 
            +
                    :param padding_mask: [B, T]
         | 
| 286 | 
            +
                    :return: [B, T, C] or [L, B, T, C]
         | 
| 287 | 
            +
                    """
         | 
| 288 | 
            +
                    padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
         | 
| 289 | 
            +
                    nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None]  # [T, B, 1]
         | 
| 290 | 
            +
                    if self.use_pos_embed:
         | 
| 291 | 
            +
                        positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
         | 
| 292 | 
            +
                        x = x + positions
         | 
| 293 | 
            +
                        x = F.dropout(x, p=self.dropout, training=self.training)
         | 
| 294 | 
            +
                    # B x T x C -> T x B x C
         | 
| 295 | 
            +
                    x = x.transpose(0, 1) * nonpadding_mask_TB
         | 
| 296 | 
            +
                    hiddens = []
         | 
| 297 | 
            +
                    for layer in self.layers:
         | 
| 298 | 
            +
                        x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
         | 
| 299 | 
            +
                        hiddens.append(x)
         | 
| 300 | 
            +
                    if self.use_last_norm:
         | 
| 301 | 
            +
                        x = self.layer_norm(x) * nonpadding_mask_TB
         | 
| 302 | 
            +
                    if return_hiddens:
         | 
| 303 | 
            +
                        x = torch.stack(hiddens, 0)  # [L, T, B, C]
         | 
| 304 | 
            +
                        x = x.transpose(1, 2)  # [L, B, T, C]
         | 
| 305 | 
            +
                    else:
         | 
| 306 | 
            +
                        x = x.transpose(0, 1)  # [B, T, C]
         | 
| 307 | 
            +
                    return x
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            class FastspeechEncoder(FFTBlocks):
         | 
| 311 | 
            +
                def __init__(self, embed_tokens, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
         | 
| 312 | 
            +
                    hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
         | 
| 313 | 
            +
                    kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
         | 
| 314 | 
            +
                    num_layers = hparams['dec_layers'] if num_layers is None else num_layers
         | 
| 315 | 
            +
                    super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
         | 
| 316 | 
            +
                                     use_pos_embed=False)  # use_pos_embed_alpha for compatibility
         | 
| 317 | 
            +
                    self.embed_tokens = embed_tokens
         | 
| 318 | 
            +
                    self.embed_scale = math.sqrt(hidden_size)
         | 
| 319 | 
            +
                    self.padding_idx = 0
         | 
| 320 | 
            +
                    if hparams.get('rel_pos') is not None and hparams['rel_pos']:
         | 
| 321 | 
            +
                        self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
         | 
| 322 | 
            +
                    else:
         | 
| 323 | 
            +
                        self.embed_positions = SinusoidalPositionalEmbedding(
         | 
| 324 | 
            +
                            hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
         | 
| 325 | 
            +
                        )
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                def forward(self, txt_tokens):
         | 
| 328 | 
            +
                    """
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    :param txt_tokens: [B, T]
         | 
| 331 | 
            +
                    :return: {
         | 
| 332 | 
            +
                        'encoder_out': [T x B x C]
         | 
| 333 | 
            +
                    }
         | 
| 334 | 
            +
                    """
         | 
| 335 | 
            +
                    encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
         | 
| 336 | 
            +
                    x = self.forward_embedding(txt_tokens)  # [B, T, H]
         | 
| 337 | 
            +
                    x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
         | 
| 338 | 
            +
                    return x
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                def forward_embedding(self, txt_tokens):
         | 
| 341 | 
            +
                    # embed tokens and positions
         | 
| 342 | 
            +
                    x = self.embed_scale * self.embed_tokens(txt_tokens)
         | 
| 343 | 
            +
                    if hparams['use_pos_embed']:
         | 
| 344 | 
            +
                        positions = self.embed_positions(txt_tokens)
         | 
| 345 | 
            +
                        x = x + positions
         | 
| 346 | 
            +
                    x = F.dropout(x, p=self.dropout, training=self.training)
         | 
| 347 | 
            +
                    return x
         | 
| 348 | 
            +
             | 
| 349 | 
            +
             | 
| 350 | 
            +
            class FastspeechDecoder(FFTBlocks):
         | 
| 351 | 
            +
                def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
         | 
| 352 | 
            +
                    num_heads = hparams['num_heads'] if num_heads is None else num_heads
         | 
| 353 | 
            +
                    hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
         | 
| 354 | 
            +
                    kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
         | 
| 355 | 
            +
                    num_layers = hparams['dec_layers'] if num_layers is None else num_layers
         | 
| 356 | 
            +
                    super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
         | 
| 357 | 
            +
             | 
    	
        modules/hifigan/hifigan.py
    ADDED
    
    | @@ -0,0 +1,370 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
         | 
| 5 | 
            +
            from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
         | 
| 8 | 
            +
            from modules.parallel_wavegan.models.source import SourceModuleHnNSF
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            LRELU_SLOPE = 0.1
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 15 | 
            +
                classname = m.__class__.__name__
         | 
| 16 | 
            +
                if classname.find("Conv") != -1:
         | 
| 17 | 
            +
                    m.weight.data.normal_(mean, std)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def apply_weight_norm(m):
         | 
| 21 | 
            +
                classname = m.__class__.__name__
         | 
| 22 | 
            +
                if classname.find("Conv") != -1:
         | 
| 23 | 
            +
                    weight_norm(m)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def get_padding(kernel_size, dilation=1):
         | 
| 27 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class ResBlock1(torch.nn.Module):
         | 
| 31 | 
            +
                def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
         | 
| 32 | 
            +
                    super(ResBlock1, self).__init__()
         | 
| 33 | 
            +
                    self.h = h
         | 
| 34 | 
            +
                    self.convs1 = nn.ModuleList([
         | 
| 35 | 
            +
                        weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
         | 
| 36 | 
            +
                                           padding=get_padding(kernel_size, dilation[0]))),
         | 
| 37 | 
            +
                        weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
         | 
| 38 | 
            +
                                           padding=get_padding(kernel_size, dilation[1]))),
         | 
| 39 | 
            +
                        weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
         | 
| 40 | 
            +
                                           padding=get_padding(kernel_size, dilation[2])))
         | 
| 41 | 
            +
                    ])
         | 
| 42 | 
            +
                    self.convs1.apply(init_weights)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    self.convs2 = nn.ModuleList([
         | 
| 45 | 
            +
                        weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
         | 
| 46 | 
            +
                                           padding=get_padding(kernel_size, 1))),
         | 
| 47 | 
            +
                        weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
         | 
| 48 | 
            +
                                           padding=get_padding(kernel_size, 1))),
         | 
| 49 | 
            +
                        weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
         | 
| 50 | 
            +
                                           padding=get_padding(kernel_size, 1)))
         | 
| 51 | 
            +
                    ])
         | 
| 52 | 
            +
                    self.convs2.apply(init_weights)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def forward(self, x):
         | 
| 55 | 
            +
                    for c1, c2 in zip(self.convs1, self.convs2):
         | 
| 56 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 57 | 
            +
                        xt = c1(xt)
         | 
| 58 | 
            +
                        xt = F.leaky_relu(xt, LRELU_SLOPE)
         | 
| 59 | 
            +
                        xt = c2(xt)
         | 
| 60 | 
            +
                        x = xt + x
         | 
| 61 | 
            +
                    return x
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def remove_weight_norm(self):
         | 
| 64 | 
            +
                    for l in self.convs1:
         | 
| 65 | 
            +
                        remove_weight_norm(l)
         | 
| 66 | 
            +
                    for l in self.convs2:
         | 
| 67 | 
            +
                        remove_weight_norm(l)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            class ResBlock2(torch.nn.Module):
         | 
| 71 | 
            +
                def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
         | 
| 72 | 
            +
                    super(ResBlock2, self).__init__()
         | 
| 73 | 
            +
                    self.h = h
         | 
| 74 | 
            +
                    self.convs = nn.ModuleList([
         | 
| 75 | 
            +
                        weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
         | 
| 76 | 
            +
                                           padding=get_padding(kernel_size, dilation[0]))),
         | 
| 77 | 
            +
                        weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
         | 
| 78 | 
            +
                                           padding=get_padding(kernel_size, dilation[1])))
         | 
| 79 | 
            +
                    ])
         | 
| 80 | 
            +
                    self.convs.apply(init_weights)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def forward(self, x):
         | 
| 83 | 
            +
                    for c in self.convs:
         | 
| 84 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 85 | 
            +
                        xt = c(xt)
         | 
| 86 | 
            +
                        x = xt + x
         | 
| 87 | 
            +
                    return x
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def remove_weight_norm(self):
         | 
| 90 | 
            +
                    for l in self.convs:
         | 
| 91 | 
            +
                        remove_weight_norm(l)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            class Conv1d1x1(Conv1d):
         | 
| 95 | 
            +
                """1x1 Conv1d with customized initialization."""
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def __init__(self, in_channels, out_channels, bias):
         | 
| 98 | 
            +
                    """Initialize 1x1 Conv1d module."""
         | 
| 99 | 
            +
                    super(Conv1d1x1, self).__init__(in_channels, out_channels,
         | 
| 100 | 
            +
                                                    kernel_size=1, padding=0,
         | 
| 101 | 
            +
                                                    dilation=1, bias=bias)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            class HifiGanGenerator(torch.nn.Module):
         | 
| 105 | 
            +
                def __init__(self, h, c_out=1):
         | 
| 106 | 
            +
                    super(HifiGanGenerator, self).__init__()
         | 
| 107 | 
            +
                    self.h = h
         | 
| 108 | 
            +
                    self.num_kernels = len(h['resblock_kernel_sizes'])
         | 
| 109 | 
            +
                    self.num_upsamples = len(h['upsample_rates'])
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    if h['use_pitch_embed']:
         | 
| 112 | 
            +
                        self.harmonic_num = 8
         | 
| 113 | 
            +
                        self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
         | 
| 114 | 
            +
                        self.m_source = SourceModuleHnNSF(
         | 
| 115 | 
            +
                            sampling_rate=h['audio_sample_rate'],
         | 
| 116 | 
            +
                            harmonic_num=self.harmonic_num)
         | 
| 117 | 
            +
                        self.noise_convs = nn.ModuleList()
         | 
| 118 | 
            +
                    self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
         | 
| 119 | 
            +
                    resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    self.ups = nn.ModuleList()
         | 
| 122 | 
            +
                    for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
         | 
| 123 | 
            +
                        c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
         | 
| 124 | 
            +
                        self.ups.append(weight_norm(
         | 
| 125 | 
            +
                            ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
         | 
| 126 | 
            +
                        if h['use_pitch_embed']:
         | 
| 127 | 
            +
                            if i + 1 < len(h['upsample_rates']):
         | 
| 128 | 
            +
                                stride_f0 = np.prod(h['upsample_rates'][i + 1:])
         | 
| 129 | 
            +
                                self.noise_convs.append(Conv1d(
         | 
| 130 | 
            +
                                    1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
         | 
| 131 | 
            +
                            else:
         | 
| 132 | 
            +
                                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    self.resblocks = nn.ModuleList()
         | 
| 135 | 
            +
                    for i in range(len(self.ups)):
         | 
| 136 | 
            +
                        ch = h['upsample_initial_channel'] // (2 ** (i + 1))
         | 
| 137 | 
            +
                        for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
         | 
| 138 | 
            +
                            self.resblocks.append(resblock(h, ch, k, d))
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
         | 
| 141 | 
            +
                    self.ups.apply(init_weights)
         | 
| 142 | 
            +
                    self.conv_post.apply(init_weights)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def forward(self, x, f0=None):
         | 
| 145 | 
            +
                    if f0 is not None:
         | 
| 146 | 
            +
                        # harmonic-source signal, noise-source signal, uv flag
         | 
| 147 | 
            +
                        f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
         | 
| 148 | 
            +
                        har_source, noi_source, uv = self.m_source(f0)
         | 
| 149 | 
            +
                        har_source = har_source.transpose(1, 2)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    x = self.conv_pre(x)
         | 
| 152 | 
            +
                    for i in range(self.num_upsamples):
         | 
| 153 | 
            +
                        x = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 154 | 
            +
                        x = self.ups[i](x)
         | 
| 155 | 
            +
                        if f0 is not None:
         | 
| 156 | 
            +
                            x_source = self.noise_convs[i](har_source)
         | 
| 157 | 
            +
                            x_source = torch.nn.functional.relu(x_source)
         | 
| 158 | 
            +
                            tmp_shape = x_source.shape[1]
         | 
| 159 | 
            +
                            x_source = torch.nn.functional.layer_norm(x_source.transpose(1, -1), (tmp_shape, )).transpose(1, -1)
         | 
| 160 | 
            +
                            x = x + x_source
         | 
| 161 | 
            +
                        xs = None
         | 
| 162 | 
            +
                        for j in range(self.num_kernels):
         | 
| 163 | 
            +
                            xs_ = self.resblocks[i * self.num_kernels + j](x)
         | 
| 164 | 
            +
                            if xs is None:
         | 
| 165 | 
            +
                                xs = xs_
         | 
| 166 | 
            +
                            else:
         | 
| 167 | 
            +
                                xs += xs_
         | 
| 168 | 
            +
                        x = xs / self.num_kernels
         | 
| 169 | 
            +
                    x = F.leaky_relu(x)
         | 
| 170 | 
            +
                    x = self.conv_post(x)
         | 
| 171 | 
            +
                    x = torch.tanh(x)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    return x
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def remove_weight_norm(self):
         | 
| 176 | 
            +
                    print('Removing weight norm...')
         | 
| 177 | 
            +
                    for l in self.ups:
         | 
| 178 | 
            +
                        remove_weight_norm(l)
         | 
| 179 | 
            +
                    for l in self.resblocks:
         | 
| 180 | 
            +
                        l.remove_weight_norm()
         | 
| 181 | 
            +
                    remove_weight_norm(self.conv_pre)
         | 
| 182 | 
            +
                    remove_weight_norm(self.conv_post)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            class DiscriminatorP(torch.nn.Module):
         | 
| 186 | 
            +
                def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
         | 
| 187 | 
            +
                    super(DiscriminatorP, self).__init__()
         | 
| 188 | 
            +
                    self.use_cond = use_cond
         | 
| 189 | 
            +
                    if use_cond:
         | 
| 190 | 
            +
                        from utils.hparams import hparams
         | 
| 191 | 
            +
                        t = hparams['hop_size']
         | 
| 192 | 
            +
                        self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
         | 
| 193 | 
            +
                        c_in = 2
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    self.period = period
         | 
| 196 | 
            +
                    norm_f = weight_norm if use_spectral_norm == False else spectral_norm
         | 
| 197 | 
            +
                    self.convs = nn.ModuleList([
         | 
| 198 | 
            +
                        norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
         | 
| 199 | 
            +
                        norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
         | 
| 200 | 
            +
                        norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
         | 
| 201 | 
            +
                        norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
         | 
| 202 | 
            +
                        norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
         | 
| 203 | 
            +
                    ])
         | 
| 204 | 
            +
                    self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                def forward(self, x, mel):
         | 
| 207 | 
            +
                    fmap = []
         | 
| 208 | 
            +
                    if self.use_cond:
         | 
| 209 | 
            +
                        x_mel = self.cond_net(mel)
         | 
| 210 | 
            +
                        x = torch.cat([x_mel, x], 1)
         | 
| 211 | 
            +
                    # 1d to 2d
         | 
| 212 | 
            +
                    b, c, t = x.shape
         | 
| 213 | 
            +
                    if t % self.period != 0:  # pad first
         | 
| 214 | 
            +
                        n_pad = self.period - (t % self.period)
         | 
| 215 | 
            +
                        x = F.pad(x, (0, n_pad), "reflect")
         | 
| 216 | 
            +
                        t = t + n_pad
         | 
| 217 | 
            +
                    x = x.view(b, c, t // self.period, self.period)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    for l in self.convs:
         | 
| 220 | 
            +
                        x = l(x)
         | 
| 221 | 
            +
                        x = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 222 | 
            +
                        fmap.append(x)
         | 
| 223 | 
            +
                    x = self.conv_post(x)
         | 
| 224 | 
            +
                    fmap.append(x)
         | 
| 225 | 
            +
                    x = torch.flatten(x, 1, -1)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    return x, fmap
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            class MultiPeriodDiscriminator(torch.nn.Module):
         | 
| 231 | 
            +
                def __init__(self, use_cond=False, c_in=1):
         | 
| 232 | 
            +
                    super(MultiPeriodDiscriminator, self).__init__()
         | 
| 233 | 
            +
                    self.discriminators = nn.ModuleList([
         | 
| 234 | 
            +
                        DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
         | 
| 235 | 
            +
                        DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
         | 
| 236 | 
            +
                        DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
         | 
| 237 | 
            +
                        DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
         | 
| 238 | 
            +
                        DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
         | 
| 239 | 
            +
                    ])
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                def forward(self, y, y_hat, mel=None):
         | 
| 242 | 
            +
                    y_d_rs = []
         | 
| 243 | 
            +
                    y_d_gs = []
         | 
| 244 | 
            +
                    fmap_rs = []
         | 
| 245 | 
            +
                    fmap_gs = []
         | 
| 246 | 
            +
                    for i, d in enumerate(self.discriminators):
         | 
| 247 | 
            +
                        y_d_r, fmap_r = d(y, mel)
         | 
| 248 | 
            +
                        y_d_g, fmap_g = d(y_hat, mel)
         | 
| 249 | 
            +
                        y_d_rs.append(y_d_r)
         | 
| 250 | 
            +
                        fmap_rs.append(fmap_r)
         | 
| 251 | 
            +
                        y_d_gs.append(y_d_g)
         | 
| 252 | 
            +
                        fmap_gs.append(fmap_g)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    return y_d_rs, y_d_gs, fmap_rs, fmap_gs
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            class DiscriminatorS(torch.nn.Module):
         | 
| 258 | 
            +
                def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
         | 
| 259 | 
            +
                    super(DiscriminatorS, self).__init__()
         | 
| 260 | 
            +
                    self.use_cond = use_cond
         | 
| 261 | 
            +
                    if use_cond:
         | 
| 262 | 
            +
                        t = np.prod(upsample_rates)
         | 
| 263 | 
            +
                        self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
         | 
| 264 | 
            +
                        c_in = 2
         | 
| 265 | 
            +
                    norm_f = weight_norm if use_spectral_norm == False else spectral_norm
         | 
| 266 | 
            +
                    self.convs = nn.ModuleList([
         | 
| 267 | 
            +
                        norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
         | 
| 268 | 
            +
                        norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
         | 
| 269 | 
            +
                        norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
         | 
| 270 | 
            +
                        norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
         | 
| 271 | 
            +
                        norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
         | 
| 272 | 
            +
                        norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
         | 
| 273 | 
            +
                        norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
         | 
| 274 | 
            +
                    ])
         | 
| 275 | 
            +
                    self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def forward(self, x, mel):
         | 
| 278 | 
            +
                    if self.use_cond:
         | 
| 279 | 
            +
                        x_mel = self.cond_net(mel)
         | 
| 280 | 
            +
                        x = torch.cat([x_mel, x], 1)
         | 
| 281 | 
            +
                    fmap = []
         | 
| 282 | 
            +
                    for l in self.convs:
         | 
| 283 | 
            +
                        x = l(x)
         | 
| 284 | 
            +
                        x = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 285 | 
            +
                        fmap.append(x)
         | 
| 286 | 
            +
                    x = self.conv_post(x)
         | 
| 287 | 
            +
                    fmap.append(x)
         | 
| 288 | 
            +
                    x = torch.flatten(x, 1, -1)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    return x, fmap
         | 
| 291 | 
            +
             | 
| 292 | 
            +
             | 
| 293 | 
            +
            class MultiScaleDiscriminator(torch.nn.Module):
         | 
| 294 | 
            +
                def __init__(self, use_cond=False, c_in=1):
         | 
| 295 | 
            +
                    super(MultiScaleDiscriminator, self).__init__()
         | 
| 296 | 
            +
                    from utils.hparams import hparams
         | 
| 297 | 
            +
                    self.discriminators = nn.ModuleList([
         | 
| 298 | 
            +
                        DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
         | 
| 299 | 
            +
                                       upsample_rates=[4, 4, hparams['hop_size'] // 16],
         | 
| 300 | 
            +
                                       c_in=c_in),
         | 
| 301 | 
            +
                        DiscriminatorS(use_cond=use_cond,
         | 
| 302 | 
            +
                                       upsample_rates=[4, 4, hparams['hop_size'] // 32],
         | 
| 303 | 
            +
                                       c_in=c_in),
         | 
| 304 | 
            +
                        DiscriminatorS(use_cond=use_cond,
         | 
| 305 | 
            +
                                       upsample_rates=[4, 4, hparams['hop_size'] // 64],
         | 
| 306 | 
            +
                                       c_in=c_in),
         | 
| 307 | 
            +
                    ])
         | 
| 308 | 
            +
                    self.meanpools = nn.ModuleList([
         | 
| 309 | 
            +
                        AvgPool1d(4, 2, padding=1),
         | 
| 310 | 
            +
                        AvgPool1d(4, 2, padding=1)
         | 
| 311 | 
            +
                    ])
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def forward(self, y, y_hat, mel=None):
         | 
| 314 | 
            +
                    y_d_rs = []
         | 
| 315 | 
            +
                    y_d_gs = []
         | 
| 316 | 
            +
                    fmap_rs = []
         | 
| 317 | 
            +
                    fmap_gs = []
         | 
| 318 | 
            +
                    for i, d in enumerate(self.discriminators):
         | 
| 319 | 
            +
                        if i != 0:
         | 
| 320 | 
            +
                            y = self.meanpools[i - 1](y)
         | 
| 321 | 
            +
                            y_hat = self.meanpools[i - 1](y_hat)
         | 
| 322 | 
            +
                        y_d_r, fmap_r = d(y, mel)
         | 
| 323 | 
            +
                        y_d_g, fmap_g = d(y_hat, mel)
         | 
| 324 | 
            +
                        y_d_rs.append(y_d_r)
         | 
| 325 | 
            +
                        fmap_rs.append(fmap_r)
         | 
| 326 | 
            +
                        y_d_gs.append(y_d_g)
         | 
| 327 | 
            +
                        fmap_gs.append(fmap_g)
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    return y_d_rs, y_d_gs, fmap_rs, fmap_gs
         | 
| 330 | 
            +
             | 
| 331 | 
            +
             | 
| 332 | 
            +
            def feature_loss(fmap_r, fmap_g):
         | 
| 333 | 
            +
                loss = 0
         | 
| 334 | 
            +
                for dr, dg in zip(fmap_r, fmap_g):
         | 
| 335 | 
            +
                    for rl, gl in zip(dr, dg):
         | 
| 336 | 
            +
                        loss += torch.mean(torch.abs(rl - gl))
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                return loss * 2
         | 
| 339 | 
            +
             | 
| 340 | 
            +
             | 
| 341 | 
            +
            def discriminator_loss(disc_real_outputs, disc_generated_outputs):
         | 
| 342 | 
            +
                r_losses = 0
         | 
| 343 | 
            +
                g_losses = 0
         | 
| 344 | 
            +
                for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
         | 
| 345 | 
            +
                    r_loss = torch.mean((1 - dr) ** 2)
         | 
| 346 | 
            +
                    g_loss = torch.mean(dg ** 2)
         | 
| 347 | 
            +
                    r_losses += r_loss
         | 
| 348 | 
            +
                    g_losses += g_loss
         | 
| 349 | 
            +
                r_losses = r_losses / len(disc_real_outputs)
         | 
| 350 | 
            +
                g_losses = g_losses / len(disc_real_outputs)
         | 
| 351 | 
            +
                return r_losses, g_losses
         | 
| 352 | 
            +
             | 
| 353 | 
            +
             | 
| 354 | 
            +
            def cond_discriminator_loss(outputs):
         | 
| 355 | 
            +
                loss = 0
         | 
| 356 | 
            +
                for dg in outputs:
         | 
| 357 | 
            +
                    g_loss = torch.mean(dg ** 2)
         | 
| 358 | 
            +
                    loss += g_loss
         | 
| 359 | 
            +
                loss = loss / len(outputs)
         | 
| 360 | 
            +
                return loss
         | 
| 361 | 
            +
             | 
| 362 | 
            +
             | 
| 363 | 
            +
            def generator_loss(disc_outputs):
         | 
| 364 | 
            +
                loss = 0
         | 
| 365 | 
            +
                for dg in disc_outputs:
         | 
| 366 | 
            +
                    l = torch.mean((1 - dg) ** 2)
         | 
| 367 | 
            +
                    loss += l
         | 
| 368 | 
            +
                loss = loss / len(disc_outputs)
         | 
| 369 | 
            +
                return loss
         | 
| 370 | 
            +
             | 
    	
        modules/hifigan/mel_utils.py
    ADDED
    
    | @@ -0,0 +1,81 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.utils.data
         | 
| 4 | 
            +
            from librosa.filters import mel as librosa_mel_fn
         | 
| 5 | 
            +
            from scipy.io.wavfile import read
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            MAX_WAV_VALUE = 32768.0
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def load_wav(full_path):
         | 
| 11 | 
            +
                sampling_rate, data = read(full_path)
         | 
| 12 | 
            +
                return data, sampling_rate
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def dynamic_range_compression(x, C=1, clip_val=1e-5):
         | 
| 16 | 
            +
                return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def dynamic_range_decompression(x, C=1):
         | 
| 20 | 
            +
                return np.exp(x) / C
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
         | 
| 24 | 
            +
                return torch.log(torch.clamp(x, min=clip_val) * C)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def dynamic_range_decompression_torch(x, C=1):
         | 
| 28 | 
            +
                return torch.exp(x) / C
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def spectral_normalize_torch(magnitudes):
         | 
| 32 | 
            +
                output = dynamic_range_compression_torch(magnitudes)
         | 
| 33 | 
            +
                return output
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def spectral_de_normalize_torch(magnitudes):
         | 
| 37 | 
            +
                output = dynamic_range_decompression_torch(magnitudes)
         | 
| 38 | 
            +
                return output
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            mel_basis = {}
         | 
| 42 | 
            +
            hann_window = {}
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def mel_spectrogram(y, hparams, center=False, complex=False):
         | 
| 46 | 
            +
                # hop_size: 512  # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
         | 
| 47 | 
            +
                # win_size: 2048  # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
         | 
| 48 | 
            +
                # fmin: 55  # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
         | 
| 49 | 
            +
                # fmax: 10000  # To be increased/reduced depending on data.
         | 
| 50 | 
            +
                # fft_size: 2048  # Extra window size is filled with 0 paddings to match this parameter
         | 
| 51 | 
            +
                # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
         | 
| 52 | 
            +
                n_fft = hparams['fft_size']
         | 
| 53 | 
            +
                num_mels = hparams['audio_num_mel_bins']
         | 
| 54 | 
            +
                sampling_rate = hparams['audio_sample_rate']
         | 
| 55 | 
            +
                hop_size = hparams['hop_size']
         | 
| 56 | 
            +
                win_size = hparams['win_size']
         | 
| 57 | 
            +
                fmin = hparams['fmin']
         | 
| 58 | 
            +
                fmax = hparams['fmax']
         | 
| 59 | 
            +
                y = y.clamp(min=-1., max=1.)
         | 
| 60 | 
            +
                global mel_basis, hann_window
         | 
| 61 | 
            +
                if fmax not in mel_basis:
         | 
| 62 | 
            +
                    mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
         | 
| 63 | 
            +
                    mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
         | 
| 64 | 
            +
                    hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
         | 
| 67 | 
            +
                                            mode='reflect')
         | 
| 68 | 
            +
                y = y.squeeze(1)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
         | 
| 71 | 
            +
                                  center=center, pad_mode='reflect', normalized=False, onesided=True)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                if not complex:
         | 
| 74 | 
            +
                    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
         | 
| 75 | 
            +
                    spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
         | 
| 76 | 
            +
                    spec = spectral_normalize_torch(spec)
         | 
| 77 | 
            +
                else:
         | 
| 78 | 
            +
                    B, C, T, _ = spec.shape
         | 
| 79 | 
            +
                    spec = spec.transpose(1, 2)  # [B, T, n_fft, 2]
         | 
| 80 | 
            +
                return spec
         | 
| 81 | 
            +
             | 
    	
        modules/parallel_wavegan/__init__.py
    ADDED
    
    | 
            File without changes
         | 
