Pouriarouzrokh commited on
Commit
40deb12
·
1 Parent(s): a803714

Removed all GPU-requiring operations from the app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -13,7 +13,6 @@ import torch
13
 
14
  model = DiffusionModule("./diffusion_configs.yaml")
15
  model.load_ckpt("./data/model.ckpt")
16
- model.cuda().half()
17
  model.eval();
18
 
19
  # Loading a baseline noise for making predictions
@@ -21,10 +20,8 @@ model.eval();
21
  seed = 3407
22
  np.random.seed(seed)
23
  torch.random.manual_seed(seed)
24
- torch.cuda.manual_seed(seed)
25
- torch.cuda.manual_seed_all(seed)
26
  torch.backends.cudnn.deterministic = True
27
- BASELINE_NOISE = torch.randn(1, 1, 256, 256).cuda().half()
28
 
29
  # Model helper functions
30
 
@@ -88,13 +85,13 @@ def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=F
88
  cls_value = torch.tensor([2, *angles, *fp])
89
  else:
90
  cls_value = torch.tensor([1, *angles, *fp])
91
- cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1).cuda().half()
92
 
93
  # Generate noise
94
  noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
95
  model_kwargs = {
96
  "cls": cls_batch,
97
- "concat": input_batch["img"].cuda().half(),
98
  }
99
 
100
  # Make predictions
 
13
 
14
  model = DiffusionModule("./diffusion_configs.yaml")
15
  model.load_ckpt("./data/model.ckpt")
 
16
  model.eval();
17
 
18
  # Loading a baseline noise for making predictions
 
20
  seed = 3407
21
  np.random.seed(seed)
22
  torch.random.manual_seed(seed)
 
 
23
  torch.backends.cudnn.deterministic = True
24
+ BASELINE_NOISE = torch.randn(1, 1, 256, 256).half()
25
 
26
  # Model helper functions
27
 
 
85
  cls_value = torch.tensor([2, *angles, *fp])
86
  else:
87
  cls_value = torch.tensor([1, *angles, *fp])
88
+ cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1)
89
 
90
  # Generate noise
91
  noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
92
  model_kwargs = {
93
  "cls": cls_batch,
94
+ "concat": input_batch["img"]
95
  }
96
 
97
  # Make predictions