mlinmg commited on
Commit
5650fbc
·
verified ·
1 Parent(s): 0d089b3

Delete tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +0 -233
tokenizer.py DELETED
@@ -1,233 +0,0 @@
1
- from typing import List, Optional, Union, Dict, Tuple, Any
2
- import os
3
- from functools import cached_property
4
-
5
- from transformers import PreTrainedTokenizerFast
6
- from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
7
- from tokenizers import Tokenizer, processors
8
- from tokenizers.pre_tokenizers import WhitespaceSplit
9
- from tokenizers.processors import TemplateProcessing
10
- import torch
11
- from hangul_romanize import Transliter
12
- from hangul_romanize.rule import academic
13
- import cutlet
14
-
15
- from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners,
16
- chinese_transliterate, korean_transliterate,
17
- japanese_cleaners)
18
-
19
- class XTTSTokenizerFast(PreTrainedTokenizerFast):
20
- """
21
- Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
22
- """
23
- def __init__(
24
- self,
25
- vocab_file: str = None,
26
- tokenizer_object: Optional[Tokenizer] = None,
27
- unk_token: str = "[UNK]",
28
- pad_token: str = "[PAD]",
29
- bos_token: str = "[START]",
30
- eos_token: str = "[STOP]",
31
- clean_up_tokenization_spaces: bool = True,
32
- **kwargs
33
- ):
34
- if tokenizer_object is None and vocab_file is not None:
35
- tokenizer_object = Tokenizer.from_file(vocab_file)
36
-
37
- if tokenizer_object is not None:
38
- # Configure the tokenizer
39
- tokenizer_object.pre_tokenizer = WhitespaceSplit()
40
- tokenizer_object.enable_padding(
41
- direction='right',
42
- pad_id=tokenizer_object.token_to_id(pad_token) or 0,
43
- pad_token=pad_token
44
- )
45
- tokenizer_object.post_processor = TemplateProcessing(
46
- single=f"{bos_token} $A {eos_token}",
47
- special_tokens=[
48
- (bos_token, tokenizer_object.token_to_id(bos_token)),
49
- (eos_token, tokenizer_object.token_to_id(eos_token)),
50
- ],
51
- )
52
-
53
- super().__init__(
54
- tokenizer_object=tokenizer_object,
55
- unk_token=unk_token,
56
- pad_token=pad_token,
57
- bos_token=bos_token,
58
- eos_token=eos_token,
59
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
60
- **kwargs
61
- )
62
-
63
- # Character limits per language
64
- self.char_limits = {
65
- "en": 250, "de": 253, "fr": 273, "es": 239,
66
- "it": 213, "pt": 203, "pl": 224, "zh": 82,
67
- "ar": 166, "cs": 186, "ru": 182, "nl": 251,
68
- "tr": 226, "ja": 71, "hu": 224, "ko": 95,
69
- }
70
-
71
- # Initialize language tools
72
- self._katsu = None
73
- self._korean_transliter = Transliter(academic)
74
-
75
- @cached_property
76
- def katsu(self):
77
- if self._katsu is None:
78
- self._katsu = cutlet.Cutlet()
79
- return self._katsu
80
-
81
- def check_input_length(self, text: str, lang: str):
82
- """Check if input text length is within limits for language"""
83
- lang = lang.split("-")[0] # remove region
84
- limit = self.char_limits.get(lang, 250)
85
- if len(text) > limit:
86
- print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.")
87
-
88
- def preprocess_text(self, text: str, lang: str) -> str:
89
- """Apply text preprocessing for language"""
90
- if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
91
- "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
92
- text = multilingual_cleaners(text, lang)
93
- if lang == "zh":
94
- text = chinese_transliterate(text)
95
- if lang == "ko":
96
- text = korean_transliterate(text)
97
- elif lang == "ja":
98
- text = japanese_cleaners(text, self.katsu)
99
- else:
100
- text = basic_cleaners(text)
101
- return text
102
-
103
- def _batch_encode_plus(
104
- self,
105
- batch_text_or_text_pairs,
106
- add_special_tokens: bool = True,
107
- padding_strategy = PaddingStrategy.DO_NOT_PAD,
108
- truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE,
109
- max_length: Optional[int] = 402,
110
- stride: int = 0,
111
- is_split_into_words: bool = False,
112
- pad_to_multiple_of: Optional[int] = None,
113
- return_tensors: Optional[str] = None,
114
- return_token_type_ids: Optional[bool] = None,
115
- return_attention_mask: Optional[bool] = None,
116
- return_overflowing_tokens: bool = False,
117
- return_special_tokens_mask: bool = False,
118
- return_offsets_mapping: bool = False,
119
- return_length: bool = False,
120
- verbose: bool = True,
121
- **kwargs
122
- ) -> Dict[str, Any]:
123
- """
124
- Override batch encoding to handle language-specific preprocessing
125
- """
126
- lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
127
- if isinstance(lang, str):
128
- lang = [lang] * len(batch_text_or_text_pairs)
129
-
130
- # Preprocess each text in the batch with its corresponding language
131
- processed_texts = []
132
- for text, text_lang in zip(batch_text_or_text_pairs, lang):
133
- if isinstance(text, str):
134
- # Check length and preprocess
135
- self.check_input_length(text, text_lang)
136
- processed_text = self.preprocess_text(text, text_lang)
137
-
138
- # Format text with language tag and spaces
139
- lang_code = "zh-cn" if text_lang == "zh" else text_lang
140
- processed_text = f"[{lang_code}]{processed_text}"
141
- processed_text = processed_text.replace(" ", "[SPACE]")
142
-
143
- processed_texts.append(processed_text)
144
- else:
145
- processed_texts.append(text)
146
-
147
- # Call the parent class's encoding method with processed texts
148
- return super()._batch_encode_plus(
149
- processed_texts,
150
- add_special_tokens=add_special_tokens,
151
- padding_strategy=padding_strategy,
152
- truncation_strategy=truncation_strategy,
153
- max_length=max_length,
154
- stride=stride,
155
- is_split_into_words=is_split_into_words,
156
- pad_to_multiple_of=pad_to_multiple_of,
157
- return_tensors=return_tensors,
158
- return_token_type_ids=return_token_type_ids,
159
- return_attention_mask=return_attention_mask,
160
- return_overflowing_tokens=return_overflowing_tokens,
161
- return_special_tokens_mask=return_special_tokens_mask,
162
- return_offsets_mapping=return_offsets_mapping,
163
- return_length=return_length,
164
- verbose=verbose,
165
- **kwargs
166
- )
167
-
168
- def __call__(
169
- self,
170
- text: Union[str, List[str]],
171
- lang: Union[str, List[str]] = "en",
172
- add_special_tokens: bool = True,
173
- padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True
174
- truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True
175
- max_length: Optional[int] = 402,
176
- stride: int = 0,
177
- return_tensors: Optional[str] = None,
178
- return_token_type_ids: Optional[bool] = None,
179
- return_attention_mask: Optional[bool] = True, # Changed default to True
180
- **kwargs
181
- ):
182
- """
183
- Main tokenization method
184
- Args:
185
- text: Text or list of texts to tokenize
186
- lang: Language code or list of language codes corresponding to each text
187
- add_special_tokens: Whether to add special tokens
188
- padding: Padding strategy (default True)
189
- truncation: Truncation strategy (default True)
190
- max_length: Maximum length
191
- stride: Stride for truncation
192
- return_tensors: Format of output tensors ("pt" for PyTorch)
193
- return_token_type_ids: Whether to return token type IDs
194
- return_attention_mask: Whether to return attention mask (default True)
195
- """
196
- # Convert single string to list for batch processing
197
- if isinstance(text, str):
198
- text = [text]
199
- if isinstance(lang, str):
200
- lang = [lang]
201
-
202
- # Ensure text and lang lists have same length
203
- if len(text) != len(lang):
204
- raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})")
205
-
206
- # Convert padding strategy
207
- if isinstance(padding, bool):
208
- padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD
209
- else:
210
- padding_strategy = PaddingStrategy(padding)
211
-
212
- # Convert truncation strategy
213
- if isinstance(truncation, bool):
214
- truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE
215
- else:
216
- truncation_strategy = TruncationStrategy(truncation)
217
-
218
- # Use the batch encoding method
219
- encoded = self._batch_encode_plus(
220
- text,
221
- add_special_tokens=add_special_tokens,
222
- padding_strategy=padding_strategy,
223
- truncation_strategy=truncation_strategy,
224
- max_length=max_length,
225
- stride=stride,
226
- return_tensors=return_tensors,
227
- return_token_type_ids=return_token_type_ids,
228
- return_attention_mask=return_attention_mask,
229
- lang=lang,
230
- **kwargs
231
- )
232
-
233
- return encoded