|
import os |
|
import sys |
|
import tqdm |
|
import torch |
|
|
|
import numpy as np |
|
|
|
from PIL import Image |
|
from torch.utils.data.dataloader import DataLoader |
|
|
|
filepath = os.path.split(os.path.abspath(__file__))[0] |
|
repopath = os.path.split(filepath)[0] |
|
sys.path.append(repopath) |
|
|
|
from lib import * |
|
from utils.misc import * |
|
from data.dataloader import * |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
torch.backends.cudnn.allow_tf32 = False |
|
|
|
def test(opt, args): |
|
model = eval(opt.Model.name)(**opt.Model) |
|
model.load_state_dict(torch.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'latest.pth')), strict=True) |
|
|
|
model.cuda() |
|
model.eval() |
|
|
|
if args.verbose is True: |
|
sets = tqdm.tqdm(opt.Test.Dataset.sets, desc='Total TestSet', total=len( |
|
opt.Test.Dataset.sets), position=0, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}') |
|
else: |
|
sets = opt.Test.Dataset.sets |
|
|
|
for set in sets: |
|
save_path = os.path.join(opt.Test.Checkpoint.checkpoint_dir, set) |
|
|
|
os.makedirs(save_path, exist_ok=True) |
|
test_dataset = eval(opt.Test.Dataset.type)(opt.Test.Dataset.root, [set], opt.Test.Dataset.transforms) |
|
test_loader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=opt.Test.Dataloader.num_workers, pin_memory=opt.Test.Dataloader.pin_memory) |
|
|
|
if args.verbose is True: |
|
samples = tqdm.tqdm(test_loader, desc=set + ' - Test', total=len(test_loader), |
|
position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}') |
|
else: |
|
samples = test_loader |
|
|
|
for sample in samples: |
|
sample = to_cuda(sample) |
|
with torch.no_grad(): |
|
out = model(sample) |
|
|
|
pred = to_numpy(out['pred'], sample['shape']) |
|
Image.fromarray((pred * 255).astype(np.uint8)).save(os.path.join(save_path, sample['name'][0] + '.png')) |
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
opt = load_config(args.config) |
|
test(opt, args) |
|
|