atiwari751 commited on
Commit
0880bdc
·
1 Parent(s): 28ba8c0

updated app to support CPU

Browse files
Files changed (1) hide show
  1. utils.py +5 -1
utils.py CHANGED
@@ -12,12 +12,16 @@ def save_checkpoint(model, optimizer, epoch, loss, path):
12
  print(f"Checkpoint saved at epoch {epoch}")
13
 
14
  def load_checkpoint(model, optimizer, checkpoint_path):
15
- checkpoint = torch.load(checkpoint_path, weights_only=True)
 
 
 
16
  model.load_state_dict(checkpoint['model_state_dict'])
17
  if optimizer is not None:
18
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
19
  start_epoch = checkpoint['epoch']
20
  loss = checkpoint['loss']
 
21
  return model, optimizer, start_epoch, loss
22
 
23
  def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates):
 
12
  print(f"Checkpoint saved at epoch {epoch}")
13
 
14
  def load_checkpoint(model, optimizer, checkpoint_path):
15
+ # Use map_location to load the checkpoint on CPU if CUDA is not available
16
+ map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
18
+
19
  model.load_state_dict(checkpoint['model_state_dict'])
20
  if optimizer is not None:
21
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
22
  start_epoch = checkpoint['epoch']
23
  loss = checkpoint['loss']
24
+
25
  return model, optimizer, start_epoch, loss
26
 
27
  def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates):