Spaces:
Build error
Build error
File size: 1,520 Bytes
af8dd52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import time
import torchvision
import cv2
import torchvision.utils as tvu
import torch.functional as F
import argparse
from net.Ushape_Trans import *
def inference_img(img_path,Net,device):
low_image = Image.open(img_path).convert('RGB')
enhance_transforms = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
with torch.no_grad():
low_image = enhance_transforms(low_image)
low_image = low_image.unsqueeze(0)
start = time.time()
restored2 = Net(low_image.to(device))
end = time.time()
return restored2,end-start
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--test_path',type=str,required=True,help='Path to test')
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
parser.add_argument('--pk_path',type=str,default='model_zoo/underwater.pth',help='Path of the checkpoint')
opt = parser.parse_args()
if not os.path.isdir(opt.save_path):
os.mkdir(opt.save_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Net = Generator().eval()
Net.load_state_dict(torch.load(opt.pk_path))
Net = Net.to(device)
image = opt.test_path
print(image)
restored2,time_num = inference_img(image,Net,device)
torchvision.utils.save_image(restored2,opt.save_path+os.path.split(image)[-1])
|