wizzseen's picture
Upload 948 files
8a6df40 verified
import os
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
def recursive_glob(rootdir='.', suffix=''):
"""Performs recursive glob with given suffix and rootdir
:param rootdir is the root directory
:param suffix is the suffix to be searched
"""
return [os.path.join(looproot, filename)
for looproot, _, filenames in os.walk(rootdir)
for filename in filenames if filename.endswith(suffix)]
def get_cityscapes_labels():
return np.array([
# [ 0, 0, 0],
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32]])
def get_pascal_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (21, 3)
"""
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]])
def get_mhp_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (21, 3)
"""
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128], # 21
[96, 0, 0], [0, 96, 0], [96, 96, 0],
[0, 0, 96], [96, 0, 96], [0, 96, 96], [96, 96, 96],
[32, 0, 0], [160, 0, 0], [32, 96, 0], [160, 96, 0],
[32, 0, 96], [160, 0, 96], [32, 96, 96], [160, 96, 96],
[0, 32, 0], [96, 32, 0], [0, 160, 0], [96, 160, 0],
[0, 32, 96], # 41
[48, 0, 0], [0, 48, 0], [48, 48, 0],
[0, 0, 96], [48, 0, 48], [0, 48, 48], [48, 48, 48],
[16, 0, 0], [80, 0, 0], [16, 48, 0], [80, 48, 0],
[16, 0, 48], [80, 0, 48], [16, 48, 48], [80, 48, 48],
[0, 16, 0], [48, 16, 0], [0, 80, 0], # 59
])
def encode_segmap(mask):
"""Encode segmentation label images as pascal classes
Args:
mask (np.ndarray): raw segmentation label image of dimension
(M, N, 3), in which the Pascal classes are encoded as colours.
Returns:
(np.ndarray): class map with dimensions (M,N), where the value at
a given location is the integer denoting the class index.
"""
mask = mask.astype(int)
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
for ii, label in enumerate(get_pascal_labels()):
label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
label_mask = label_mask.astype(int)
return label_mask
def decode_seg_map_sequence(label_masks, dataset='pascal'):
rgb_masks = []
for label_mask in label_masks:
rgb_mask = decode_segmap(label_mask, dataset)
rgb_masks.append(rgb_mask)
rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
return rgb_masks
def decode_segmap(label_mask, dataset, plot=False):
"""Decode segmentation class labels into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
if dataset == 'pascal':
n_classes = 21
label_colours = get_pascal_labels()
elif dataset == 'cityscapes':
n_classes = 19
label_colours = get_cityscapes_labels()
elif dataset == 'mhp':
n_classes = 59
label_colours = get_mhp_labels()
else:
raise NotImplementedError
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
if plot:
plt.imshow(rgb)
plt.show()
else:
return rgb
def generate_param_report(logfile, param):
log_file = open(logfile, 'w')
for key, val in param.items():
log_file.write(key + ':' + str(val) + '\n')
log_file.close()
def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True):
n, c, h, w = logit.size()
# logit = logit.permute(0, 2, 3, 1)
target = target.squeeze(1)
if weight is None:
criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average)
else:
criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average)
loss = criterion(logit, target.long())
return loss
def cross_entropy2d_dataparallel(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True):
n, c, h, w = logit.size()
# logit = logit.permute(0, 2, 3, 1)
target = target.squeeze(1)
if weight is None:
criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average))
else:
criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average))
loss = criterion(logit, target.long())
return loss.sum()
def lr_poly(base_lr, iter_, max_iter=100, power=0.9):
return base_lr * ((1 - float(iter_) / max_iter) ** power)
def get_iou(pred, gt, n_classes=21):
total_iou = 0.0
for i in range(len(pred)):
pred_tmp = pred[i]
gt_tmp = gt[i]
intersect = [0] * n_classes
union = [0] * n_classes
for j in range(n_classes):
match = (pred_tmp == j) + (gt_tmp == j)
it = torch.sum(match == 2).item()
un = torch.sum(match > 0).item()
intersect[j] += it
union[j] += un
iou = []
for k in range(n_classes):
if union[k] == 0:
continue
iou.append(intersect[k] / union[k])
img_iou = (sum(iou) / len(iou))
total_iou += img_iou
return total_iou
def scale_tensor(input,size=512,mode='bilinear'):
print(input.size())
# b,h,w = input.size()
_, _, h, w = input.size()
if mode == 'nearest':
if h == 512 and w == 512:
return input
return F.upsample_nearest(input,size=(size,size))
if h>512 and w > 512:
return F.upsample(input, size=(size,size), mode=mode, align_corners=True)
return F.upsample(input, size=(size,size), mode=mode, align_corners=True)
def scale_tensor_list(input,):
output = []
for i in range(len(input)-1):
output_item = []
for j in range(len(input[i])):
_, _, h, w = input[-1][j].size()
output_item.append(F.upsample(input[i][j], size=(h,w), mode='bilinear', align_corners=True))
output.append(output_item)
output.append(input[-1])
return output
def scale_tensor_list_0(input,base_input):
output = []
assert len(input) == len(base_input)
for j in range(len(input)):
_, _, h, w = base_input[j].size()
after_size = F.upsample(input[j], size=(h,w), mode='bilinear', align_corners=True)
base_input[j] = base_input[j] + after_size
# output.append(output_item)
# output.append(input[-1])
return base_input
if __name__ == '__main__':
print(lr_poly(0.007,iter_=99,max_iter=150))