danifei commited on
Commit
da6c643
·
verified ·
1 Parent(s): 86fee05

assert the model runs on gpu if available, if not, on cpu

Browse files
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -15,7 +15,7 @@ from options.options import parse
15
  path_opt = './options/predict/LOLBlur.yml'
16
 
17
  opt = parse(path_opt)
18
-
19
  #define some auxiliary functions
20
  pil_to_tensor = transforms.ToTensor()
21
 
@@ -29,12 +29,10 @@ model = Network(img_channel=opt['network']['img_channels'],
29
  dilations=opt['network']['dilations'],
30
  extra_depth_wise = opt['network']['extra_depth_wise'])
31
 
32
- checkpoints = torch.load(opt['save']['best'])
33
  # print(checkpoints)
34
  model.load_state_dict(checkpoints['model_state_dict'])
35
 
36
-
37
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
38
  model = model.to(device)
39
 
40
  def load_img (filename):
 
15
  path_opt = './options/predict/LOLBlur.yml'
16
 
17
  opt = parse(path_opt)
18
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
  #define some auxiliary functions
20
  pil_to_tensor = transforms.ToTensor()
21
 
 
29
  dilations=opt['network']['dilations'],
30
  extra_depth_wise = opt['network']['extra_depth_wise'])
31
 
32
+ checkpoints = torch.load(opt['save']['best'], map_location=device)
33
  # print(checkpoints)
34
  model.load_state_dict(checkpoints['model_state_dict'])
35
 
 
 
36
  model = model.to(device)
37
 
38
  def load_img (filename):