Update convert.py
Browse files- convert.py +4 -2
convert.py
CHANGED
|
@@ -95,7 +95,9 @@ def convert_file(
|
|
| 95 |
pt_filename: str,
|
| 96 |
sf_filename: str,
|
| 97 |
):
|
| 98 |
-
loaded = torch.load(pt_filename
|
|
|
|
|
|
|
| 99 |
shared = shared_pointers(loaded)
|
| 100 |
for shared_weights in shared:
|
| 101 |
for name in shared_weights[1:]:
|
|
@@ -238,7 +240,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
|
|
| 238 |
operations = convert_multi(model_id, folder)
|
| 239 |
else:
|
| 240 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 241 |
-
|
| 242 |
else:
|
| 243 |
operations = convert_generic(model_id, folder, filenames)
|
| 244 |
|
|
|
|
| 95 |
pt_filename: str,
|
| 96 |
sf_filename: str,
|
| 97 |
):
|
| 98 |
+
loaded = torch.load(pt_filename)
|
| 99 |
+
if "state_dict" in loaded:
|
| 100 |
+
loaded = loaded["state_dict"]
|
| 101 |
shared = shared_pointers(loaded)
|
| 102 |
for shared_weights in shared:
|
| 103 |
for name in shared_weights[1:]:
|
|
|
|
| 240 |
operations = convert_multi(model_id, folder)
|
| 241 |
else:
|
| 242 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 243 |
+
check_final_model(model_id, folder)
|
| 244 |
else:
|
| 245 |
operations = convert_generic(model_id, folder, filenames)
|
| 246 |
|