HoneyTian commited on
Commit
30c0bb4
·
1 Parent(s): df6a635
examples/vm_sound_classification/step_3_train_model.py CHANGED
@@ -186,7 +186,8 @@ def main():
186
 
187
  if args.pretrained_model is not None and os.path.exists(args.pretrained_model):
188
  logger.info(f"load pretrained model state dict from: {args.pretrained_model}")
189
- with zipfile.ZipFile(args.pretrained_model, "r") as f_zip:
 
190
  out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
191
  # print(out_root.as_posix())
192
  if out_root.exists():
@@ -194,7 +195,7 @@ def main():
194
  out_root.mkdir(parents=True, exist_ok=True)
195
  f_zip.extractall(path=out_root)
196
 
197
- tgt_path = out_root / os.path.basename(args.pretrained_model)
198
  model_pt_file = tgt_path / "model.pt"
199
  with open(model_pt_file, "r") as f:
200
  state_dict = torch.load(f, map_location="cpu")
 
186
 
187
  if args.pretrained_model is not None and os.path.exists(args.pretrained_model):
188
  logger.info(f"load pretrained model state dict from: {args.pretrained_model}")
189
+ pretrained_model = Path(args.pretrained_model)
190
+ with zipfile.ZipFile(pretrained_model.as_posix(), "r") as f_zip:
191
  out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
192
  # print(out_root.as_posix())
193
  if out_root.exists():
 
195
  out_root.mkdir(parents=True, exist_ok=True)
196
  f_zip.extractall(path=out_root)
197
 
198
+ tgt_path = out_root / pretrained_model.stem
199
  model_pt_file = tgt_path / "model.pt"
200
  with open(model_pt_file, "r") as f:
201
  state_dict = torch.load(f, map_location="cpu")