Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update convert.py
Browse files- convert.py +9 -2
convert.py
CHANGED
|
@@ -113,13 +113,20 @@ def convert_file(
|
|
| 113 |
loaded = torch.load(pt_filename, map_location="cpu")
|
| 114 |
if "state_dict" in loaded:
|
| 115 |
loaded = loaded["state_dict"]
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
# For tensors to be contiguous
|
| 118 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
| 119 |
|
| 120 |
dirname = os.path.dirname(sf_filename)
|
| 121 |
os.makedirs(dirname, exist_ok=True)
|
| 122 |
-
save_file(loaded, sf_filename, metadata=
|
| 123 |
reloaded = load_file(sf_filename)
|
| 124 |
for k in loaded:
|
| 125 |
pt_tensor = loaded[k]
|
|
|
|
| 113 |
loaded = torch.load(pt_filename, map_location="cpu")
|
| 114 |
if "state_dict" in loaded:
|
| 115 |
loaded = loaded["state_dict"]
|
| 116 |
+
to_removes = _remove_duplicate_names(loaded)
|
| 117 |
+
|
| 118 |
+
metadata = {"format": "pt"}
|
| 119 |
+
for kept_name, to_remove_group in to_removes.items():
|
| 120 |
+
for to_remove in to_remove_group:
|
| 121 |
+
if to_remove not in metadata:
|
| 122 |
+
metadata[to_remove] = kept_name
|
| 123 |
+
del loaded[to_remove]
|
| 124 |
# For tensors to be contiguous
|
| 125 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
| 126 |
|
| 127 |
dirname = os.path.dirname(sf_filename)
|
| 128 |
os.makedirs(dirname, exist_ok=True)
|
| 129 |
+
save_file(loaded, sf_filename, metadata=metadata)
|
| 130 |
reloaded = load_file(sf_filename)
|
| 131 |
for k in loaded:
|
| 132 |
pt_tensor = loaded[k]
|