Text Generation
Transformers
Safetensors
llama
text-generation-inference
mfromm commited on
Commit
2bdb949
·
verified ·
1 Parent(s): 11ff4bc

Update gptx_tokenizer.py

Browse files
Files changed (1) hide show
  1. gptx_tokenizer.py +51 -46
gptx_tokenizer.py CHANGED
@@ -38,7 +38,6 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
38
  text (str): The text to encode.
39
  return_tokens (bool): If True, returns token strings instead of token IDs.
40
  is_continuation (bool): If True, uses a continuation tokenizer (if available).
41
-
42
  Returns:
43
  List[int] or List[str]: Encoded text as a list of token IDs or token strings.
44
  """
@@ -56,7 +55,6 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
56
  """
57
  Create a list of special tokens, including the BOS, EOS, PAD, EOD tokens,
58
  and 256 additional placeholder tokens.
59
-
60
  Returns:
61
  List[str]: List of special tokens.
62
  """
@@ -64,7 +62,7 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
64
  f"<placeholder_tok_{i}>" for i in range(256)
65
  ]
66
 
67
- def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
68
  if not os.path.isfile(config_path):
69
  config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
70
  if not config_path:
@@ -89,33 +87,60 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
89
  OSError: If the model file cannot be loaded or downloaded.
90
  """
91
  if not os.path.isfile(model_file_or_name):
92
- if repo_id is None:
93
- raise ValueError("repo_id must be provided if model_file_or_name is not a local file")
94
-
95
- try:
96
- # List all files in the repo
97
- repo_files = list_repo_files(repo_id)
98
-
99
- # Find the tokenizer model file
100
- tokenizer_files = [f for f in repo_files if f.endswith('.model')]
101
- if not tokenizer_files:
102
- raise FileNotFoundError(f"No .model file found in repository {repo_id}")
103
-
104
- # Use the first .model file found
105
- model_file = tokenizer_files[0]
106
- print(f"Found tokenizer model file: {model_file}")
107
-
108
- # Download the file
109
- model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
110
- print(f"Downloaded tokenizer model to: {model_file_or_name}")
111
- except Exception as e:
112
- raise OSError(f"Failed to download tokenizer model: {str(e)}")
113
 
114
  try:
115
  return spm.SentencePieceProcessor(model_file=model_file_or_name)
116
  except Exception as e:
117
  raise OSError(f"Failed to load tokenizer model: {str(e)}")
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def __init__(
120
  self,
121
  model_path: Optional[str] = None,
@@ -124,12 +149,10 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
124
  ) -> None:
125
  """
126
  Initialize the tokenizer.
127
-
128
  Args:
129
  model_path (Optional[str]): Path to the tokenizer model file.
130
  config_path (Optional[str]): Path to the tokenizer configuration file.
131
  **kwargs: Additional keyword arguments passed to the superclass.
132
-
133
  This method also ensures backward compatibility by setting
134
  `clean_up_tokenization_spaces` to False by default.
135
  """
@@ -176,7 +199,6 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
176
  def vocab_size(self) -> int:
177
  """
178
  Get the size of the tokenizer vocabulary.
179
-
180
  Returns:
181
  int: The size of the vocabulary.
182
  """
@@ -185,7 +207,6 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
185
  def get_vocab(self) -> Dict[str, int]:
186
  """
187
  Get the vocabulary as a dictionary mapping token strings to their IDs.
188
-
189
  Returns:
190
  Dict[str, int]: Vocabulary mapping.
191
  """
@@ -196,11 +217,9 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
196
  def _tokenize(self, text: str, **kwargs) -> List[int]:
197
  """
198
  Tokenize the input text.
199
-
200
  Args:
201
  text (str): Text to tokenize.
202
  **kwargs: Additional keyword arguments.
203
-
204
  Returns:
205
  List[int]: List of token IDs.
206
  """
@@ -210,13 +229,10 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
210
  def _convert_token_to_id(self, token: str) -> int:
211
  """
212
  Convert a token string to its corresponding ID.
213
-
214
  Args:
215
  token (str): The token to convert.
216
-
217
  Returns:
218
  int: The token's ID.
219
-
220
  Raises:
221
  ValueError: If the token is unknown and cannot be encoded to a single ID.
222
  """
