Spaces:
Build error
Build error
File size: 1,422 Bytes
1951449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import time
import torchvision
import cv2
import torchvision.utils as tvu
import torch.functional as F
import argparse
def inference_img(haze_path,Net):
haze_image = Image.open(haze_path).convert('RGB')
enhance_transforms = transforms.Compose([
transforms.Resize((400,400)),
transforms.ToTensor()
])
print(haze_image.size)
with torch.no_grad():
haze_image = enhance_transforms(haze_image)
#print(haze_image)
haze_image = haze_image.unsqueeze(0)
start = time.time()
restored2 = Net(haze_image)
end = time.time()
return restored2,end-start
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--test_path',type=str,required=True,help='Path to test')
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
parser.add_argument('--pk_path',type=str,default='model_zoo/Haze4k.tjm',help='Path of the checkpoint')
opt = parser.parse_args()
if not os.path.isdir(opt.save_path):
os.mkdir(opt.save_path)
Net=torch.jit.load(opt.pk_path,map_location=torch.device('cpu')).eval()
image = opt.test_path
print(image)
restored2,time_num = inference_img(image,Net)
torchvision.utils.save_image(restored2,opt.save_path+os.path.split(image)[-1])
|