Spaces:
Sleeping
Sleeping
| import socket | |
| import timeit | |
| import numpy as np | |
| from PIL import Image | |
| from datetime import datetime | |
| import os | |
| import sys | |
| from collections import OrderedDict | |
| sys.path.append('./') | |
| # PyTorch includes | |
| import torch | |
| from torch.autograd import Variable | |
| from torchvision import transforms | |
| import cv2 | |
| # Custom includes | |
| from networks import deeplab_xception_transfer, graph | |
| from dataloaders import custom_transforms as tr | |
| # | |
| import argparse | |
| import torch.nn.functional as F | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| label_colours = [(0,0,0) | |
| , (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0) | |
| , (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)] | |
| def flip(x, dim): | |
| indices = [slice(None)] * x.dim() | |
| indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, | |
| dtype=torch.long, device=x.device) | |
| return x[tuple(indices)] | |
| def flip_cihp(tail_list): | |
| ''' | |
| :param tail_list: tail_list size is 1 x n_class x h x w | |
| :return: | |
| ''' | |
| # tail_list = tail_list[0] | |
| tail_list_rev = [None] * 20 | |
| for xx in range(14): | |
| tail_list_rev[xx] = tail_list[xx].unsqueeze(0) | |
| tail_list_rev[14] = tail_list[15].unsqueeze(0) | |
| tail_list_rev[15] = tail_list[14].unsqueeze(0) | |
| tail_list_rev[16] = tail_list[17].unsqueeze(0) | |
| tail_list_rev[17] = tail_list[16].unsqueeze(0) | |
| tail_list_rev[18] = tail_list[19].unsqueeze(0) | |
| tail_list_rev[19] = tail_list[18].unsqueeze(0) | |
| return torch.cat(tail_list_rev,dim=0) | |
| def decode_labels(mask, num_images=1, num_classes=20): | |
| """Decode batch of segmentation masks. | |
| Args: | |
| mask: result of inference after taking argmax. | |
| num_images: number of images to decode from the batch. | |
| num_classes: number of classes to predict (including background). | |
| Returns: | |
| A batch with num_images RGB images of the same size as the input. | |
| """ | |
| n, h, w = mask.shape | |
| assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % ( | |
| n, num_images) | |
| outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) | |
| for i in range(num_images): | |
| img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) | |
| pixels = img.load() | |
| for j_, j in enumerate(mask[i, :, :]): | |
| for k_, k in enumerate(j): | |
| if k < num_classes: | |
| pixels[k_, j_] = label_colours[k] | |
| outputs[i] = np.array(img) | |
| return outputs | |
| def read_img(img_path): | |
| _img = Image.open(img_path).convert('RGB') # return is RGB pic | |
| return _img | |
| def img_transform(img, transform=None): | |
| sample = {'image': img, 'label': 0} | |
| sample = transform(sample) | |
| return sample | |
| def inference(net, img_path='', output_path='./', output_name='f', use_gpu=True): | |
| ''' | |
| :param net: | |
| :param img_path: | |
| :param output_path: | |
| :return: | |
| ''' | |
| # adj | |
| adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float() | |
| adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3) | |
| adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float()) | |
| adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda() | |
| cihp_adj = graph.preprocess_adj(graph.cihp_graph) | |
| adj3_ = Variable(torch.from_numpy(cihp_adj).float()) | |
| adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda() | |
| # multi-scale | |
| scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75] | |
| img = read_img(img_path) | |
| testloader_list = [] | |
| testloader_flip_list = [] | |
| for pv in scale_list: | |
| composed_transforms_ts = transforms.Compose([ | |
| tr.Scale_only_img(pv), | |
| tr.Normalize_xception_tf_only_img(), | |
| tr.ToTensor_only_img()]) | |
| composed_transforms_ts_flip = transforms.Compose([ | |
| tr.Scale_only_img(pv), | |
| tr.HorizontalFlip_only_img(), | |
| tr.Normalize_xception_tf_only_img(), | |
| tr.ToTensor_only_img()]) | |
| testloader_list.append(img_transform(img, composed_transforms_ts)) | |
| # print(img_transform(img, composed_transforms_ts)) | |
| testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip)) | |
| # print(testloader_list) | |
| start_time = timeit.default_timer() | |
| # One testing epoch | |
| net.eval() | |
| # 1 0.5 0.75 1.25 1.5 1.75 ; flip: | |
| for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)): | |
| inputs, labels = sample_batched[0]['image'], sample_batched[0]['label'] | |
| inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label'] | |
| inputs = inputs.unsqueeze(0) | |
| inputs_f = inputs_f.unsqueeze(0) | |
| inputs = torch.cat((inputs, inputs_f), dim=0) | |
| if iii == 0: | |
| _, _, h, w = inputs.size() | |
| # assert inputs.size() == inputs_f.size() | |
| # Forward pass of the mini-batch | |
| inputs = Variable(inputs, requires_grad=False) | |
| with torch.no_grad(): | |
| if use_gpu >= 0: | |
| inputs = inputs.cuda() | |
| # outputs = net.forward(inputs) | |
| outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda()) | |
| outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2 | |
| outputs = outputs.unsqueeze(0) | |
| if iii > 0: | |
| outputs = F.upsample(outputs, size=(h, w), mode='bilinear', align_corners=True) | |
| outputs_final = outputs_final + outputs | |
| else: | |
| outputs_final = outputs.clone() | |
| ################ plot pic | |
| predictions = torch.max(outputs_final, 1)[1] | |
| results = predictions.cpu().numpy() | |
| vis_res = decode_labels(results) | |
| parsing_im = Image.fromarray(vis_res[0]) | |
| parsing_im.save(output_path+'/{}.png'.format(output_name)) | |
| cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :]) | |
| end_time = timeit.default_timer() | |
| print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time)) | |
| if __name__ == '__main__': | |
| '''argparse begin''' | |
| parser = argparse.ArgumentParser() | |
| # parser.add_argument('--loadmodel',default=None,type=str) | |
| parser.add_argument('--loadmodel', default='', type=str) | |
| parser.add_argument('--img_path', default='', type=str) | |
| parser.add_argument('--output_path', default='', type=str) | |
| parser.add_argument('--output_name', default='', type=str) | |
| parser.add_argument('--use_gpu', default=1, type=int) | |
| opts = parser.parse_args() | |
| net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20, | |
| hidden_layers=128, | |
| source_classes=7, ) | |
| if not opts.loadmodel == '': | |
| x = torch.load(opts.loadmodel) | |
| net.load_source_model(x) | |
| print('load model:', opts.loadmodel) | |
| else: | |
| print('no model load !!!!!!!!') | |
| raise RuntimeError('No model!!!!') | |
| if opts.use_gpu >0 : | |
| net.cuda() | |
| use_gpu = True | |
| else: | |
| use_gpu = False | |
| raise RuntimeError('must use the gpu!!!!') | |
| inference(net=net, img_path=opts.img_path,output_path=opts.output_path , output_name=opts.output_name, use_gpu=use_gpu) | |