@@ -230,11 +246,9 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
230
  ) -> str:
231
  """
232
  Decode a list of token IDs into a string.
233
-
234
  Args:
235
  token_ids (Union[List[int], List[List[int]]]): List of token IDs or lists of token IDs.
236
  num_threads (Optional[int]): Number of threads to use for decoding.
237
-
238
  Returns:
239
  str: Decoded string.
240
  """
@@ -243,10 +257,8 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
243
  def _convert_id_to_token(self, index: int) -> str:
244
  """
245
  Convert a token ID to its corresponding token string.
246
-
247
  Args:
248
  index (int): Token ID.
249
-
250
  Returns:
251
  str: Corresponding token string.
252
  """
@@ -255,10 +267,8 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
255
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
256
  """
257
  Convert a list of tokens into a single string.
258
-
259
  Args:
260
  tokens (List[str]): List of token strings.
261
-
262
  Returns:
263
  str: Concatenated string of tokens.
264
  """
@@ -267,14 +277,11 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
267
  def _tok_decode(self, token_ids: List[int], **kwargs: Any) -> str:
268
  """
269
  Internal method to decode token IDs with additional arguments.
270
-
271
  Args:
272
  token_ids (List[int]): List of token IDs.
273
  **kwargs: Additional arguments to pass to the decode method.
274
-
275
  Returns:
276
  str: Decoded string.
277
-
278
  This method also issues a warning if unsupported arguments are provided.
279
  """
280
  passed_kwargs = {key: value for (key, value) in kwargs.items() if key in self.decode_kwargs}
@@ -440,6 +447,4 @@ class SPTokenizer(HFGPTXTokenizer):
440
  self.chat_template = {
441
  lang: f"System: {sys_msg}" + "{{- '\\n'}}\n" + chat_template
442
  for lang, sys_msg in self.system_messages_by_lang.items()
443
- }
444
-
445
-
 
38
  text (str): The text to encode.
39
  return_tokens (bool): If True, returns token strings instead of token IDs.
40
  is_continuation (bool): If True, uses a continuation tokenizer (if available).
 
41
  Returns:
42
  List[int] or List[str]: Encoded text as a list of token IDs or token strings.
43
  """
 
55
  """
56
  Create a list of special tokens, including the BOS, EOS, PAD, EOD tokens,
57
  and 256 additional placeholder tokens.
 
58
  Returns:
59
  List[str]: List of special tokens.
60
  """
 
62
  f"<placeholder_tok_{i}>" for i in range(256)
63
  ]
64
 
65
+ def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
66
  if not os.path.isfile(config_path):
67
  config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
68
  if not config_path:
 
87
  OSError: If the model file cannot be loaded or downloaded.
88
  """
89
  if not os.path.isfile(model_file_or_name):
