Update gptx_tokenizer.py
Browse files- gptx_tokenizer.py +51 -46
gptx_tokenizer.py
CHANGED
@@ -38,7 +38,6 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
38 |
text (str): The text to encode.
|
39 |
return_tokens (bool): If True, returns token strings instead of token IDs.
|
40 |
is_continuation (bool): If True, uses a continuation tokenizer (if available).
|
41 |
-
|
42 |
Returns:
|
43 |
List[int] or List[str]: Encoded text as a list of token IDs or token strings.
|
44 |
"""
|
@@ -56,7 +55,6 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
56 |
"""
|
57 |
Create a list of special tokens, including the BOS, EOS, PAD, EOD tokens,
|
58 |
and 256 additional placeholder tokens.
|
59 |
-
|
60 |
Returns:
|
61 |
List[str]: List of special tokens.
|
62 |
"""
|
@@ -64,7 +62,7 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
64 |
f"<placeholder_tok_{i}>" for i in range(256)
|
65 |
]
|
66 |
|
67 |
-
|
68 |
if not os.path.isfile(config_path):
|
69 |
config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
|
70 |
if not config_path:
|
@@ -89,33 +87,60 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
89 |
OSError: If the model file cannot be loaded or downloaded.
|
90 |
"""
|
91 |
if not os.path.isfile(model_file_or_name):
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
try:
|
96 |
-
# List all files in the repo
|
97 |
-
repo_files = list_repo_files(repo_id)
|
98 |
-
|
99 |
-
# Find the tokenizer model file
|
100 |
-
tokenizer_files = [f for f in repo_files if f.endswith('.model')]
|
101 |
-
if not tokenizer_files:
|
102 |
-
raise FileNotFoundError(f"No .model file found in repository {repo_id}")
|
103 |
-
|
104 |
-
# Use the first .model file found
|
105 |
-
model_file = tokenizer_files[0]
|
106 |
-
print(f"Found tokenizer model file: {model_file}")
|
107 |
-
|
108 |
-
# Download the file
|
109 |
-
model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
|
110 |
-
print(f"Downloaded tokenizer model to: {model_file_or_name}")
|
111 |
-
except Exception as e:
|
112 |
-
raise OSError(f"Failed to download tokenizer model: {str(e)}")
|
113 |
|
114 |
try:
|
115 |
return spm.SentencePieceProcessor(model_file=model_file_or_name)
|
116 |
except Exception as e:
|
117 |
raise OSError(f"Failed to load tokenizer model: {str(e)}")
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
def __init__(
|
120 |
self,
|
121 |
model_path: Optional[str] = None,
|
@@ -124,12 +149,10 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
124 |
) -> None:
|
125 |
"""
|
126 |
Initialize the tokenizer.
|
127 |
-
|
128 |
Args:
|
129 |
model_path (Optional[str]): Path to the tokenizer model file.
|
130 |
config_path (Optional[str]): Path to the tokenizer configuration file.
|
131 |
**kwargs: Additional keyword arguments passed to the superclass.
|
132 |
-
|
133 |
This method also ensures backward compatibility by setting
|
134 |
`clean_up_tokenization_spaces` to False by default.
|
135 |
"""
|
@@ -176,7 +199,6 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
176 |
def vocab_size(self) -> int:
|
177 |
"""
|
178 |
Get the size of the tokenizer vocabulary.
|
179 |
-
|
180 |
Returns:
|
181 |
int: The size of the vocabulary.
|
182 |
"""
|
@@ -185,7 +207,6 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
185 |
def get_vocab(self) -> Dict[str, int]:
|
186 |
"""
|
187 |
Get the vocabulary as a dictionary mapping token strings to their IDs.
|
188 |
-
|
189 |
Returns:
|
190 |
Dict[str, int]: Vocabulary mapping.
|
191 |
"""
|
@@ -196,11 +217,9 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
196 |
def _tokenize(self, text: str, **kwargs) -> List[int]:
|
197 |
"""
|
198 |
Tokenize the input text.
|
199 |
-
|
200 |
Args:
|
201 |
text (str): Text to tokenize.
|
202 |
**kwargs: Additional keyword arguments.
|
203 |
-
|
204 |
Returns:
|
205 |
List[int]: List of token IDs.
|
206 |
"""
|
@@ -210,13 +229,10 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
210 |
def _convert_token_to_id(self, token: str) -> int:
|
211 |
"""
|
212 |
Convert a token string to its corresponding ID.
|
213 |
-
|
214 |
Args:
|
215 |
token (str): The token to convert.
|
216 |
-
|
217 |
Returns:
|
218 |
int: The token's ID.
|
219 |
-
|
220 |
Raises:
|
221 |
ValueError: If the token is unknown and cannot be encoded to a single ID.
|
222 |
"""
|
@@ -230,11 +246,9 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
230 |
) -> str:
|
231 |
"""
|
232 |
Decode a list of token IDs into a string.
|
233 |
-
|
234 |
Args:
|
235 |
token_ids (Union[List[int], List[List[int]]]): List of token IDs or lists of token IDs.
|
236 |
num_threads (Optional[int]): Number of threads to use for decoding.
|
237 |
-
|
238 |
Returns:
|
239 |
str: Decoded string.
|
240 |
"""
|
@@ -243,10 +257,8 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
243 |
def _convert_id_to_token(self, index: int) -> str:
|
244 |
"""
|
245 |
Convert a token ID to its corresponding token string.
|
246 |
-
|
247 |
Args:
|
248 |
index (int): Token ID.
|
249 |
-
|
250 |
Returns:
|
251 |
str: Corresponding token string.
|
252 |
"""
|
@@ -255,10 +267,8 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
255 |
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
256 |
"""
|
257 |
Convert a list of tokens into a single string.
|
258 |
-
|
259 |
Args:
|
260 |
tokens (List[str]): List of token strings.
|
261 |
-
|
262 |
Returns:
|
263 |
str: Concatenated string of tokens.
|
264 |
"""
|
@@ -267,14 +277,11 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
267 |
def _tok_decode(self, token_ids: List[int], **kwargs: Any) -> str:
|
268 |
"""
|
269 |
Internal method to decode token IDs with additional arguments.
|
270 |
-
|
271 |
Args:
|
272 |
token_ids (List[int]): List of token IDs.
|
273 |
**kwargs: Additional arguments to pass to the decode method.
|
274 |
-
|
275 |
Returns:
|
276 |
str: Decoded string.
|
277 |
-
|
278 |
This method also issues a warning if unsupported arguments are provided.
|
279 |
"""
|
280 |
passed_kwargs = {key: value for (key, value) in kwargs.items() if key in self.decode_kwargs}
|
@@ -440,6 +447,4 @@ class SPTokenizer(HFGPTXTokenizer):
|
|
440 |
self.chat_template = {
|
441 |
lang: f"System: {sys_msg}" + "{{- '\\n'}}\n" + chat_template
|
442 |
for lang, sys_msg in self.system_messages_by_lang.items()
|
443 |
-
}
|
444 |
-
|
445 |
-
|
|
|
38 |
text (str): The text to encode.
|
39 |
return_tokens (bool): If True, returns token strings instead of token IDs.
|
40 |
is_continuation (bool): If True, uses a continuation tokenizer (if available).
|
|
|
41 |
Returns:
|
42 |
List[int] or List[str]: Encoded text as a list of token IDs or token strings.
|
43 |
"""
|
|
|
55 |
"""
|
56 |
Create a list of special tokens, including the BOS, EOS, PAD, EOD tokens,
|
57 |
and 256 additional placeholder tokens.
|
|
|
58 |
Returns:
|
59 |
List[str]: List of special tokens.
|
60 |
"""
|
|
|
62 |
f"<placeholder_tok_{i}>" for i in range(256)
|
63 |
]
|
64 |
|
65 |
+
def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
|
66 |
if not os.path.isfile(config_path):
|
67 |
config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
|
68 |
if not config_path:
|
|
|
87 |
OSError: If the model file cannot be loaded or downloaded.
|
88 |
"""
|
89 |
if not os.path.isfile(model_file_or_name):
|
90 |
+
model_file_or_name = try_to_load_from_cache(repo_id=repo_id, filename=Path(model_file_or_name).name)
|
91 |
+
if not model_file_or_name:
|
92 |
+
model_file_or_name = self._download_model_from_hub(repo_id=repo_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
try:
|
95 |
return spm.SentencePieceProcessor(model_file=model_file_or_name)
|
96 |
except Exception as e:
|
97 |
raise OSError(f"Failed to load tokenizer model: {str(e)}")
|
98 |
+
|
99 |
+
def _download_model_from_hub(self, repo_id: str) -> Optional[str]:
|
100 |
+
try:
|
101 |
+
# List all files in the repo
|
102 |
+
repo_files = list_repo_files(repo_id)
|
103 |
+
|
104 |
+
# Find the tokenizer model file
|
105 |
+
tokenizer_files = [f for f in repo_files if f.endswith('.model')]
|
106 |
+
if not tokenizer_files:
|
107 |
+
raise FileNotFoundError(f"No .model file found in repository {repo_id}")
|
108 |
+
|
109 |
+
# Use the first .model file found
|
110 |
+
model_file = tokenizer_files[0]
|
111 |
+
print(f"Found tokenizer model file: {model_file}")
|
112 |
+
|
113 |
+
# Download the file
|
114 |
+
model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
|
115 |
+
print(f"Downloaded tokenizer model to: {model_file_or_name}")
|
116 |
+
except Exception as e:
|
117 |
+
raise OSError(f"Failed to download tokenizer model: {str(e)}")
|
118 |
+
|
119 |
+
return model_file_or_name
|
120 |
+
|
121 |
+
def _download_config_from_hub(self, repo_id: str):
|
122 |
+
if repo_id is None:
|
123 |
+
raise ValueError("repo_id must be provided if config_path is not a local file")
|
124 |
+
|
125 |
+
try:
|
126 |
+
# List all files in the repo
|
127 |
+
repo_files = list_repo_files(repo_id)
|
128 |
+
|
129 |
+
# Find the tokenizer config file
|
130 |
+
tokenizer_files = [f for f in repo_files if f.endswith('tokenizer_config.json')]
|
131 |
+
if not tokenizer_files:
|
132 |
+
raise FileNotFoundError(f"No tokenizer_config.json file found in repository {repo_id}")
|
133 |
+
|
134 |
+
# Use the first tokenizer_config.json file found
|
135 |
+
tokenizer_config_file = tokenizer_files[0]
|
136 |
+
print(f"Found tokenizer config file: {tokenizer_config_file}")
|
137 |
+
|
138 |
+
# Download the file
|
139 |
+
tokenizer_config_file_or_name = hf_hub_download(repo_id=repo_id, filename=tokenizer_config_file)
|
140 |
+
print(f"Downloaded tokenizer config file to: {tokenizer_config_file_or_name}")
|
141 |
+
return tokenizer_config_file_or_name
|
142 |
+
except Exception as e:
|
143 |
+
raise OSError(f"Failed to download tokenizer model: {str(e)}")
|
144 |
def __init__(
|
145 |
self,
|
146 |
model_path: Optional[str] = None,
|
|
|
149 |
) -> None:
|
150 |
"""
|
151 |
Initialize the tokenizer.
|
|
|
152 |
Args:
|
153 |
model_path (Optional[str]): Path to the tokenizer model file.
|
154 |
config_path (Optional[str]): Path to the tokenizer configuration file.
|
155 |
**kwargs: Additional keyword arguments passed to the superclass.
|
|
|
156 |
This method also ensures backward compatibility by setting
|
157 |
`clean_up_tokenization_spaces` to False by default.
|
158 |
"""
|
|
|
199 |
def vocab_size(self) -> int:
|
200 |
"""
|
201 |
Get the size of the tokenizer vocabulary.
|
|
|
202 |
Returns:
|
203 |
int: The size of the vocabulary.
|
204 |
"""
|
|
|
207 |
def get_vocab(self) -> Dict[str, int]:
|
208 |
"""
|
209 |
Get the vocabulary as a dictionary mapping token strings to their IDs.
|
|
|
210 |
Returns:
|
211 |
Dict[str, int]: Vocabulary mapping.
|
212 |
"""
|
|
|
217 |
def _tokenize(self, text: str, **kwargs) -> List[int]:
|
218 |
"""
|
219 |
Tokenize the input text.
|
|
|
220 |
Args:
|
221 |
text (str): Text to tokenize.
|
222 |
**kwargs: Additional keyword arguments.
|
|
|
223 |
Returns:
|
224 |
List[int]: List of token IDs.
|
225 |
"""
|
|
|
229 |
def _convert_token_to_id(self, token: str) -> int:
|
230 |
"""
|
231 |
Convert a token string to its corresponding ID.
|
|
|
232 |
Args:
|
233 |
token (str): The token to convert.
|
|
|
234 |
Returns:
|
235 |
int: The token's ID.
|
|
|
236 |
Raises:
|
237 |
ValueError: If the token is unknown and cannot be encoded to a single ID.
|
238 |
"""
|
|
|
246 |
) -> str:
|
247 |
"""
|
248 |
Decode a list of token IDs into a string.
|
|
|
249 |
Args:
|
250 |
token_ids (Union[List[int], List[List[int]]]): List of token IDs or lists of token IDs.
|
251 |
num_threads (Optional[int]): Number of threads to use for decoding.
|
|
|
252 |
Returns:
|
253 |
str: Decoded string.
|
254 |
"""
|
|
|
257 |
def _convert_id_to_token(self, index: int) -> str:
|
258 |
"""
|
259 |
Convert a token ID to its corresponding token string.
|
|
|
260 |
Args:
|
261 |
index (int): Token ID.
|
|
|
262 |
Returns:
|
263 |
str: Corresponding token string.
|
264 |
"""
|
|
|
267 |
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
268 |
"""
|
269 |
Convert a list of tokens into a single string.
|
|
|
270 |
Args:
|
271 |
tokens (List[str]): List of token strings.
|
|
|
272 |
Returns:
|
273 |
str: Concatenated string of tokens.
|
274 |
"""
|
|
|
277 |
def _tok_decode(self, token_ids: List[int], **kwargs: Any) -> str:
|
278 |
"""
|
279 |
Internal method to decode token IDs with additional arguments.
|
|
|
280 |
Args:
|
281 |
token_ids (List[int]): List of token IDs.
|
282 |
**kwargs: Additional arguments to pass to the decode method.
|
|
|
283 |
Returns:
|
284 |
str: Decoded string.
|
|
|
285 |
This method also issues a warning if unsupported arguments are provided.
|
286 |
"""
|
287 |
passed_kwargs = {key: value for (key, value) in kwargs.items() if key in self.decode_kwargs}
|
|
|
447 |
self.chat_template = {
|
448 |
lang: f"System: {sys_msg}" + "{{- '\\n'}}\n" + chat_template
|
449 |
for lang, sys_msg in self.system_messages_by_lang.items()
|
450 |
+
}
|
|
|
|