AlexK-PL commited on
Commit
378de71
·
1 Parent(s): 30b323b

Delete melgan

Browse files
melgan/.gitignore DELETED
@@ -1,119 +0,0 @@
1
- # IDE configuration
2
- .idea/
3
-
4
- # configuration
5
- config/*
6
- !config/default.yaml
7
- temp-restore.yaml
8
-
9
- # logs, checkpoints
10
- chkpt/
11
- logs/
12
-
13
- # just a temporary folder
14
- temp/
15
-
16
- # Byte-compiled / optimized / DLL files
17
- __pycache__/
18
- *.py[cod]
19
- *$py.class
20
-
21
- # C extensions
22
- *.so
23
-
24
- # Distribution / packaging
25
- .Python
26
- build/
27
- develop-eggs/
28
- dist/
29
- downloads/
30
- eggs/
31
- .eggs/
32
- lib/
33
- lib64/
34
- parts/
35
- sdist/
36
- var/
37
- wheels/
38
- *.egg-info/
39
- .installed.cfg
40
- *.egg
41
- MANIFEST
42
-
43
- # PyInstaller
44
- # Usually these files are written by a python script from a template
45
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
46
- *.manifest
47
- *.spec
48
-
49
- # Installer logs
50
- pip-log.txt
51
- pip-delete-this-directory.txt
52
-
53
- # Unit test / coverage reports
54
- htmlcov/
55
- .tox/
56
- .coverage
57
- .coverage.*
58
- .cache
59
- nosetests.xml
60
- coverage.xml
61
- *.cover
62
- .hypothesis/
63
- .pytest_cache/
64
-
65
- # Translations
66
- *.mo
67
- *.pot
68
-
69
- # Django stuff:
70
- *.log
71
- local_settings.py
72
- db.sqlite3
73
-
74
- # Flask stuff:
75
- instance/
76
- .webassets-cache
77
-
78
- # Scrapy stuff:
79
- .scrapy
80
-
81
- # Sphinx documentation
82
- docs/_build/
83
-
84
- # PyBuilder
85
- target/
86
-
87
- # Jupyter Notebook
88
- .ipynb_checkpoints
89
-
90
- # pyenv
91
- .python-version
92
-
93
- # celery beat schedule file
94
- celerybeat-schedule
95
-
96
- # SageMath parsed files
97
- *.sage.py
98
-
99
- # Environments
100
- .env
101
- .venv
102
- env/
103
- venv/
104
- ENV/
105
- env.bak/
106
- venv.bak/
107
-
108
- # Spyder project settings
109
- .spyderproject
110
- .spyproject
111
-
112
- # Rope project settings
113
- .ropeproject
114
-
115
- # mkdocs documentation
116
- /site
117
-
118
- # mypy
119
- .mypy_cache/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/LICENSE DELETED
@@ -1,29 +0,0 @@
1
- BSD 3-Clause License
2
-
3
- Copyright (c) 2019, Seungwon Park 박승원
4
- All rights reserved.
5
-
6
- Redistribution and use in source and binary forms, with or without
7
- modification, are permitted provided that the following conditions are met:
8
-
9
- 1. Redistributions of source code must retain the above copyright notice, this
10
- list of conditions and the following disclaimer.
11
-
12
- 2. Redistributions in binary form must reproduce the above copyright notice,
13
- this list of conditions and the following disclaimer in the documentation
14
- and/or other materials provided with the distribution.
15
-
16
- 3. Neither the name of the copyright holder nor the names of its
17
- contributors may be used to endorse or promote products derived from
18
- this software without specific prior written permission.
19
-
20
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/README.md DELETED
@@ -1,83 +0,0 @@
1
- # MelGAN
2
- Unofficial PyTorch implementation of [MelGAN vocoder](https://arxiv.org/abs/1910.06711)
3
-
4
- ## Key Features
5
-
6
- - MelGAN is lighter, faster, and better at generalizing to unseen speakers than [WaveGlow](https://github.com/NVIDIA/waveglow).
7
- - This repository use identical mel-spectrogram function from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), so this can be directly used to convert output from NVIDIA's tacotron2 into raw-audio.
8
- - Pretrained model on LJSpeech-1.1 via [PyTorch Hub](https://pytorch.org/hub).
9
-
10
- ![](./assets/gd.png)
11
-
12
- ## Prerequisites
13
-
14
- Tested on Python 3.6
15
- ```bash
16
- pip install -r requirements.txt
17
- ```
18
-
19
- ## Prepare Dataset
20
-
21
- - Download dataset for training. This can be any wav files with sample rate 22050Hz. (e.g. LJSpeech was used in paper)
22
- - preprocess: `python preprocess.py -c config/default.yaml -d [data's root path]`
23
- - Edit configuration `yaml` file
24
-
25
- ## Train & Tensorboard
26
-
27
- - `python trainer.py -c [config yaml file] -n [name of the run]`
28
- - `cp config/default.yaml config/config.yaml` and then edit `config.yaml`
29
- - Write down the root path of train/validation files to 2nd/3rd line.
30
- - Each path should contain pairs of `*.wav` with corresponding (preprocessed) `*.mel` file.
31
- - The data loader parses list of files within the path recursively.
32
- - `tensorboard --logdir logs/`
33
-
34
- ## Pretrained model
35
-
36
- Try with Google Colab: TODO
37
-
38
- ```python
39
- import torch
40
- vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
41
- vocoder.eval()
42
- mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here
43
-
44
- if torch.cuda.is_available():
45
- vocoder = vocoder.cuda()
46
- mel = mel.cuda()
47
-
48
- with torch.no_grad():
49
- audio = vocoder.inference(mel)
50
- ```
51
-
52
- ## Inference
53
-
54
- - `python inference.py -p [checkpoint path] -i [input mel path]`
55
-
56
- ## Results
57
-
58
- See audio samples at: http://swpark.me/melgan/.
59
- Model was trained at V100 GPU for 14 days using LJSpeech-1.1.
60
-
61
- ![](./assets/lj-tensorboard-v0.3-alpha.png)
62
-
63
-
64
- ## Implementation Authors
65
-
66
- - [Seungwon Park](http://swpark.me) @ MINDsLab Inc. ([email protected], [email protected])
67
- - Myunchul Joe @ MINDsLab Inc.
68
- - [Rishikesh](https://github.com/rishikksh20) @ DeepSync Technologies Pvt Ltd.
69
-
70
- ## License
71
-
72
- BSD 3-Clause License.
73
-
74
- - [utils/stft.py](./utils/stft.py) by Prem Seetharaman (BSD 3-Clause License)
75
- - [datasets/mel2samp.py](./datasets/mel2samp.py) from https://github.com/NVIDIA/waveglow (BSD 3-Clause License)
76
- - [utils/hparams.py](./utils/hparams.py) from https://github.com/HarryVolek/PyTorch_Speaker_Verification (No License specified)
77
-
78
- ## Useful resources
79
-
80
- - [How to Train a GAN? Tips and tricks to make GANs work](https://github.com/soumith/ganhacks) by Soumith Chintala
81
- - [Official MelGAN implementation by original authors](https://github.com/descriptinc/melgan-neurips)
82
- - [Reproduction of MelGAN - NeurIPS 2019 Reproducibility Challenge (Ablation Track)](https://openreview.net/pdf?id=9jTbNbBNw0) by Yifei Zhao, Yichao Yang, and Yang Gao
83
- - "replacing the average pooling layer with max pooling layer and replacing reflection padding with replication padding improves the performance significantly, while combining them produces worse results"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/assets/gd.png DELETED
Binary file (114 kB)
 
melgan/assets/lj-tensorboard-v0.3-alpha.png DELETED
Binary file (61.1 kB)
 
melgan/assets/lj-tensorboard.png DELETED
Binary file (45 kB)
 
melgan/config/default.yaml DELETED
@@ -1,34 +0,0 @@
1
- data: # root path of train/validation data (either relative/absoulte path is ok)
2
- train: ''
3
- validation: ''
4
- ---
5
- train:
6
- rep_discriminator: 1
7
- num_workers: 32
8
- batch_size: 16
9
- optimizer: 'adam'
10
- adam:
11
- lr: 0.0001
12
- beta1: 0.5
13
- beta2: 0.9
14
- ---
15
- audio:
16
- n_mel_channels: 80
17
- segment_length: 16000
18
- pad_short: 2000
19
- filter_length: 1024
20
- hop_length: 256 # WARNING: this can't be changed.
21
- win_length: 1024
22
- sampling_rate: 22050
23
- mel_fmin: 0.0
24
- mel_fmax: 8000.0
25
- ---
26
- model:
27
- feat_match: 10.0
28
- ---
29
- log:
30
- summary_interval: 1
31
- validation_interval: 5
32
- save_interval: 25
33
- chkpt_dir: 'chkpt'
34
- log_dir: 'logs'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/datasets/dataloader.py DELETED
@@ -1,67 +0,0 @@
1
- import os
2
- import glob
3
- import torch
4
- import random
5
- import numpy as np
6
- from torch.utils.data import Dataset, DataLoader
7
-
8
- from utils.utils import read_wav_np
9
-
10
-
11
- def create_dataloader(hp, args, train):
12
- dataset = MelFromDisk(hp, args, train)
13
-
14
- if train:
15
- return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True,
16
- num_workers=hp.train.num_workers, pin_memory=True, drop_last=True)
17
- else:
18
- return DataLoader(dataset=dataset, batch_size=1, shuffle=False,
19
- num_workers=hp.train.num_workers, pin_memory=True, drop_last=False)
20
-
21
-
22
- class MelFromDisk(Dataset):
23
- def __init__(self, hp, args, train):
24
- self.hp = hp
25
- self.args = args
26
- self.train = train
27
- self.path = hp.data.train if train else hp.data.validation
28
- self.wav_list = glob.glob(os.path.join(self.path, '**', '*.wav'), recursive=True)
29
- self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length + 2
30
- self.mapping = [i for i in range(len(self.wav_list))]
31
-
32
- def __len__(self):
33
- return len(self.wav_list)
34
-
35
- def __getitem__(self, idx):
36
- if self.train:
37
- idx1 = idx
38
- idx2 = self.mapping[idx1]
39
- return self.my_getitem(idx1), self.my_getitem(idx2)
40
- else:
41
- return self.my_getitem(idx)
42
-
43
- def shuffle_mapping(self):
44
- random.shuffle(self.mapping)
45
-
46
- def my_getitem(self, idx):
47
- wavpath = self.wav_list[idx]
48
- melpath = wavpath.replace('.wav', '.mel')
49
- sr, audio = read_wav_np(wavpath)
50
- if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short:
51
- audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \
52
- mode='constant', constant_values=0.0)
53
-
54
- audio = torch.from_numpy(audio).unsqueeze(0)
55
- mel = torch.load(melpath).squeeze(0)
56
-
57
- if self.train:
58
- max_mel_start = mel.size(1) - self.mel_segment_length
59
- mel_start = random.randint(0, max_mel_start)
60
- mel_end = mel_start + self.mel_segment_length
61
- mel = mel[:, mel_start:mel_end]
62
-
63
- audio_start = mel_start * self.hp.audio.hop_length
64
- audio = audio[:, audio_start:audio_start+self.hp.audio.segment_length]
65
-
66
- audio = audio + (1/32768) * torch.randn_like(audio)
67
- return mel, audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/hubconf.py DELETED
@@ -1,41 +0,0 @@
1
- dependencies = ['torch']
2
- import torch
3
- from model.generator import Generator
4
-
5
- model_params = {
6
- 'nvidia_tacotron2_LJ11_epoch6400': {
7
- 'mel_channel': 80,
8
- 'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.3-alpha/nvidia_tacotron2_LJ11_epoch6400.pt',
9
- },
10
- }
11
-
12
-
13
- def melgan(model_name='nvidia_tacotron2_LJ11_epoch6400', pretrained=True, progress=True):
14
- params = model_params[model_name]
15
- model = Generator(params['mel_channel'])
16
-
17
- if pretrained:
18
- state_dict = torch.hub.load_state_dict_from_url(params['model_url'],
19
- progress=progress)
20
- model.load_state_dict(state_dict['model_g'])
21
-
22
- model.eval(inference=True)
23
-
24
- return model
25
-
26
-
27
- if __name__ == '__main__':
28
- vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
29
- mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here
30
-
31
- print('Input mel-spectrogram shape: {}'.format(mel.shape))
32
-
33
- if torch.cuda.is_available():
34
- print('Moving data & model to GPU')
35
- vocoder = vocoder.cuda()
36
- mel = mel.cuda()
37
-
38
- with torch.no_grad():
39
- audio = vocoder.inference(mel)
40
-
41
- print('Output audio shape: {}'.format(audio.shape))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/inference.py DELETED
@@ -1,49 +0,0 @@
1
- import os
2
- import glob
3
- import tqdm
4
- import torch
5
- import argparse
6
- from scipy.io.wavfile import write
7
-
8
- from model.generator import Generator
9
- from utils.hparams import HParam, load_hparam_str
10
-
11
- MAX_WAV_VALUE = 32768.0
12
-
13
-
14
- def main(args):
15
- checkpoint = torch.load(args.checkpoint_path)
16
- if args.config is not None:
17
- hp = HParam(args.config)
18
- else:
19
- hp = load_hparam_str(checkpoint['hp_str'])
20
-
21
- model = Generator(hp.audio.n_mel_channels).cuda()
22
- model.load_state_dict(checkpoint['model_g'])
23
- model.eval(inference=False)
24
-
25
- with torch.no_grad():
26
- for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
27
- mel = torch.load(melpath)
28
- if len(mel.shape) == 2:
29
- mel = mel.unsqueeze(0)
30
- mel = mel.cuda()
31
-
32
- audio = model.inference(mel)
33
- audio = audio.cpu().detach().numpy()
34
-
35
- out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
36
- write(out_path, hp.audio.sampling_rate, audio)
37
-
38
-
39
- if __name__ == '__main__':
40
- parser = argparse.ArgumentParser()
41
- parser.add_argument('-c', '--config', type=str, default=None,
42
- help="yaml file for config. will use hp_str from checkpoint if not given.")
43
- parser.add_argument('-p', '--checkpoint_path', type=str, required=True,
44
- help="path of checkpoint pt file for evaluation")
45
- parser.add_argument('-i', '--input_folder', type=str, required=True,
46
- help="directory of mel-spectrograms to invert into raw audio. ")
47
- args = parser.parse_args()
48
-
49
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/model/discriminator.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class Discriminator(nn.Module):
7
- def __init__(self):
8
- super(Discriminator, self).__init__()
9
-
10
- self.discriminator = nn.ModuleList([
11
- nn.Sequential(
12
- nn.ReflectionPad1d(7),
13
- nn.utils.weight_norm(nn.Conv1d(1, 16, kernel_size=15, stride=1)),
14
- nn.LeakyReLU(0.2, inplace=True),
15
- ),
16
- nn.Sequential(
17
- nn.utils.weight_norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, padding=20, groups=4)),
18
- nn.LeakyReLU(0.2, inplace=True),
19
- ),
20
- nn.Sequential(
21
- nn.utils.weight_norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, padding=20, groups=16)),
22
- nn.LeakyReLU(0.2, inplace=True),
23
- ),
24
- nn.Sequential(
25
- nn.utils.weight_norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, padding=20, groups=64)),
26
- nn.LeakyReLU(0.2, inplace=True),
27
- ),
28
- nn.Sequential(
29
- nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, padding=20, groups=256)),
30
- nn.LeakyReLU(0.2, inplace=True),
31
- ),
32
- nn.Sequential(
33
- nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=2)),
34
- nn.LeakyReLU(0.2, inplace=True),
35
- ),
36
- nn.utils.weight_norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1)),
37
- ])
38
-
39
- def forward(self, x):
40
- '''
41
- returns: (list of 6 features, discriminator score)
42
- we directly predict score without last sigmoid function
43
- since we're using Least Squares GAN (https://arxiv.org/abs/1611.04076)
44
- '''
45
- features = list()
46
- for module in self.discriminator:
47
- x = module(x)
48
- features.append(x)
49
- return features[:-1], features[-1]
50
-
51
-
52
- if __name__ == '__main__':
53
- model = Discriminator()
54
-
55
- x = torch.randn(3, 1, 22050)
56
- print(x.shape)
57
-
58
- features, score = model(x)
59
- for feat in features:
60
- print(feat.shape)
61
- print(score.shape)
62
-
63
- pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
64
- print(pytorch_total_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/model/generator.py DELETED
@@ -1,99 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from .res_stack import ResStack
6
- # from res_stack import ResStack
7
-
8
- MAX_WAV_VALUE = 32768.0
9
-
10
-
11
- class Generator(nn.Module):
12
- def __init__(self, mel_channel):
13
- super(Generator, self).__init__()
14
- self.mel_channel = mel_channel
15
-
16
- self.generator = nn.Sequential(
17
- nn.ReflectionPad1d(3),
18
- nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1)),
19
-
20
- nn.LeakyReLU(0.2),
21
- nn.utils.weight_norm(nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4)),
22
-
23
- ResStack(256),
24
-
25
- nn.LeakyReLU(0.2),
26
- nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)),
27
-
28
- ResStack(128),
29
-
30
- nn.LeakyReLU(0.2),
31
- nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)),
32
-
33
- ResStack(64),
34
-
35
- nn.LeakyReLU(0.2),
36
- nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)),
37
-
38
- ResStack(32),
39
-
40
- nn.LeakyReLU(0.2),
41
- nn.ReflectionPad1d(3),
42
- nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)),
43
- nn.Tanh(),
44
- )
45
-
46
- def forward(self, mel):
47
- mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
48
- return self.generator(mel)
49
-
50
- def eval(self, inference=False):
51
- super(Generator, self).eval()
52
-
53
- # don't remove weight norm while validation in training loop
54
- if inference:
55
- self.remove_weight_norm()
56
-
57
- def remove_weight_norm(self):
58
- for idx, layer in enumerate(self.generator):
59
- if len(layer.state_dict()) != 0:
60
- try:
61
- nn.utils.remove_weight_norm(layer)
62
- except:
63
- layer.remove_weight_norm()
64
-
65
- def inference(self, mel):
66
- hop_length = 256
67
- # pad input mel with zeros to cut artifact
68
- # see https://github.com/seungwonpark/melgan/issues/8
69
- zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device)
70
- mel = torch.cat((mel, zero), dim=2)
71
-
72
- audio = self.forward(mel)
73
- audio = audio.squeeze() # collapse all dimension except time axis
74
- audio = audio[:-(hop_length*10)]
75
- audio = MAX_WAV_VALUE * audio
76
- audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
77
- audio = audio.short()
78
-
79
- return audio
80
-
81
-
82
- '''
83
- to run this, fix
84
- from . import ResStack
85
- into
86
- from res_stack import ResStack
87
- '''
88
- if __name__ == '__main__':
89
- model = Generator(80)
90
-
91
- x = torch.randn(3, 80, 10)
92
- print(x.shape)
93
-
94
- y = model(x)
95
- print(y.shape)
96
- assert y.shape == torch.Size([3, 1, 2560])
97
-
98
- pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
99
- print(pytorch_total_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/model/identity.py DELETED
@@ -1,12 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class Identity(nn.Module):
7
- def __init__(self):
8
- super(Identity, self).__init__()
9
-
10
- def forward(self, x):
11
- return x
12
-
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/model/multiscale.py DELETED
@@ -1,29 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from .discriminator import Discriminator
6
- from .identity import Identity
7
-
8
-
9
- class MultiScaleDiscriminator(nn.Module):
10
- def __init__(self):
11
- super(MultiScaleDiscriminator, self).__init__()
12
-
13
- self.discriminators = nn.ModuleList(
14
- [Discriminator() for _ in range(3)]
15
- )
16
-
17
- self.pooling = nn.ModuleList(
18
- [Identity()] +
19
- [nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)]
20
- )
21
-
22
- def forward(self, x):
23
- ret = list()
24
-
25
- for pool, disc in zip(self.pooling, self.discriminators):
26
- x = pool(x)
27
- ret.append(disc(x))
28
-
29
- return ret # [(feat, score), (feat, score), (feat, score)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/model/res_stack.py DELETED
@@ -1,36 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
-
6
-
7
- class ResStack(nn.Module):
8
- def __init__(self, channel):
9
- super(ResStack, self).__init__()
10
-
11
- self.blocks = nn.ModuleList([
12
- nn.Sequential(
13
- nn.LeakyReLU(0.2),
14
- nn.ReflectionPad1d(3**i),
15
- nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i)),
16
- nn.LeakyReLU(0.2),
17
- nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
18
- )
19
- for i in range(3)
20
- ])
21
-
22
- self.shortcuts = nn.ModuleList([
23
- nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
24
- for i in range(3)
25
- ])
26
-
27
- def forward(self, x):
28
- for block, shortcut in zip(self.blocks, self.shortcuts):
29
- x = shortcut(x) + block(x)
30
- return x
31
-
32
- def remove_weight_norm(self):
33
- for block, shortcut in zip(self.blocks, self.shortcuts):
34
- nn.utils.remove_weight_norm(block[2])
35
- nn.utils.remove_weight_norm(block[4])
36
- nn.utils.remove_weight_norm(shortcut)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/preprocess.py DELETED
@@ -1,50 +0,0 @@
1
- import os
2
- import glob
3
- import tqdm
4
- import torch
5
- import argparse
6
- import numpy as np
7
-
8
- from utils.stft import TacotronSTFT
9
- from utils.hparams import HParam
10
- from utils.utils import read_wav_np
11
-
12
-
13
- def main(hp, args):
14
- stft = TacotronSTFT(filter_length=hp.audio.filter_length,
15
- hop_length=hp.audio.hop_length,
16
- win_length=hp.audio.win_length,
17
- n_mel_channels=hp.audio.n_mel_channels,
18
- sampling_rate=hp.audio.sampling_rate,
19
- mel_fmin=hp.audio.mel_fmin,
20
- mel_fmax=hp.audio.mel_fmax)
21
-
22
- wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True)
23
-
24
- for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
25
- sr, wav = read_wav_np(wavpath)
26
- assert sr == hp.audio.sampling_rate, \
27
- "sample rate mismatch. expected %d, got %d at %s" % \
28
- (hp.audio.sampling_rate, sr, wavpath)
29
-
30
- if len(wav) < hp.audio.segment_length + hp.audio.pad_short:
31
- wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \
32
- mode='constant', constant_values=0.0)
33
-
34
- wav = torch.from_numpy(wav).unsqueeze(0)
35
- mel = stft.mel_spectrogram(wav)
36
-
37
- melpath = wavpath.replace('.wav', '.mel')
38
- torch.save(mel, melpath)
39
-
40
-
41
- if __name__ == '__main__':
42
- parser = argparse.ArgumentParser()
43
- parser.add_argument('-c', '--config', type=str, required=True,
44
- help="yaml file for config.")
45
- parser.add_argument('-d', '--data_path', type=str, required=True,
46
- help="root directory of wav files")
47
- args = parser.parse_args()
48
- hp = HParam(args.config)
49
-
50
- main(hp, args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/requirements.txt DELETED
@@ -1,9 +0,0 @@
1
- librosa
2
- matplotlib
3
- numpy
4
- scipy
5
- tensorboardX
6
- torch
7
- tqdm
8
- pillow
9
- pyyaml
 
 
 
 
 
 
 
 
 
 
melgan/trainer.py DELETED
@@ -1,52 +0,0 @@
1
- import os
2
- import time
3
- import logging
4
- import argparse
5
-
6
- from utils.train import train
7
- from utils.hparams import HParam
8
- from utils.writer import MyWriter
9
- from datasets.dataloader import create_dataloader
10
-
11
-
12
- if __name__ == '__main__':
13
- parser = argparse.ArgumentParser()
14
- parser.add_argument('-c', '--config', type=str, required=True,
15
- help="yaml file for configuration")
16
- parser.add_argument('-p', '--checkpoint_path', type=str, default=None,
17
- help="path of checkpoint pt file to resume training")
18
- parser.add_argument('-n', '--name', type=str, required=True,
19
- help="name of the model for logging, saving checkpoint")
20
- args = parser.parse_args()
21
-
22
- hp = HParam(args.config)
23
- with open(args.config, 'r') as f:
24
- hp_str = ''.join(f.readlines())
25
-
26
- pt_dir = os.path.join(hp.log.chkpt_dir, args.name)
27
- log_dir = os.path.join(hp.log.log_dir, args.name)
28
- os.makedirs(pt_dir, exist_ok=True)
29
- os.makedirs(log_dir, exist_ok=True)
30
-
31
- logging.basicConfig(
32
- level=logging.INFO,
33
- format='%(asctime)s - %(levelname)s - %(message)s',
34
- handlers=[
35
- logging.FileHandler(os.path.join(log_dir,
36
- '%s-%d.log' % (args.name, time.time()))),
37
- logging.StreamHandler()
38
- ]
39
- )
40
- logger = logging.getLogger()
41
-
42
- writer = MyWriter(hp, log_dir)
43
-
44
- assert hp.audio.hop_length == 256, \
45
- 'hp.audio.hop_length must be equal to 256, got %d' % hp.audio.hop_length
46
- assert hp.data.train != '' and hp.data.validation != '', \
47
- 'hp.data.train and hp.data.validation can\'t be empty: please fix %s' % args.config
48
-
49
- trainloader = create_dataloader(hp, args, True)
50
- valloader = create_dataloader(hp, args, False)
51
-
52
- train(args, pt_dir, args.checkpoint_path, trainloader, valloader, writer, logger, hp, hp_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/utils/audio_processing.py DELETED
@@ -1,93 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from scipy.signal import get_window
4
- import librosa.util as librosa_util
5
-
6
-
7
- def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
8
- n_fft=800, dtype=np.float32, norm=None):
9
- """
10
- # from librosa 0.6
11
- Compute the sum-square envelope of a window function at a given hop length.
12
-
13
- This is used to estimate modulation effects induced by windowing
14
- observations in short-time fourier transforms.
15
-
16
- Parameters
17
- ----------
18
- window : string, tuple, number, callable, or list-like
19
- Window specification, as in `get_window`
20
-
21
- n_frames : int > 0
22
- The number of analysis frames
23
-
24
- hop_length : int > 0
25
- The number of samples to advance between frames
26
-
27
- win_length : [optional]
28
- The length of the window function. By default, this matches `n_fft`.
29
-
30
- n_fft : int > 0
31
- The length of each analysis frame.
32
-
33
- dtype : np.dtype
34
- The data type of the output
35
-
36
- Returns
37
- -------
38
- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
39
- The sum-squared envelope of the window function
40
- """
41
- if win_length is None:
42
- win_length = n_fft
43
-
44
- n = n_fft + hop_length * (n_frames - 1)
45
- x = np.zeros(n, dtype=dtype)
46
-
47
- # Compute the squared window at the desired length
48
- win_sq = get_window(window, win_length, fftbins=True)
49
- win_sq = librosa_util.normalize(win_sq, norm=norm)**2
50
- win_sq = librosa_util.pad_center(win_sq, n_fft)
51
-
52
- # Fill the envelope
53
- for i in range(n_frames):
54
- sample = i * hop_length
55
- x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
56
- return x
57
-
58
-
59
- def griffin_lim(magnitudes, stft_fn, n_iters=30):
60
- """
61
- PARAMS
62
- ------
63
- magnitudes: spectrogram magnitudes
64
- stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
65
- """
66
-
67
- angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
68
- angles = angles.astype(np.float32)
69
- angles = torch.autograd.Variable(torch.from_numpy(angles))
70
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
71
-
72
- for i in range(n_iters):
73
- _, angles = stft_fn.transform(signal)
74
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
75
- return signal
76
-
77
-
78
- def dynamic_range_compression(x, C=1, clip_val=1e-5):
79
- """
80
- PARAMS
81
- ------
82
- C: compression factor
83
- """
84
- return torch.log(torch.clamp(x, min=clip_val) * C)
85
-
86
-
87
- def dynamic_range_decompression(x, C=1):
88
- """
89
- PARAMS
90
- ------
91
- C: compression factor used to compress
92
- """
93
- return torch.exp(x) / C
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/utils/hparams.py DELETED
@@ -1,67 +0,0 @@
1
- # modified from https://github.com/HarryVolek/PyTorch_Speaker_Verification
2
-
3
- import os
4
- import yaml
5
-
6
-
7
- def load_hparam_str(hp_str):
8
- path = 'temp-restore.yaml'
9
- with open(path, 'w') as f:
10
- f.write(hp_str)
11
- ret = HParam(path)
12
- os.remove(path)
13
- return ret
14
-
15
-
16
- def load_hparam(filename):
17
- stream = open(filename, 'r')
18
- docs = yaml.load_all(stream, Loader=yaml.Loader)
19
- hparam_dict = dict()
20
- for doc in docs:
21
- for k, v in doc.items():
22
- hparam_dict[k] = v
23
- return hparam_dict
24
-
25
-
26
- def merge_dict(user, default):
27
- if isinstance(user, dict) and isinstance(default, dict):
28
- for k, v in default.items():
29
- if k not in user:
30
- user[k] = v
31
- else:
32
- user[k] = merge_dict(user[k], v)
33
- return user
34
-
35
-
36
- class Dotdict(dict):
37
- """
38
- a dictionary that supports dot notation
39
- as well as dictionary access notation
40
- usage: d = DotDict() or d = DotDict({'val1':'first'})
41
- set attributes: d.val2 = 'second' or d['val2'] = 'second'
42
- get attributes: d.val2 or d['val2']
43
- """
44
- __getattr__ = dict.__getitem__
45
- __setattr__ = dict.__setitem__
46
- __delattr__ = dict.__delitem__
47
-
48
- def __init__(self, dct=None):
49
- dct = dict() if not dct else dct
50
- for key, value in dct.items():
51
- if hasattr(value, 'keys'):
52
- value = Dotdict(value)
53
- self[key] = value
54
-
55
-
56
- class HParam(Dotdict):
57
-
58
- def __init__(self, file):
59
- super(Dotdict, self).__init__()
60
- hp_dict = load_hparam(file)
61
- hp_dotdict = Dotdict(hp_dict)
62
- for k, v in hp_dotdict.items():
63
- setattr(self, k, v)
64
-
65
- __getattr__ = Dotdict.__getitem__
66
- __setattr__ = Dotdict.__setitem__
67
- __delattr__ = Dotdict.__delitem__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/utils/plotting.py DELETED
@@ -1,29 +0,0 @@
1
- import matplotlib
2
- matplotlib.use("Agg")
3
- import matplotlib.pylab as plt
4
- import numpy as np
5
-
6
-
7
- def save_figure_to_numpy(fig):
8
- # save it to a numpy array.
9
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
10
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
11
- data = np.transpose(data, (2, 0, 1))
12
- return data
13
-
14
-
15
- def plot_waveform_to_numpy(waveform):
16
- fig, ax = plt.subplots(figsize=(12, 3))
17
- ax.plot()
18
- ax.plot(range(len(waveform)), waveform,
19
- linewidth=0.1, alpha=0.7, color='blue')
20
-
21
- plt.xlabel("Samples")
22
- plt.ylabel("Amplitude")
23
- plt.ylim(-1, 1)
24
- plt.tight_layout()
25
-
26
- fig.canvas.draw()
27
- data = save_figure_to_numpy(fig)
28
- plt.close()
29
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/utils/stft.py DELETED
@@ -1,184 +0,0 @@
1
- """
2
- BSD 3-Clause License
3
-
4
- Copyright (c) 2017, Prem Seetharaman
5
- All rights reserved.
6
-
7
- * Redistribution and use in source and binary forms, with or without
8
- modification, are permitted provided that the following conditions are met:
9
-
10
- * Redistributions of source code must retain the above copyright notice,
11
- this list of conditions and the following disclaimer.
12
-
13
- * Redistributions in binary form must reproduce the above copyright notice, this
14
- list of conditions and the following disclaimer in the
15
- documentation and/or other materials provided with the distribution.
16
-
17
- * Neither the name of the copyright holder nor the names of its
18
- contributors may be used to endorse or promote products derived from this
19
- software without specific prior written permission.
20
-
21
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
- ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
- (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
- ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
- SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
- """
32
-
33
- import torch
34
- import numpy as np
35
- import torch.nn.functional as F
36
- from torch.autograd import Variable
37
- from scipy.signal import get_window
38
- from librosa.util import pad_center, tiny
39
- from .audio_processing import window_sumsquare, dynamic_range_compression, dynamic_range_decompression
40
- from librosa.filters import mel as librosa_mel_fn
41
-
42
-
43
- class STFT(torch.nn.Module):
44
- """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
45
- def __init__(self, filter_length=800, hop_length=200, win_length=800,
46
- window='hann'):
47
- super(STFT, self).__init__()
48
- self.filter_length = filter_length
49
- self.hop_length = hop_length
50
- self.win_length = win_length
51
- self.window = window
52
- self.forward_transform = None
53
- scale = self.filter_length / self.hop_length
54
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
55
-
56
- cutoff = int((self.filter_length / 2 + 1))
57
- fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
58
- np.imag(fourier_basis[:cutoff, :])])
59
-
60
- forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
61
- inverse_basis = torch.FloatTensor(
62
- np.linalg.pinv(scale * fourier_basis).T[:, None, :])
63
-
64
- if window is not None:
65
- assert(filter_length >= win_length)
66
- # get window and zero center pad it to filter_length
67
- fft_window = get_window(window, win_length, fftbins=True)
68
- fft_window = pad_center(fft_window, filter_length)
69
- fft_window = torch.from_numpy(fft_window).float()
70
-
71
- # window the bases
72
- forward_basis *= fft_window
73
- inverse_basis *= fft_window
74
-
75
- self.register_buffer('forward_basis', forward_basis.float())
76
- self.register_buffer('inverse_basis', inverse_basis.float())
77
-
78
- def transform(self, input_data):
79
- num_batches = input_data.size(0)
80
- num_samples = input_data.size(1)
81
-
82
- self.num_samples = num_samples
83
-
84
- # similar to librosa, reflect-pad the input
85
- input_data = input_data.view(num_batches, 1, num_samples)
86
- input_data = F.pad(
87
- input_data.unsqueeze(1),
88
- (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
89
- mode='reflect')
90
- input_data = input_data.squeeze(1)
91
-
92
- # https://github.com/NVIDIA/tacotron2/issues/125
93
- forward_transform = F.conv1d(
94
- input_data, # cuda()
95
- Variable(self.forward_basis, requires_grad=False), # cuda()
96
- stride=self.hop_length,
97
- padding=0).cpu()
98
-
99
- cutoff = int((self.filter_length / 2) + 1)
100
- real_part = forward_transform[:, :cutoff, :]
101
- imag_part = forward_transform[:, cutoff:, :]
102
-
103
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
104
- phase = torch.autograd.Variable(
105
- torch.atan2(imag_part.data, real_part.data))
106
-
107
- return magnitude, phase
108
-
109
- def inverse(self, magnitude, phase):
110
- recombine_magnitude_phase = torch.cat(
111
- [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
112
-
113
- inverse_transform = F.conv_transpose1d(
114
- recombine_magnitude_phase,
115
- Variable(self.inverse_basis, requires_grad=False),
116
- stride=self.hop_length,
117
- padding=0)
118
-
119
- if self.window is not None:
120
- window_sum = window_sumsquare(
121
- self.window, magnitude.size(-1), hop_length=self.hop_length,
122
- win_length=self.win_length, n_fft=self.filter_length,
123
- dtype=np.float32)
124
- # remove modulation effects
125
- approx_nonzero_indices = torch.from_numpy(
126
- np.where(window_sum > tiny(window_sum))[0])
127
- window_sum = torch.autograd.Variable(
128
- torch.from_numpy(window_sum), requires_grad=False)
129
- # window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
130
- inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
131
-
132
- # scale by hop ratio
133
- inverse_transform *= float(self.filter_length) / self.hop_length
134
-
135
- inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
136
- inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
137
-
138
- return inverse_transform
139
-
140
- def forward(self, input_data):
141
- self.magnitude, self.phase = self.transform(input_data)
142
- reconstruction = self.inverse(self.magnitude, self.phase)
143
- return reconstruction
144
-
145
-
146
- class TacotronSTFT(torch.nn.Module):
147
- def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
148
- n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
149
- mel_fmax=None):
150
- super(TacotronSTFT, self).__init__()
151
- self.n_mel_channels = n_mel_channels
152
- self.sampling_rate = sampling_rate
153
- self.stft_fn = STFT(filter_length, hop_length, win_length)
154
- mel_basis = librosa_mel_fn(
155
- sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
156
- mel_basis = torch.from_numpy(mel_basis).float()
157
- self.register_buffer('mel_basis', mel_basis)
158
-
159
- def spectral_normalize(self, magnitudes):
160
- output = dynamic_range_compression(magnitudes)
161
- return output
162
-
163
- def spectral_de_normalize(self, magnitudes):
164
- output = dynamic_range_decompression(magnitudes)
165
- return output
166
-
167
- def mel_spectrogram(self, y):
168
- """Computes mel-spectrograms from a batch of waves
169
- PARAMS
170
- ------
171
- y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
172
-
173
- RETURNS
174
- -------
175
- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
176
- """
177
- assert(torch.min(y.data) >= -1)
178
- assert(torch.max(y.data) <= 1)
179
-
180
- magnitudes, phases = self.stft_fn.transform(y)
181
- magnitudes = magnitudes.data
182
- mel_output = torch.matmul(self.mel_basis, magnitudes)
183
- mel_output = self.spectral_normalize(mel_output)
184
- return mel_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/utils/train.py DELETED
@@ -1,131 +0,0 @@
1
- import os
2
- import math
3
- import tqdm
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import itertools
8
- import traceback
9
-
10
- from model.generator import Generator
11
- from model.multiscale import MultiScaleDiscriminator
12
- from .utils import get_commit_hash
13
- from .validation import validate
14
-
15
-
16
- def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str):
17
- model_g = Generator(hp.audio.n_mel_channels) # cuda()
18
- model_d = MultiScaleDiscriminator() # cuda()
19
-
20
- optim_g = torch.optim.Adam(model_g.parameters(),
21
- lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
22
- optim_d = torch.optim.Adam(model_d.parameters(),
23
- lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
24
-
25
- githash = get_commit_hash()
26
-
27
- init_epoch = -1
28
- step = 0
29
-
30
- if chkpt_path is not None:
31
- logger.info("Resuming from checkpoint: %s" % chkpt_path)
32
- checkpoint = torch.load(chkpt_path)
33
- model_g.load_state_dict(checkpoint['model_g'])
34
- model_d.load_state_dict(checkpoint['model_d'])
35
- optim_g.load_state_dict(checkpoint['optim_g'])
36
- optim_d.load_state_dict(checkpoint['optim_d'])
37
- step = checkpoint['step']
38
- init_epoch = checkpoint['epoch']
39
-
40
- if hp_str != checkpoint['hp_str']:
41
- logger.warning("New hparams is different from checkpoint. Will use new.")
42
-
43
- if githash != checkpoint['githash']:
44
- logger.warning("Code might be different: git hash is different.")
45
- logger.warning("%s -> %s" % (checkpoint['githash'], githash))
46
-
47
- else:
48
- logger.info("Starting new training run.")
49
-
50
- # this accelerates training when the size of minibatch is always consistent.
51
- # if not consistent, it'll horribly slow down.
52
- torch.backends.cudnn.benchmark = True
53
-
54
- try:
55
- model_g.train()
56
- model_d.train()
57
- for epoch in itertools.count(init_epoch+1):
58
- if epoch % hp.log.validation_interval == 0:
59
- with torch.no_grad():
60
- validate(hp, args, model_g, model_d, valloader, writer, step)
61
-
62
- trainloader.dataset.shuffle_mapping()
63
- loader = tqdm.tqdm(trainloader, desc='Loading train data')
64
- for (melG, audioG), (melD, audioD) in loader:
65
- # melG = melG.cuda()
66
- # audioG = audioG.cuda()
67
- # melD = melD.cuda()
68
- # audioD = audioD.cuda()
69
-
70
- # generator
71
- optim_g.zero_grad()
72
- fake_audio = model_g(melG)[:, :, :hp.audio.segment_length]
73
- disc_fake = model_d(fake_audio)
74
- disc_real = model_d(audioG)
75
- loss_g = 0.0
76
- for (feats_fake, score_fake), (feats_real, _) in zip(disc_fake, disc_real):
77
- loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2]))
78
- for feat_f, feat_r in zip(feats_fake, feats_real):
79
- loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
80
-
81
- loss_g.backward()
82
- optim_g.step()
83
-
84
- # discriminator
85
- fake_audio = model_g(melD)[:, :, :hp.audio.segment_length]
86
- fake_audio = fake_audio.detach()
87
- loss_d_sum = 0.0
88
- for _ in range(hp.train.rep_discriminator):
89
- optim_d.zero_grad()
90
- disc_fake = model_d(fake_audio)
91
- disc_real = model_d(audioD)
92
- loss_d = 0.0
93
- for (_, score_fake), (_, score_real) in zip(disc_fake, disc_real):
94
- loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
95
- loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
96
-
97
- loss_d.backward()
98
- optim_d.step()
99
- loss_d_sum += loss_d
100
-
101
- step += 1
102
- # logging
103
- loss_g = loss_g.item()
104
- loss_d_avg = loss_d_sum / hp.train.rep_discriminator
105
- loss_d_avg = loss_d_avg.item()
106
- if any([loss_g > 1e8, math.isnan(loss_g), loss_d_avg > 1e8, math.isnan(loss_d_avg)]):
107
- logger.error("loss_g %.01f loss_d_avg %.01f at step %d!" % (loss_g, loss_d_avg, step))
108
- raise Exception("Loss exploded")
109
-
110
- if step % hp.log.summary_interval == 0:
111
- writer.log_training(loss_g, loss_d_avg, step)
112
- loader.set_description("g %.04f d %.04f | step %d" % (loss_g, loss_d_avg, step))
113
-
114
- if epoch % hp.log.save_interval == 0:
115
- save_path = os.path.join(pt_dir, '%s_%s_%04d.pt'
116
- % (args.name, githash, epoch))
117
- torch.save({
118
- 'model_g': model_g.state_dict(),
119
- 'model_d': model_d.state_dict(),
120
- 'optim_g': optim_g.state_dict(),
121
- 'optim_d': optim_d.state_dict(),
122
- 'step': step,
123
- 'epoch': epoch,
124
- 'hp_str': hp_str,
125
- 'githash': githash,
126
- }, save_path)
127
- logger.info("Saved checkpoint to: %s" % save_path)
128
-
129
- except Exception as e:
130
- logger.info("Exiting due to exception: %s" % e)
131
- traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/utils/utils.py DELETED
@@ -1,26 +0,0 @@
1
- import random
2
- import subprocess
3
- import numpy as np
4
- from scipy.io.wavfile import read
5
-
6
-
7
- def get_commit_hash():
8
- message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
9
- return message.strip().decode('utf-8')
10
-
11
- def read_wav_np(path):
12
- sr, wav = read(path)
13
-
14
- if len(wav.shape) == 2:
15
- wav = wav[:, 0]
16
-
17
- if wav.dtype == np.int16:
18
- wav = wav / 32768.0
19
- elif wav.dtype == np.int32:
20
- wav = wav / 2147483648.0
21
- elif wav.dtype == np.uint8:
22
- wav = (wav - 128) / 128.0
23
-
24
- wav = wav.astype(np.float32)
25
-
26
- return sr, wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/utils/validation.py DELETED
@@ -1,41 +0,0 @@
1
- import tqdm
2
- import torch
3
-
4
-
5
- def validate(hp, args, generator, discriminator, valloader, writer, step):
6
- generator.eval()
7
- discriminator.eval()
8
- torch.backends.cudnn.benchmark = False
9
-
10
- loader = tqdm.tqdm(valloader, desc='Validation loop')
11
- loss_g_sum = 0.0
12
- loss_d_sum = 0.0
13
- for mel, audio in loader:
14
- # mel = mel.cuda()
15
- # audio = audio.cuda()
16
-
17
- # generator
18
- fake_audio = generator(mel)
19
- disc_fake = discriminator(fake_audio[:, :, :audio.size(2)])
20
- disc_real = discriminator(audio)
21
- loss_g = 0.0
22
- loss_d = 0.0
23
- for (feats_fake, score_fake), (feats_real, score_real) in zip(disc_fake, disc_real):
24
- loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2]))
25
- for feat_f, feat_r in zip(feats_fake, feats_real):
26
- loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
27
- loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
28
- loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
29
-
30
- loss_g_sum += loss_g.item()
31
- loss_d_sum += loss_d.item()
32
-
33
- loss_g_avg = loss_g_sum / len(valloader.dataset)
34
- loss_d_avg = loss_d_sum / len(valloader.dataset)
35
-
36
- audio = audio[0][0].cpu().detach().numpy()
37
- fake_audio = fake_audio[0][0].cpu().detach().numpy()
38
-
39
- writer.log_validation(loss_g_avg, loss_d_avg, generator, discriminator, audio, fake_audio, step)
40
-
41
- torch.backends.cudnn.benchmark = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melgan/utils/writer.py DELETED
@@ -1,33 +0,0 @@
1
- from tensorboardX import SummaryWriter
2
-
3
- from .plotting import plot_waveform_to_numpy
4
-
5
-
6
- class MyWriter(SummaryWriter):
7
- def __init__(self, hp, logdir):
8
- super(MyWriter, self).__init__(logdir)
9
- self.sample_rate = hp.audio.sampling_rate
10
- self.is_first = True
11
-
12
- def log_training(self, g_loss, d_loss, step):
13
- self.add_scalar('train.g_loss', g_loss, step)
14
- self.add_scalar('train.d_loss', d_loss, step)
15
-
16
- def log_validation(self, g_loss, d_loss, generator, discriminator, target, prediction, step):
17
- self.add_scalar('validation.g_loss', g_loss, step)
18
- self.add_scalar('validation.d_loss', d_loss, step)
19
-
20
- self.add_audio('raw_audio_predicted', prediction, step, self.sample_rate)
21
- self.add_image('waveform_predicted', plot_waveform_to_numpy(prediction), step)
22
-
23
- self.log_histogram(generator, step)
24
- self.log_histogram(discriminator, step)
25
-
26
- if self.is_first:
27
- self.add_audio('raw_audio_target', target, step, self.sample_rate)
28
- self.add_image('waveform_target', plot_waveform_to_numpy(target), step)
29
- self.is_first = False
30
-
31
- def log_histogram(self, model, step):
32
- for tag, value in model.named_parameters():
33
- self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step)