mlinmg commited on
Commit
8cd94f2
·
verified ·
1 Parent(s): 525f7bc

Delete xtts2_config.py

Browse files
Files changed (1) hide show
  1. xtts2_config.py +0 -229
xtts2_config.py DELETED
@@ -1,229 +0,0 @@
1
- from dataclasses import asdict, dataclass
2
- from typing import Dict, Optional, List
3
- from transformers.configuration_utils import PretrainedConfig
4
- from transformers.utils import logging
5
-
6
- logger = logging.get_logger(__name__)
7
-
8
-
9
- @dataclass
10
- class GPTAudioConfig:
11
- """Configuration for GPT audio processing parameters"""
12
- mel_channels: int = 80
13
- sample_rate: int = 22050
14
- output_sample_rate: int = 24000
15
-
16
- @dataclass
17
- class XTTSAudioConfig:
18
- """Configuration for audio processing parameters"""
19
- sample_rate: int = 22050
20
- output_sample_rate: int = 24000
21
- mel_channels: int = 80
22
- hop_length: int = 256
23
- win_length: int = 1024
24
- n_fft: int = 1024
25
- fmin: int = 0
26
- fmax: int = 8000
27
- power: float = 1.0
28
- mel_norms_file: Optional[str] = None
29
-
30
-
31
- class XTTSGPTConfig(PretrainedConfig):
32
- """Configuration class for the GPT component of XTTS."""
33
- model_type = "xtts_gpt"
34
-
35
- def __init__(
36
- self,
37
- # Model architecture
38
- hidden_size: int = 1024, # gpt_n_model_channels in original
39
- n_inner: int = 4096,
40
- num_hidden_layers: int = 30, # gpt_layers in original
41
- num_attention_heads: int = 16, # gpt_n_heads in original
42
-
43
- # Tokenizer settings
44
- vocab_size: int = 6681, # gpt_number_text_tokens in original
45
- number_text_tokens: int = 6681, # Explicit text token vocabulary size
46
- start_text_token: Optional[int] = None,
47
- stop_text_token: Optional[int] = None,
48
-
49
- # Audio token settings
50
- num_audio_tokens: int = 1026, # gpt_num_audio_tokens in original
51
- start_audio_token: int = 1024, # gpt_start_audio_token in original
52
- stop_audio_token: int = 1025, # gpt_stop_audio_token in original
53
-
54
- # Sequence length settings
55
- max_audio_tokens: int = 605, # gpt_max_audio_tokens in original
56
- max_text_tokens: int = 402, # gpt_max_text_tokens in original
57
- max_prompt_tokens: int = 70, # gpt_max_prompt_tokens in original
58
- gpt_max_audio_tokens: int = 605, # Used for generation
59
-
60
- # Model behavior settings
61
- use_masking_gt_prompt_approach: bool = True, # gpt_use_masking_gt_prompt_approach in original
62
- use_perceiver_resampler: bool = True, # gpt_use_perceiver_resampler in original
63
- kv_cache: bool = True,
64
- enable_redaction: bool = False,
65
-
66
- # GPT batch settings
67
- gpt_batch_size: int = 1,
68
-
69
- # Audio processing
70
- audio_config: Optional[Dict] = None,
71
-
72
- # Architecture specifics
73
- layer_norm_epsilon: float = 1e-5,
74
- initializer_range: float = 0.02,
75
- add_cross_attention: bool = False,
76
- scale_attn_by_inverse_layer_idx: bool = False,
77
- reorder_and_upcast_attn: bool = False,
78
-
79
- # Size settings for the decoder
80
- decoder_input_dim: int = 1024,
81
- architectures=["XttsGPT"],
82
- auto_map = {
83
- "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
84
- "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT",
85
- },
86
- activation_function: str = "gelu",
87
- attn_pdrop: float = 0.1,
88
- **kwargs
89
- ):
90
- super().__init__(**kwargs)
91
- self.architectures = architectures
92
- self.auto_map = auto_map
93
- self.audio_config = GPTAudioConfig(
94
- **audio_config if audio_config is not None else {}
95
- )
96
- self.activation_function = activation_function
97
- self.attn_pdrop = attn_pdrop
98
- self.hidden_size = hidden_size
99
- self.n_inner = n_inner
100
- self.num_hidden_layers = num_hidden_layers
101
- self.num_attention_heads = num_attention_heads
102
-
103
- self.vocab_size = vocab_size
104
- self.number_text_tokens = number_text_tokens
105
- self.start_text_token = start_text_token
106
- self.stop_text_token = stop_text_token
107
-
108
- self.num_audio_tokens = num_audio_tokens
109
- self.start_audio_token = start_audio_token
110
- self.stop_audio_token = stop_audio_token
111
-
112
- self.max_audio_tokens = max_audio_tokens
113
- self.max_text_tokens = max_text_tokens
114
- self.max_prompt_tokens = max_prompt_tokens
115
- self.gpt_max_audio_tokens = gpt_max_audio_tokens
116
-
117
- self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach
118
- self.use_perceiver_resampler = use_perceiver_resampler
119
- self.kv_cache = kv_cache
120
- self.enable_redaction = enable_redaction
121
-
122
- self.gpt_batch_size = gpt_batch_size
123
-
124
- self.layer_norm_epsilon = layer_norm_epsilon
125
- self.initializer_range = initializer_range
126
- self.add_cross_attention = add_cross_attention
127
- self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
128
- self.reorder_and_upcast_attn = reorder_and_upcast_attn
129
-
130
- self.decoder_input_dim = decoder_input_dim
131
-
132
- def to_dict(self) -> Dict:
133
- """Convert the config to a dictionary."""
134
- output = super().to_dict()
135
- output["audio_config"] = asdict(self.audio_config)
136
- return output
137
-
138
- @classmethod
139
- def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig":
140
- """Create a config from a dictionary."""
141
- return cls(**config_dict)
142
-
143
-
144
- class XTTSConfig(PretrainedConfig):
145
- """Configuration class for XTTS model components except GPT."""
146
- model_type = "xtts"
147
-
148
- def __init__(
149
- self,
150
- # Audio settings
151
- audio_config: Optional[Dict] = None,
152
- input_sample_rate: int = 22050,
153
- output_sample_rate: int = 24000,
154
- output_hop_length: int = 256,
155
-
156
- # Model architecture
157
- decoder_input_dim: int = 1024,
158
- d_vector_dim: int = 512,
159
- cond_d_vector_in_each_upsampling_layer: bool = True,
160
-
161
- # Training settings
162
- gpt_code_stride_len: int = 1024,
163
- duration_const: int = 102400,
164
-
165
- # Tokenizer settings
166
- tokenizer_file: str = "",
167
- num_chars: int = 255,
168
-
169
- # Language support
170
- languages: Optional[List[str]] = None,
171
-
172
- # GPT configuration
173
- gpt_config: Optional[Dict] = None,
174
- architectures=["Xtts"],
175
- auto_map = {
176
- "AutoConfig": "AstraMindAI/xtts2--xtts2_config.XTTSConfig",
177
- "AutoModelForCausalLM": "AstraMindAI/xtts2--xtts2_modeling.Xtts",
178
- "AutoTokenizer": "AstraMindAI/xtts2--tokenizer.XTTSTokenizerFast"
179
- },
180
- **kwargs
181
- ):
182
- super().__init__(**kwargs)
183
- self.architectures = architectures
184
- self.auto_map = auto_map
185
- # Initialize audio config
186
- self.audio_config = XTTSAudioConfig(
187
- **audio_config if audio_config is not None else {}
188
- )
189
-
190
- self.input_sample_rate = input_sample_rate
191
- self.output_sample_rate = output_sample_rate
192
- self.output_hop_length = output_hop_length
193
-
194
- self.decoder_input_dim = decoder_input_dim
195
- self.d_vector_dim = d_vector_dim
196
- self.cond_d_vector_in_each_upsampling_layer = cond_d_vector_in_each_upsampling_layer
197
-
198
- self.gpt_code_stride_len = gpt_code_stride_len
199
- self.duration_const = duration_const
200
-
201
- self.tokenizer_file = tokenizer_file
202
- self.num_chars = num_chars
203
-
204
- # Initialize GPT config
205
- self.gpt = XTTSGPTConfig(**gpt_config if gpt_config is not None else {})
206
-
207
- if languages is None:
208
- self.languages = [
209
- "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru",
210
- "nl", "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
211
- ]
212
- else:
213
- self.languages = languages
214
-
215
- def to_dict(self) -> Dict:
216
- """Convert the config to a dictionary."""
217
- output = super().to_dict()
218
- output["audio_config"] = asdict(self.audio_config)
219
- output["gpt_config"] = self.gpt.to_dict()
220
- return output
221
-
222
- @classmethod
223
- def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSConfig":
224
- """Create a config from a dictionary."""
225
- if "gpt_config" in config_dict:
226
- gpt_config = config_dict["gpt_config"]
227
- config_dict = {k: v for k, v in config_dict.items() if k != "gpt_config"}
228
- return cls(gpt_config=gpt_config, **config_dict)
229
- return cls(**config_dict)