|
|
|
|
|
import numpy as np |
|
from model import BiSeNet |
|
|
|
import torch |
|
|
|
import os |
|
import os.path as osp |
|
|
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
import cv2 |
|
from pathlib import Path |
|
import configargparse |
|
import tqdm |
|
|
|
|
|
|
|
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', |
|
img_size=(512, 512)): |
|
im = np.array(im) |
|
vis_im = im.copy().astype(np.uint8) |
|
vis_parsing_anno = parsing_anno.copy().astype(np.uint8) |
|
vis_parsing_anno = cv2.resize( |
|
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) |
|
vis_parsing_anno_color = np.zeros( |
|
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) |
|
vis_parsing_anno_color_face = np.zeros( |
|
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) |
|
|
|
num_of_class = np.max(vis_parsing_anno) |
|
|
|
for pi in range(1, 14): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) |
|
for pi in range(14, 16): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) |
|
for pi in range(16, 17): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) |
|
for pi in range(17, num_of_class+1): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) |
|
|
|
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) |
|
index = np.where(vis_parsing_anno == num_of_class-1) |
|
vis_im = cv2.resize(vis_parsing_anno_color, img_size, |
|
interpolation=cv2.INTER_NEAREST) |
|
if save_im: |
|
cv2.imwrite(save_path, vis_im) |
|
|
|
for pi in range(1, 7): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color_face[index[0], index[1], :] = np.array([255, 0, 0]) |
|
for pi in range(10, 14): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color_face[index[0], index[1], :] = np.array([255, 0, 0]) |
|
pad = 5 |
|
vis_parsing_anno_color_face = vis_parsing_anno_color_face.astype(np.uint8) |
|
face_part = (vis_parsing_anno_color_face[..., 0] == 255) & (vis_parsing_anno_color_face[..., 1] == 0) & (vis_parsing_anno_color_face[..., 2] == 0) |
|
face_coords = np.stack(np.nonzero(face_part), axis=-1) |
|
sorted_inds = np.lexsort((-face_coords[:, 0], face_coords[:, 1])) |
|
sorted_face_coords = face_coords[sorted_inds] |
|
u, uid, ucnt = np.unique(sorted_face_coords[:, 1], return_index=True, return_counts=True) |
|
bottom_face_coords = sorted_face_coords[uid] + np.array([pad, 0]) |
|
rows, cols, _ = vis_parsing_anno_color_face.shape |
|
|
|
|
|
bottom_face_coords[:, 0] = np.clip(bottom_face_coords[:, 0], 0, rows - 1) |
|
|
|
y_min = np.min(bottom_face_coords[:, 1]) |
|
y_max = np.max(bottom_face_coords[:, 1]) |
|
|
|
|
|
y_range = y_max - y_min |
|
height_per_part = y_range // 4 |
|
|
|
start_y_part1 = y_min + height_per_part |
|
end_y_part1 = start_y_part1 + height_per_part |
|
|
|
start_y_part2 = end_y_part1 |
|
end_y_part2 = start_y_part2 + height_per_part |
|
|
|
for coord in bottom_face_coords: |
|
x, y = coord |
|
start_x = max(x - pad, 0) |
|
end_x = min(x + pad, rows) |
|
if start_y_part1 <= y <= end_y_part1 or start_y_part2 <= y <= end_y_part2: |
|
vis_parsing_anno_color_face[start_x:end_x, y] = [255, 0, 0] |
|
|
|
|
|
|
|
|
|
|
|
vis_im = cv2.GaussianBlur(vis_parsing_anno_color_face, (9, 9), cv2.BORDER_DEFAULT) |
|
|
|
vis_im = cv2.resize(vis_im, img_size, |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
cv2.imwrite(save_path.replace('.png', '_face.png'), vis_im) |
|
|
|
|
|
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): |
|
|
|
Path(respth).mkdir(parents=True, exist_ok=True) |
|
|
|
print(f'[INFO] loading model...') |
|
n_classes = 19 |
|
net = BiSeNet(n_classes=n_classes) |
|
net.cuda() |
|
net.load_state_dict(torch.load(cp)) |
|
net.eval() |
|
|
|
to_tensor = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
|
|
image_paths = os.listdir(dspth) |
|
|
|
with torch.no_grad(): |
|
for image_path in tqdm.tqdm(image_paths): |
|
if image_path.endswith('.jpg') or image_path.endswith('.png'): |
|
img = Image.open(osp.join(dspth, image_path)) |
|
ori_size = img.size |
|
image = img.resize((512, 512), Image.BILINEAR) |
|
image = image.convert("RGB") |
|
img = to_tensor(image) |
|
|
|
|
|
inputs = torch.unsqueeze(img, 0) |
|
outputs = net(inputs.cuda()) |
|
parsing = outputs.mean(0).cpu().numpy().argmax(0) |
|
image_path = int(image_path[:-4]) |
|
image_path = str(image_path) + '.png' |
|
|
|
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = configargparse.ArgumentParser() |
|
parser.add_argument('--respath', type=str, default='./result/', help='result path for label') |
|
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') |
|
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') |
|
args = parser.parse_args() |
|
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) |
|
|