52Hz commited on
Commit
6b1147f
·
1 Parent(s): d1a6db0

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +5 -5
predict.py CHANGED
@@ -8,7 +8,7 @@ import glob
8
  import torch
9
  from skimage import img_as_ubyte
10
  from PIL import Image
11
- from model.SRMNet import SRMNet
12
  from main_test_SRMNet import save_img, setup
13
  import torchvision.transforms.functional as TF
14
  import torch.nn.functional as F
@@ -16,7 +16,7 @@ import torch.nn.functional as F
16
 
17
  class Predictor(cog.Predictor):
18
  def setup(self):
19
- model_dir = 'experiments/pretrained_models/AWGN_denoising_SRMNet.pth'
20
 
21
  parser = argparse.ArgumentParser(description='Demo Image Denoising')
22
  parser.add_argument('--input_dir', default='./test/', type=str, help='Input images')
@@ -38,7 +38,7 @@ class Predictor(cog.Predictor):
38
  shutil.copy(str(image), input_path)
39
 
40
  # Load corresponding models architecture and weights
41
- model = SRMNet()
42
  model.eval()
43
  model = model.to(self.device)
44
 
@@ -46,7 +46,7 @@ class Predictor(cog.Predictor):
46
  os.makedirs(save_dir, exist_ok=True)
47
 
48
  out_path = Path(tempfile.mkdtemp()) / "out.png"
49
- mul = 16
50
  for file_ in sorted(glob.glob(os.path.join(folder, '*.PNG'))):
51
  img = Image.open(file_).convert('RGB')
52
  input_ = TF.to_tensor(img).unsqueeze(0).cuda()
@@ -60,7 +60,7 @@ class Predictor(cog.Predictor):
60
  with torch.no_grad():
61
  restored = model(input_)
62
 
63
- restored = torch.clamp(restored, 0, 1)
64
  restored = restored[:, :, :h, :w]
65
  restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
66
  restored = img_as_ubyte(restored[0])
 
8
  import torch
9
  from skimage import img_as_ubyte
10
  from PIL import Image
11
+ from model.CMFNet import CMFNet
12
  from main_test_SRMNet import save_img, setup
13
  import torchvision.transforms.functional as TF
14
  import torch.nn.functional as F
 
16
 
17
  class Predictor(cog.Predictor):
18
  def setup(self):
19
+ model_dir = 'experiments/pretrained_models/deraindrop_model.pth'
20
 
21
  parser = argparse.ArgumentParser(description='Demo Image Denoising')
22
  parser.add_argument('--input_dir', default='./test/', type=str, help='Input images')
 
38
  shutil.copy(str(image), input_path)
39
 
40
  # Load corresponding models architecture and weights
41
+ model = CMFNet()
42
  model.eval()
43
  model = model.to(self.device)
44
 
 
46
  os.makedirs(save_dir, exist_ok=True)
47
 
48
  out_path = Path(tempfile.mkdtemp()) / "out.png"
49
+ mul = 8
50
  for file_ in sorted(glob.glob(os.path.join(folder, '*.PNG'))):
51
  img = Image.open(file_).convert('RGB')
52
  input_ = TF.to_tensor(img).unsqueeze(0).cuda()
 
60
  with torch.no_grad():
61
  restored = model(input_)
62
 
63
+ restored = torch.clamp(restored[0], 0, 1)
64
  restored = restored[:, :, :h, :w]
65
  restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
66
  restored = img_as_ubyte(restored[0])