tezuesh commited on
Commit
f2cde80
·
verified ·
1 Parent(s): 06e2d44

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +1 -1
inference.py CHANGED
@@ -19,7 +19,7 @@ class Inferencer:
19
  transforms.Normalize((0.1307,), (0.3081,))
20
  ])
21
 
22
- def _load_model(self, model_path='saved_models/best_model.pth'):
23
  """Load the trained model."""
24
  model = MNISTModel().to(self.device)
25
  model.load_state_dict(
 
19
  transforms.Normalize((0.1307,), (0.3081,))
20
  ])
21
 
22
+ def _load_model(self, model_path='best_model.pth'):
23
  """Load the trained model."""
24
  model = MNISTModel().to(self.device)
25
  model.load_state_dict(