Spaces:
Runtime error
Runtime error
Commit
·
40deb12
1
Parent(s):
a803714
Removed all GPU-requiring operations from the app.py
Browse files
app.py
CHANGED
|
@@ -13,7 +13,6 @@ import torch
|
|
| 13 |
|
| 14 |
model = DiffusionModule("./diffusion_configs.yaml")
|
| 15 |
model.load_ckpt("./data/model.ckpt")
|
| 16 |
-
model.cuda().half()
|
| 17 |
model.eval();
|
| 18 |
|
| 19 |
# Loading a baseline noise for making predictions
|
|
@@ -21,10 +20,8 @@ model.eval();
|
|
| 21 |
seed = 3407
|
| 22 |
np.random.seed(seed)
|
| 23 |
torch.random.manual_seed(seed)
|
| 24 |
-
torch.cuda.manual_seed(seed)
|
| 25 |
-
torch.cuda.manual_seed_all(seed)
|
| 26 |
torch.backends.cudnn.deterministic = True
|
| 27 |
-
BASELINE_NOISE = torch.randn(1, 1, 256, 256).
|
| 28 |
|
| 29 |
# Model helper functions
|
| 30 |
|
|
@@ -88,13 +85,13 @@ def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=F
|
|
| 88 |
cls_value = torch.tensor([2, *angles, *fp])
|
| 89 |
else:
|
| 90 |
cls_value = torch.tensor([1, *angles, *fp])
|
| 91 |
-
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1)
|
| 92 |
|
| 93 |
# Generate noise
|
| 94 |
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
|
| 95 |
model_kwargs = {
|
| 96 |
"cls": cls_batch,
|
| 97 |
-
"concat": input_batch["img"]
|
| 98 |
}
|
| 99 |
|
| 100 |
# Make predictions
|
|
|
|
| 13 |
|
| 14 |
model = DiffusionModule("./diffusion_configs.yaml")
|
| 15 |
model.load_ckpt("./data/model.ckpt")
|
|
|
|
| 16 |
model.eval();
|
| 17 |
|
| 18 |
# Loading a baseline noise for making predictions
|
|
|
|
| 20 |
seed = 3407
|
| 21 |
np.random.seed(seed)
|
| 22 |
torch.random.manual_seed(seed)
|
|
|
|
|
|
|
| 23 |
torch.backends.cudnn.deterministic = True
|
| 24 |
+
BASELINE_NOISE = torch.randn(1, 1, 256, 256).half()
|
| 25 |
|
| 26 |
# Model helper functions
|
| 27 |
|
|
|
|
| 85 |
cls_value = torch.tensor([2, *angles, *fp])
|
| 86 |
else:
|
| 87 |
cls_value = torch.tensor([1, *angles, *fp])
|
| 88 |
+
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1)
|
| 89 |
|
| 90 |
# Generate noise
|
| 91 |
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
|
| 92 |
model_kwargs = {
|
| 93 |
"cls": cls_batch,
|
| 94 |
+
"concat": input_batch["img"]
|
| 95 |
}
|
| 96 |
|
| 97 |
# Make predictions
|