90
+ model_file_or_name = try_to_load_from_cache(repo_id=repo_id, filename=Path(model_file_or_name).name)
91
+ if not model_file_or_name:
92
+ model_file_or_name = self._download_model_from_hub(repo_id=repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  try:
95
  return spm.SentencePieceProcessor(model_file=model_file_or_name)
96
  except Exception as e:
97
  raise OSError(f"Failed to load tokenizer model: {str(e)}")
98
+
99
+ def _download_model_from_hub(self, repo_id: str) -> Optional[str]:
100
+ try:
101
+ # List all files in the repo
102
+ repo_files = list_repo_files(repo_id)
103
+
104
+ # Find the tokenizer model file
105
+ tokenizer_files = [f for f in repo_files if f.endswith('.model')]
106
+ if not tokenizer_files:
107
+ raise FileNotFoundError(f"No .model file found in repository {repo_id}")
108
+
109
+ # Use the first .model file found
110
+ model_file = tokenizer_files[0]
111
+ print(f"Found tokenizer model file: {model_file}")
112
+
113
+ # Download the file
114
+ model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
115
+ print(f"Downloaded tokenizer model to: {model_file_or_name}")
116
+ except Exception as e:
117
+ raise OSError(f"Failed to download tokenizer model: {str(e)}")
118
+
119
+ return model_file_or_name
120
+
121
+ def _download_config_from_hub(self, repo_id: str):
122
+ if repo_id is None:
123
+ raise ValueError("repo_id must be provided if config_path is not a local file")
124
+
125
+ try:
126
+ # List all files in the repo
127
+ repo_files = list_repo_files(repo_id)
128
+
129
+ # Find the tokenizer config file
130
+ tokenizer_files = [f for f in repo_files if f.endswith('tokenizer_config.json')]
131
+ if not tokenizer_files:
132
+ raise FileNotFoundError(f"No tokenizer_config.json file found in repository {repo_id}")
133
+
134
+ # Use the first tokenizer_config.json file found
135
+ tokenizer_config_file = tokenizer_files[0]
136
+ print(f"Found tokenizer config file: {tokenizer_config_file}")
137
+
138
+ # Download the file
139
+ tokenizer_config_file_or_name = hf_hub_download(repo_id=repo_id, filename=tokenizer_config_file)
140
+ print(f"Downloaded tokenizer config file to: {tokenizer_config_file_or_name}")
141
+ return tokenizer_config_file_or_name
142
+ except Exception as e:
143
+ raise OSError(f"Failed to download tokenizer model: {str(e)}")
144
  def __init__(
145
  self,
146
  model_path: Optional[str] = None,
 
149
  ) -> None:
150
  """
151
  Initialize the tokenizer.
 
152
  Args:
153
  model_path (Optional[str]): Path to the tokenizer model file.
154
  config_path (Optional[str]): Path to the tokenizer configuration file.
155
  **kwargs: Additional keyword arguments passed to the superclass.
 
156
  This method also ensures backward compatibility by setting
157
  `clean_up_tokenization_spaces` to False by default.
158
  """
 
199
  def vocab_size(self) -> int:
200
  """
201
  Get the size of the tokenizer vocabulary.
 
202
  Returns:
203
  int: The size of the vocabulary.
204
  """
 
207
  def get_vocab(self) -> Dict[str, int]:
208
  """
209
  Get the vocabulary as a dictionary mapping token strings to their IDs.
 
210
  Returns:
211
  Dict[str, int]: Vocabulary mapping.
212
  """
 
217
  def _tokenize(self, text: str, **kwargs) -> List[int]:
218
  """
219
  Tokenize the input text.
 
220
  Args:
221
  text (str): Text to tokenize.
222
  **kwargs: Additional keyword arguments.
 
223
  Returns:
224
  List[int]: List of token IDs.
225
  """
 
229
  def _convert_token_to_id(self, token: str) -> int:
230
  """
231
  Convert a token string to its corresponding ID.
 
232
  Args:
233
  token (str): The token to convert.
 
234
  Returns:
235
  int: The token's ID.
 
236
  Raises:
237
  ValueError: If the token is unknown and cannot be encoded to a single ID.
238
  """
 
246
  ) -> str:
247
  """
248
  Decode a list of token IDs into a string.
 
249
  Args:
250
  token_ids (Union[List[int], List[List[int]]]): List of token IDs or lists of token IDs.
251
  num_threads (Optional[int]): Number of threads to use for decoding.
 
252
  Returns:
253
  str: Decoded string.
254
  """
 
257
  def _convert_id_to_token(self, index: int) -> str:
258
  """
259
  Convert a token ID to its corresponding token string.
 
260
  Args:
261
  index (int): Token ID.
 
262
  Returns:
263
  str: Corresponding token string.
264
  """
 
267
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
268
  """
269
  Convert a list of tokens into a single string.
 
270
  Args:
271
  tokens (List[str]): List of token strings.
 
272
  Returns:
273
  str: Concatenated string of tokens.
274
  """
 
277
  def _tok_decode(self, token_ids: List[int], **kwargs: Any) -> str:
278
  """
279
  Internal method to decode token IDs with additional arguments.
 
280
  Args:
281
  token_ids (List[int]): List of token IDs.
282
  **kwargs: Additional arguments to pass to the decode method.
 
283
  Returns:
284
  str: Decoded string.
 
285
  This method also issues a warning if unsupported arguments are provided.
286
  """
287
  passed_kwargs = {key: value for (key, value) in kwargs.items() if key in self.decode_kwargs}
 
447
  self.chat_template = {
448
  lang: f"System: {sys_msg}" + "{{- '\\n'}}\n" + chat_template
449
  for lang, sys_msg in self.system_messages_by_lang.items()
450
+ }