Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| from coqpit import Coqpit | |
| from TTS.model import BaseTrainerModel | |
| # pylint: skip-file | |
| class BaseVocoder(BaseTrainerModel): | |
| """Base `vocoder` class. Every new `vocoder` model must inherit this. | |
| It defines `vocoder` specific functions on top of `Model`. | |
| Notes on input/output tensor shapes: | |
| Any input or output tensor of the model must be shaped as | |
| - 3D tensors `batch x time x channels` | |
| - 2D tensors `batch x channels` | |
| - 1D tensors `batch x 1` | |
| """ | |
| MODEL_TYPE = "vocoder" | |
| def __init__(self, config): | |
| super().__init__() | |
| self._set_model_args(config) | |
| def _set_model_args(self, config: Coqpit): | |
| """Setup model args based on the config type. | |
| If the config is for training with a name like "*Config", then the model args are embeded in the | |
| config.model_args | |
| If the config is for the model with a name like "*Args", then we assign the directly. | |
| """ | |
| # don't use isintance not to import recursively | |
| if "Config" in config.__class__.__name__: | |
| if "characters" in config: | |
| _, self.config, num_chars = self.get_characters(config) | |
| self.config.num_chars = num_chars | |
| if hasattr(self.config, "model_args"): | |
| config.model_args.num_chars = num_chars | |
| if "model_args" in config: | |
| self.args = self.config.model_args | |
| # This is for backward compatibility | |
| if "model_params" in config: | |
| self.args = self.config.model_params | |
| else: | |
| self.config = config | |
| if "model_args" in config: | |
| self.args = self.config.model_args | |
| # This is for backward compatibility | |
| if "model_params" in config: | |
| self.args = self.config.model_params | |
| else: | |
| raise ValueError("config must be either a *Config or *Args") | |