emmas96 commited on
Commit
956fc7c
·
1 Parent(s): b5de43a

allow inference of hyperpcm on cpu

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -143,7 +143,7 @@ def retrieval():
143
  memory = dataset
144
  model = HyperPCM(memory=memory).to(device)
145
  model = torch.nn.DataParallel(model)
146
- model.load_state_dict(torch.load(checkpoint_path))
147
  model.eval()
148
 
149
  with torch.set_grad_enabled(False):
 
143
  memory = dataset
144
  model = HyperPCM(memory=memory).to(device)
145
  model = torch.nn.DataParallel(model)
146
+ model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
147
  model.eval()
148
 
149
  with torch.set_grad_enabled(False):