Spaces:
Runtime error
Runtime error
Commit
·
4a7c70f
1
Parent(s):
a33fac3
fix cuda error
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ from scripts.test_functions import process_image
|
|
12 |
model_path = "hf/best_unet_model.pth"
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
unet_model = UNet().to(device)
|
15 |
-
unet_model.load_state_dict(torch.load(model_path))
|
16 |
unet_model.eval()
|
17 |
|
18 |
def block(image, source_age, target_age):
|
|
|
12 |
model_path = "hf/best_unet_model.pth"
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
unet_model = UNet().to(device)
|
15 |
+
unet_model.load_state_dict(torch.load(model_path, map_location=device))
|
16 |
unet_model.eval()
|
17 |
|
18 |
def block(image, source_age, target_age):
|