timroelofs123 commited on
Commit
4a7c70f
·
1 Parent(s): a33fac3

fix cuda error

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -12,7 +12,7 @@ from scripts.test_functions import process_image
12
  model_path = "hf/best_unet_model.pth"
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
  unet_model = UNet().to(device)
15
- unet_model.load_state_dict(torch.load(model_path))
16
  unet_model.eval()
17
 
18
  def block(image, source_age, target_age):
 
12
  model_path = "hf/best_unet_model.pth"
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
  unet_model = UNet().to(device)
15
+ unet_model.load_state_dict(torch.load(model_path, map_location=device))
16
  unet_model.eval()
17
 
18
  def block(image, source_age, target_age):