Spaces:
Runtime error
Runtime error
| import importlib | |
| import re | |
| from coqpit import Coqpit | |
| def to_camel(text): | |
| text = text.capitalize() | |
| return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) | |
| def setup_model(config: Coqpit): | |
| """Load models directly from configuration.""" | |
| if "discriminator_model" in config and "generator_model" in config: | |
| MyModel = importlib.import_module("TTS.vocoder.models.gan") | |
| MyModel = getattr(MyModel, "GAN") | |
| else: | |
| MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower()) | |
| if config.model.lower() == "wavernn": | |
| MyModel = getattr(MyModel, "Wavernn") | |
| elif config.model.lower() == "gan": | |
| MyModel = getattr(MyModel, "GAN") | |
| elif config.model.lower() == "wavegrad": | |
| MyModel = getattr(MyModel, "Wavegrad") | |
| else: | |
| try: | |
| MyModel = getattr(MyModel, to_camel(config.model)) | |
| except ModuleNotFoundError as e: | |
| raise ValueError(f"Model {config.model} not exist!") from e | |
| print(" > Vocoder Model: {}".format(config.model)) | |
| return MyModel.init_from_config(config) | |
| def setup_generator(c): | |
| """TODO: use config object as arguments""" | |
| print(" > Generator Model: {}".format(c.generator_model)) | |
| MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) | |
| MyModel = getattr(MyModel, to_camel(c.generator_model)) | |
| # this is to preserve the Wavernn class name (instead of Wavernn) | |
| if c.generator_model.lower() in "hifigan_generator": | |
| model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) | |
| elif c.generator_model.lower() in "melgan_generator": | |
| model = MyModel( | |
| in_channels=c.audio["num_mels"], | |
| out_channels=1, | |
| proj_kernel=7, | |
| base_channels=512, | |
| upsample_factors=c.generator_model_params["upsample_factors"], | |
| res_kernel=3, | |
| num_res_blocks=c.generator_model_params["num_res_blocks"], | |
| ) | |
| elif c.generator_model in "melgan_fb_generator": | |
| raise ValueError("melgan_fb_generator is now fullband_melgan_generator") | |
| elif c.generator_model.lower() in "multiband_melgan_generator": | |
| model = MyModel( | |
| in_channels=c.audio["num_mels"], | |
| out_channels=4, | |
| proj_kernel=7, | |
| base_channels=384, | |
| upsample_factors=c.generator_model_params["upsample_factors"], | |
| res_kernel=3, | |
| num_res_blocks=c.generator_model_params["num_res_blocks"], | |
| ) | |
| elif c.generator_model.lower() in "fullband_melgan_generator": | |
| model = MyModel( | |
| in_channels=c.audio["num_mels"], | |
| out_channels=1, | |
| proj_kernel=7, | |
| base_channels=512, | |
| upsample_factors=c.generator_model_params["upsample_factors"], | |
| res_kernel=3, | |
| num_res_blocks=c.generator_model_params["num_res_blocks"], | |
| ) | |
| elif c.generator_model.lower() in "parallel_wavegan_generator": | |
| model = MyModel( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| num_res_blocks=c.generator_model_params["num_res_blocks"], | |
| stacks=c.generator_model_params["stacks"], | |
| res_channels=64, | |
| gate_channels=128, | |
| skip_channels=64, | |
| aux_channels=c.audio["num_mels"], | |
| dropout=0.0, | |
| bias=True, | |
| use_weight_norm=True, | |
| upsample_factors=c.generator_model_params["upsample_factors"], | |
| ) | |
| elif c.generator_model.lower() in "univnet_generator": | |
| model = MyModel(**c.generator_model_params) | |
| else: | |
| raise NotImplementedError(f"Model {c.generator_model} not implemented!") | |
| return model | |
| def setup_discriminator(c): | |
| """TODO: use config objekt as arguments""" | |
| print(" > Discriminator Model: {}".format(c.discriminator_model)) | |
| if "parallel_wavegan" in c.discriminator_model: | |
| MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") | |
| else: | |
| MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) | |
| MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) | |
| if c.discriminator_model in "hifigan_discriminator": | |
| model = MyModel() | |
| if c.discriminator_model in "random_window_discriminator": | |
| model = MyModel( | |
| cond_channels=c.audio["num_mels"], | |
| hop_length=c.audio["hop_length"], | |
| uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], | |
| cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], | |
| cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], | |
| window_sizes=c.discriminator_model_params["window_sizes"], | |
| ) | |
| if c.discriminator_model in "melgan_multiscale_discriminator": | |
| model = MyModel( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_sizes=(5, 3), | |
| base_channels=c.discriminator_model_params["base_channels"], | |
| max_channels=c.discriminator_model_params["max_channels"], | |
| downsample_factors=c.discriminator_model_params["downsample_factors"], | |
| ) | |
| if c.discriminator_model == "residual_parallel_wavegan_discriminator": | |
| model = MyModel( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| num_layers=c.discriminator_model_params["num_layers"], | |
| stacks=c.discriminator_model_params["stacks"], | |
| res_channels=64, | |
| gate_channels=128, | |
| skip_channels=64, | |
| dropout=0.0, | |
| bias=True, | |
| nonlinear_activation="LeakyReLU", | |
| nonlinear_activation_params={"negative_slope": 0.2}, | |
| ) | |
| if c.discriminator_model == "parallel_wavegan_discriminator": | |
| model = MyModel( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| num_layers=c.discriminator_model_params["num_layers"], | |
| conv_channels=64, | |
| dilation_factor=1, | |
| nonlinear_activation="LeakyReLU", | |
| nonlinear_activation_params={"negative_slope": 0.2}, | |
| bias=True, | |
| ) | |
| if c.discriminator_model == "univnet_discriminator": | |
| model = MyModel() | |
| return model | |