lysandre HF staff commited on
Commit
c236734
·
1 Parent(s): 7174749

Delete convert.py

Browse files
Files changed (1) hide show
  1. convert.py +0 -382
convert.py DELETED
@@ -1,382 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
- import shutil
5
- from collections import defaultdict
6
- from tempfile import TemporaryDirectory
7
- from typing import Dict, List, Optional, Set, Tuple
8
-
9
- import torch
10
-
11
- from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
12
- from huggingface_hub.file_download import repo_folder_name
13
- from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
14
-
15
-
16
- COMMIT_DESCRIPTION = """
17
- This is an automated PR created with https://huggingface.co/spaces/safetensors/convert
18
-
19
- This new file is equivalent to `pytorch_model.bin` but safe in the sense that
20
- no arbitrary code can be put into it.
21
-
22
- These files also happen to load much faster than their pytorch counterpart:
23
- https://colab.research.google.com/github/huggingface/notebooks/blob/main/safetensors_doc/en/speed.ipynb
24
-
25
- The widgets on your model page will run using this model even if this is not merged
26
- making sure the file actually works.
27
-
28
- If you find any issues: please report here: https://huggingface.co/spaces/safetensors/convert/discussions
29
-
30
- Feel free to ignore this PR.
31
- """
32
-
33
- ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
34
-
35
-
36
- def _remove_duplicate_names(
37
- state_dict: Dict[str, torch.Tensor],
38
- *,
39
- preferred_names: List[str] = None,
40
- discard_names: List[str] = None,
41
- ) -> Dict[str, List[str]]:
42
- if preferred_names is None:
43
- preferred_names = []
44
- preferred_names = set(preferred_names)
45
- if discard_names is None:
46
- discard_names = []
47
- discard_names = set(discard_names)
48
-
49
- shareds = _find_shared_tensors(state_dict)
50
- to_remove = defaultdict(list)
51
- for shared in shareds:
52
- complete_names = set([name for name in shared if _is_complete(state_dict[name])])
53
- if not complete_names:
54
- if len(shared) == 1:
55
- # Force contiguous
56
- name = list(shared)[0]
57
- state_dict[name] = state_dict[name].clone()
58
- complete_names = {name}
59
- else:
60
- raise RuntimeError(
61
- f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
62
- )
63
-
64
- keep_name = sorted(list(complete_names))[0]
65
-
66
- # Mecanism to preferentially select keys to keep
67
- # coming from the on-disk file to allow
68
- # loading models saved with a different choice
69
- # of keep_name
70
- preferred = complete_names.difference(discard_names)
71
- if preferred:
72
- keep_name = sorted(list(preferred))[0]
73
-
74
- if preferred_names:
75
- preferred = preferred_names.intersection(complete_names)
76
- if preferred:
77
- keep_name = sorted(list(preferred))[0]
78
- for name in sorted(shared):
79
- if name != keep_name:
80
- to_remove[keep_name].append(name)
81
- return to_remove
82
-
83
-
84
- def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
85
- try:
86
- import json
87
-
88
- import transformers
89
-
90
- config_filename = hf_hub_download(
91
- model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
92
- )
93
- with open(config_filename, "r") as f:
94
- config = json.load(f)
95
- architecture = config["architectures"][0]
96
-
97
- class_ = getattr(transformers, architecture)
98
-
99
- # Name for this varible depends on transformers version.
100
- discard_names = getattr(class_, "_tied_weights_keys", [])
101
-
102
- except Exception:
103
- discard_names = []
104
- return discard_names
105
-
106
-
107
- class AlreadyExists(Exception):
108
- pass
109
-
110
-
111
- def check_file_size(sf_filename: str, pt_filename: str):
112
- sf_size = os.stat(sf_filename).st_size
113
- pt_size = os.stat(pt_filename).st_size
114
-
115
- if (sf_size - pt_size) / pt_size > 0.01:
116
- raise RuntimeError(
117
- f"""The file size different is more than 1%:
118
- - {sf_filename}: {sf_size}
119
- - {pt_filename}: {pt_size}
120
- """
121
- )
122
-
123
-
124
- def rename(pt_filename: str) -> str:
125
- filename, ext = os.path.splitext(pt_filename)
126
- local = f"{filename}.safetensors"
127
- local = local.replace("pytorch_model", "model")
128
- return local
129
-
130
-
131
- def convert_multi(
132
- model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]
133
- ) -> ConversionResult:
134
- filename = hf_hub_download(
135
- repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder
136
- )
137
- with open(filename, "r") as f:
138
- data = json.load(f)
139
-
140
- filenames = set(data["weight_map"].values())
141
- local_filenames = []
142
- for filename in filenames:
143
- pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token, cache_dir=folder)
144
-
145
- sf_filename = rename(pt_filename)
146
- sf_filename = os.path.join(folder, sf_filename)
147
- convert_file(pt_filename, sf_filename, discard_names=discard_names)
148
- local_filenames.append(sf_filename)
149
-
150
- index = os.path.join(folder, "model.safetensors.index.json")
151
- with open(index, "w") as f:
152
- newdata = {k: v for k, v in data.items()}
153
- newmap = {k: rename(v) for k, v in data["weight_map"].items()}
154
- newdata["weight_map"] = newmap
155
- json.dump(newdata, f, indent=4)
156
- local_filenames.append(index)
157
-
158
- operations = [
159
- CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames
160
- ]
161
- errors: List[Tuple[str, "Exception"]] = []
162
-
163
- return operations, errors
164
-
165
-
166
- def convert_single(
167
- model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
168
- ) -> ConversionResult:
169
- pt_filename = hf_hub_download(
170
- repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder
171
- )
172
-
173
- sf_name = "model.safetensors"
174
- sf_filename = os.path.join(folder, sf_name)
175
- convert_file(pt_filename, sf_filename, discard_names)
176
- operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)]
177
- errors: List[Tuple[str, "Exception"]] = []
178
- return operations, errors
179
-
180
-
181
- def convert_file(
182
- pt_filename: str,
183
- sf_filename: str,
184
- discard_names: List[str],
185
- ):
186
- loaded = torch.load(pt_filename, map_location="cpu")
187
- if "state_dict" in loaded:
188
- loaded = loaded["state_dict"]
189
- to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
190
-
191
- metadata = {"format": "pt"}
192
- for kept_name, to_remove_group in to_removes.items():
193
- for to_remove in to_remove_group:
194
- if to_remove not in metadata:
195
- metadata[to_remove] = kept_name
196
- del loaded[to_remove]
197
- # Force tensors to be contiguous
198
- loaded = {k: v.contiguous() for k, v in loaded.items()}
199
-
200
- dirname = os.path.dirname(sf_filename)
201
- os.makedirs(dirname, exist_ok=True)
202
- save_file(loaded, sf_filename, metadata=metadata)
203
- check_file_size(sf_filename, pt_filename)
204
- reloaded = load_file(sf_filename)
205
- for k in loaded:
206
- pt_tensor = loaded[k]
207
- sf_tensor = reloaded[k]
208
- if not torch.equal(pt_tensor, sf_tensor):
209
- raise RuntimeError(f"The output tensors do not match for key {k}")
210
-
211
-
212
- def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]]) -> str:
213
- errors = []
214
- for key in ["missing_keys", "mismatched_keys", "unexpected_keys"]:
215
- pt_set = set(pt_infos[key])
216
- sf_set = set(sf_infos[key])
217
-
218
- pt_only = pt_set - sf_set
219
- sf_only = sf_set - pt_set
220
-
221
- if pt_only:
222
- errors.append(f"{key} : PT warnings contain {pt_only} which are not present in SF warnings")
223
- if sf_only:
224
- errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
225
- return "\n".join(errors)
226
-
227
-
228
- def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
229
- try:
230
- revision_commit = api.model_info(model_id, revision=revision).sha
231
- discussions = api.get_repo_discussions(repo_id=model_id)
232
- except Exception:
233
- return None
234
- for discussion in discussions:
235
- if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
236
- commits = api.list_repo_commits(model_id, revision=discussion.git_reference)
237
-
238
- if revision_commit == commits[1].commit_id:
239
- return discussion
240
- return None
241
-
242
-
243
- def convert_generic(
244
- model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
245
- ) -> ConversionResult:
246
- operations = []
247
- errors = []
248
-
249
- extensions = set([".bin", ".ckpt"])
250
- for filename in filenames:
251
- prefix, ext = os.path.splitext(filename)
252
- if ext in extensions:
253
- pt_filename = hf_hub_download(
254
- model_id, revision=revision, filename=filename, token=token, cache_dir=folder
255
- )
256
- dirname, raw_filename = os.path.split(filename)
257
- if raw_filename == "pytorch_model.bin":
258
- # XXX: This is a special case to handle `transformers` and the
259
- # `transformers` part of the model which is actually loaded by `transformers`.
260
- sf_in_repo = os.path.join(dirname, "model.safetensors")
261
- else:
262
- sf_in_repo = f"{prefix}.safetensors"
263
- sf_filename = os.path.join(folder, sf_in_repo)
264
- try:
265
- convert_file(pt_filename, sf_filename, discard_names=[])
266
- operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
267
- except Exception as e:
268
- errors.append((pt_filename, e))
269
- return operations, errors
270
-
271
-
272
- def convert(
273
- api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False
274
- ) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
275
- pr_title = "Adding `safetensors` variant of this model"
276
- info = api.model_info(model_id, revision=revision)
277
- filenames = set(s.rfilename for s in info.siblings)
278
-
279
- with TemporaryDirectory() as d:
280
- folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
281
- os.makedirs(folder)
282
- new_pr = None
283
- try:
284
- operations = None
285
- pr = previous_pr(api, model_id, pr_title, revision=revision)
286
-
287
- library_name = getattr(info, "library_name", None)
288
- if any(filename.endswith(".safetensors") for filename in filenames) and not force:
289
- raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
290
- elif pr is not None and not force:
291
- url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
292
- new_pr = pr
293
- raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
294
- elif library_name == "transformers":
295
-
296
- discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
297
- if "pytorch_model.bin" in filenames:
298
- operations, errors = convert_single(
299
- model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
300
- )
301
- elif "pytorch_model.bin.index.json" in filenames:
302
- operations, errors = convert_multi(
303
- model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
304
- )
305
- else:
306
- raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
307
- else:
308
- operations, errors = convert_generic(
309
- model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
310
- )
311
-
312
- if operations:
313
- new_pr = api.create_commit(
314
- repo_id=model_id,
315
- revision=revision,
316
- operations=operations,
317
- commit_message=pr_title,
318
- commit_description=COMMIT_DESCRIPTION,
319
- create_pr=True,
320
- )
321
- print(f"Pr created at {new_pr.pr_url}")
322
- else:
323
- print("No files to convert")
324
- finally:
325
- shutil.rmtree(folder)
326
- return new_pr, errors
327
-
328
-
329
- if __name__ == "__main__":
330
- DESCRIPTION = """
331
- Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
332
- It is PyTorch exclusive for now.
333
- It works by downloading the weights (PT), converting them locally, and uploading them back
334
- as a PR on the hub.
335
- """
336
- parser = argparse.ArgumentParser(description=DESCRIPTION)
337
- parser.add_argument(
338
- "model_id",
339
- type=str,
340
- help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
341
- )
342
- parser.add_argument(
343
- "--revision",
344
- type=str,
345
- help="The revision to convert",
346
- )
347
- parser.add_argument(
348
- "--force",
349
- action="store_true",
350
- help="Create the PR even if it already exists of if the model was already converted.",
351
- )
352
- parser.add_argument(
353
- "-y",
354
- action="store_true",
355
- help="Ignore safety prompt",
356
- )
357
- args = parser.parse_args()
358
- model_id = args.model_id
359
- api = HfApi()
360
- if args.y:
361
- txt = "y"
362
- else:
363
- txt = input(
364
- "This conversion script will unpickle a pickled file, which is inherently unsafe. If you do not trust this file, we invite you to use"
365
- " https://huggingface.co/spaces/safetensors/convert or google colab or other hosted solution to avoid potential issues with this file."
366
- " Continue [Y/n] ?"
367
- )
368
- if txt.lower() in {"", "y"}:
369
- commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force)
370
- string = f"""
371
- ### Success 🔥
372
- Yay! This model was successfully converted and a PR was open using your token, here:
373
- [{commit_info.pr_url}]({commit_info.pr_url})
374
- """
375
- if errors:
376
- string += "\nErrors during conversion:\n"
377
- string += "\n".join(
378
- f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
379
- )
380
- print(string)
381
- else:
382
- print(f"Answer was `{txt}` aborting.")