wizzseen's picture
Upload 948 files
8a6df40 verified
import sys
sys.path.append('./')
# PyTorch includes
import torch
import numpy as np
from utils import test_human
from PIL import Image
#
import argparse
def get_parser():
'''argparse begin'''
parser = argparse.ArgumentParser()
LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--batch', default=16, type=int)
parser.add_argument('--lr', default=1e-7, type=float)
parser.add_argument('--numworker',default=12,type=int)
parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
parser.add_argument('--step', default=30, type=int)
parser.add_argument('--txt_file',default=None,type=str)
parser.add_argument('--pred_path',default=None,type=str)
parser.add_argument('--gt_path',default=None,type=str)
parser.add_argument('--classes', default=7, type=int)
parser.add_argument('--testepoch', default=10, type=int)
opts = parser.parse_args()
return opts
def eval_(pred_path, gt_path, classes, txt_file):
pred_path = pred_path
gt_path = gt_path
with open(txt_file,) as f:
lines = f.readlines()
lines = [x.strip() for x in lines]
output_list = []
label_list = []
for i,file in enumerate(lines):
print(i)
file_name = file + '.png'
try:
predict_pic = np.array(Image.open(pred_path+file_name))
gt_pic = np.array(Image.open(gt_path+file_name))
output_list.append(torch.from_numpy(predict_pic))
label_list.append(torch.from_numpy(gt_pic))
except:
print(file_name,flush=True)
raise RuntimeError('no predict/gt image.')
# gt_pic = np.array(Image.open(gt_path + file_name))
# output_list.append(torch.from_numpy(gt_pic))
# label_list.append(torch.from_numpy(gt_pic))
miou = test_human.get_iou_from_list(output_list, label_list, n_cls=classes)
print('Validation:')
print('MIoU: %f\n' % miou)
if __name__ == '__main__':
opts = get_parser()
eval_(pred_path=opts.pred_path, gt_path=opts.gt_path, classes=opts.classes, txt_file=opts.txt_file)