allow inference of hyperpcm on cpu
Browse files
app.py
CHANGED
@@ -143,7 +143,7 @@ def retrieval():
|
|
143 |
memory = dataset
|
144 |
model = HyperPCM(memory=memory).to(device)
|
145 |
model = torch.nn.DataParallel(model)
|
146 |
-
model.load_state_dict(torch.load(checkpoint_path))
|
147 |
model.eval()
|
148 |
|
149 |
with torch.set_grad_enabled(False):
|
|
|
143 |
memory = dataset
|
144 |
model = HyperPCM(memory=memory).to(device)
|
145 |
model = torch.nn.DataParallel(model)
|
146 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
|
147 |
model.eval()
|
148 |
|
149 |
with torch.set_grad_enabled(False):
|