Spaces:
Running
Running
update
Browse files
examples/vm_sound_classification/step_3_train_model.py
CHANGED
@@ -186,7 +186,8 @@ def main():
|
|
186 |
|
187 |
if args.pretrained_model is not None and os.path.exists(args.pretrained_model):
|
188 |
logger.info(f"load pretrained model state dict from: {args.pretrained_model}")
|
189 |
-
|
|
|
190 |
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
191 |
# print(out_root.as_posix())
|
192 |
if out_root.exists():
|
@@ -194,7 +195,7 @@ def main():
|
|
194 |
out_root.mkdir(parents=True, exist_ok=True)
|
195 |
f_zip.extractall(path=out_root)
|
196 |
|
197 |
-
tgt_path = out_root /
|
198 |
model_pt_file = tgt_path / "model.pt"
|
199 |
with open(model_pt_file, "r") as f:
|
200 |
state_dict = torch.load(f, map_location="cpu")
|
|
|
186 |
|
187 |
if args.pretrained_model is not None and os.path.exists(args.pretrained_model):
|
188 |
logger.info(f"load pretrained model state dict from: {args.pretrained_model}")
|
189 |
+
pretrained_model = Path(args.pretrained_model)
|
190 |
+
with zipfile.ZipFile(pretrained_model.as_posix(), "r") as f_zip:
|
191 |
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
192 |
# print(out_root.as_posix())
|
193 |
if out_root.exists():
|
|
|
195 |
out_root.mkdir(parents=True, exist_ok=True)
|
196 |
f_zip.extractall(path=out_root)
|
197 |
|
198 |
+
tgt_path = out_root / pretrained_model.stem
|
199 |
model_pt_file = tgt_path / "model.pt"
|
200 |
with open(model_pt_file, "r") as f:
|
201 |
state_dict = torch.load(f, map_location="cpu")
|