Spaces:
Runtime error
Runtime error
Commit
·
40deb12
1
Parent(s):
a803714
Removed all GPU-requiring operations from the app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,6 @@ import torch
|
|
13 |
|
14 |
model = DiffusionModule("./diffusion_configs.yaml")
|
15 |
model.load_ckpt("./data/model.ckpt")
|
16 |
-
model.cuda().half()
|
17 |
model.eval();
|
18 |
|
19 |
# Loading a baseline noise for making predictions
|
@@ -21,10 +20,8 @@ model.eval();
|
|
21 |
seed = 3407
|
22 |
np.random.seed(seed)
|
23 |
torch.random.manual_seed(seed)
|
24 |
-
torch.cuda.manual_seed(seed)
|
25 |
-
torch.cuda.manual_seed_all(seed)
|
26 |
torch.backends.cudnn.deterministic = True
|
27 |
-
BASELINE_NOISE = torch.randn(1, 1, 256, 256).
|
28 |
|
29 |
# Model helper functions
|
30 |
|
@@ -88,13 +85,13 @@ def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=F
|
|
88 |
cls_value = torch.tensor([2, *angles, *fp])
|
89 |
else:
|
90 |
cls_value = torch.tensor([1, *angles, *fp])
|
91 |
-
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1)
|
92 |
|
93 |
# Generate noise
|
94 |
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
|
95 |
model_kwargs = {
|
96 |
"cls": cls_batch,
|
97 |
-
"concat": input_batch["img"]
|
98 |
}
|
99 |
|
100 |
# Make predictions
|
|
|
13 |
|
14 |
model = DiffusionModule("./diffusion_configs.yaml")
|
15 |
model.load_ckpt("./data/model.ckpt")
|
|
|
16 |
model.eval();
|
17 |
|
18 |
# Loading a baseline noise for making predictions
|
|
|
20 |
seed = 3407
|
21 |
np.random.seed(seed)
|
22 |
torch.random.manual_seed(seed)
|
|
|
|
|
23 |
torch.backends.cudnn.deterministic = True
|
24 |
+
BASELINE_NOISE = torch.randn(1, 1, 256, 256).half()
|
25 |
|
26 |
# Model helper functions
|
27 |
|
|
|
85 |
cls_value = torch.tensor([2, *angles, *fp])
|
86 |
else:
|
87 |
cls_value = torch.tensor([1, *angles, *fp])
|
88 |
+
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1)
|
89 |
|
90 |
# Generate noise
|
91 |
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
|
92 |
model_kwargs = {
|
93 |
"cls": cls_batch,
|
94 |
+
"concat": input_batch["img"]
|
95 |
}
|
96 |
|
97 |
# Make predictions
|