Spaces:
Running
Running
import os | |
import torch | |
import random | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
def recursive_glob(rootdir='.', suffix=''): | |
"""Performs recursive glob with given suffix and rootdir | |
:param rootdir is the root directory | |
:param suffix is the suffix to be searched | |
""" | |
return [os.path.join(looproot, filename) | |
for looproot, _, filenames in os.walk(rootdir) | |
for filename in filenames if filename.endswith(suffix)] | |
def get_cityscapes_labels(): | |
return np.array([ | |
# [ 0, 0, 0], | |
[128, 64, 128], | |
[244, 35, 232], | |
[70, 70, 70], | |
[102, 102, 156], | |
[190, 153, 153], | |
[153, 153, 153], | |
[250, 170, 30], | |
[220, 220, 0], | |
[107, 142, 35], | |
[152, 251, 152], | |
[0, 130, 180], | |
[220, 20, 60], | |
[255, 0, 0], | |
[0, 0, 142], | |
[0, 0, 70], | |
[0, 60, 100], | |
[0, 80, 100], | |
[0, 0, 230], | |
[119, 11, 32]]) | |
def get_pascal_labels(): | |
"""Load the mapping that associates pascal classes with label colors | |
Returns: | |
np.ndarray with dimensions (21, 3) | |
""" | |
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], | |
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], | |
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], | |
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], | |
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], | |
[0, 64, 128]]) | |
def get_mhp_labels(): | |
"""Load the mapping that associates pascal classes with label colors | |
Returns: | |
np.ndarray with dimensions (21, 3) | |
""" | |
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], | |
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], | |
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], | |
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], | |
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], | |
[0, 64, 128], # 21 | |
[96, 0, 0], [0, 96, 0], [96, 96, 0], | |
[0, 0, 96], [96, 0, 96], [0, 96, 96], [96, 96, 96], | |
[32, 0, 0], [160, 0, 0], [32, 96, 0], [160, 96, 0], | |
[32, 0, 96], [160, 0, 96], [32, 96, 96], [160, 96, 96], | |
[0, 32, 0], [96, 32, 0], [0, 160, 0], [96, 160, 0], | |
[0, 32, 96], # 41 | |
[48, 0, 0], [0, 48, 0], [48, 48, 0], | |
[0, 0, 96], [48, 0, 48], [0, 48, 48], [48, 48, 48], | |
[16, 0, 0], [80, 0, 0], [16, 48, 0], [80, 48, 0], | |
[16, 0, 48], [80, 0, 48], [16, 48, 48], [80, 48, 48], | |
[0, 16, 0], [48, 16, 0], [0, 80, 0], # 59 | |
]) | |
def encode_segmap(mask): | |
"""Encode segmentation label images as pascal classes | |
Args: | |
mask (np.ndarray): raw segmentation label image of dimension | |
(M, N, 3), in which the Pascal classes are encoded as colours. | |
Returns: | |
(np.ndarray): class map with dimensions (M,N), where the value at | |
a given location is the integer denoting the class index. | |
""" | |
mask = mask.astype(int) | |
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) | |
for ii, label in enumerate(get_pascal_labels()): | |
label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii | |
label_mask = label_mask.astype(int) | |
return label_mask | |
def decode_seg_map_sequence(label_masks, dataset='pascal'): | |
rgb_masks = [] | |
for label_mask in label_masks: | |
rgb_mask = decode_segmap(label_mask, dataset) | |
rgb_masks.append(rgb_mask) | |
rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) | |
return rgb_masks | |
def decode_segmap(label_mask, dataset, plot=False): | |
"""Decode segmentation class labels into a color image | |
Args: | |
label_mask (np.ndarray): an (M,N) array of integer values denoting | |
the class label at each spatial location. | |
plot (bool, optional): whether to show the resulting color image | |
in a figure. | |
Returns: | |
(np.ndarray, optional): the resulting decoded color image. | |
""" | |
if dataset == 'pascal': | |
n_classes = 21 | |
label_colours = get_pascal_labels() | |
elif dataset == 'cityscapes': | |
n_classes = 19 | |
label_colours = get_cityscapes_labels() | |
elif dataset == 'mhp': | |
n_classes = 59 | |
label_colours = get_mhp_labels() | |
else: | |
raise NotImplementedError | |
r = label_mask.copy() | |
g = label_mask.copy() | |
b = label_mask.copy() | |
for ll in range(0, n_classes): | |
r[label_mask == ll] = label_colours[ll, 0] | |
g[label_mask == ll] = label_colours[ll, 1] | |
b[label_mask == ll] = label_colours[ll, 2] | |
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) | |
rgb[:, :, 0] = r / 255.0 | |
rgb[:, :, 1] = g / 255.0 | |
rgb[:, :, 2] = b / 255.0 | |
if plot: | |
plt.imshow(rgb) | |
plt.show() | |
else: | |
return rgb | |
def generate_param_report(logfile, param): | |
log_file = open(logfile, 'w') | |
for key, val in param.items(): | |
log_file.write(key + ':' + str(val) + '\n') | |
log_file.close() | |
def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): | |
n, c, h, w = logit.size() | |
# logit = logit.permute(0, 2, 3, 1) | |
target = target.squeeze(1) | |
if weight is None: | |
criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average) | |
else: | |
criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average) | |
loss = criterion(logit, target.long()) | |
return loss | |
def cross_entropy2d_dataparallel(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): | |
n, c, h, w = logit.size() | |
# logit = logit.permute(0, 2, 3, 1) | |
target = target.squeeze(1) | |
if weight is None: | |
criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average)) | |
else: | |
criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average)) | |
loss = criterion(logit, target.long()) | |
return loss.sum() | |
def lr_poly(base_lr, iter_, max_iter=100, power=0.9): | |
return base_lr * ((1 - float(iter_) / max_iter) ** power) | |
def get_iou(pred, gt, n_classes=21): | |
total_iou = 0.0 | |
for i in range(len(pred)): | |
pred_tmp = pred[i] | |
gt_tmp = gt[i] | |
intersect = [0] * n_classes | |
union = [0] * n_classes | |
for j in range(n_classes): | |
match = (pred_tmp == j) + (gt_tmp == j) | |
it = torch.sum(match == 2).item() | |
un = torch.sum(match > 0).item() | |
intersect[j] += it | |
union[j] += un | |
iou = [] | |
for k in range(n_classes): | |
if union[k] == 0: | |
continue | |
iou.append(intersect[k] / union[k]) | |
img_iou = (sum(iou) / len(iou)) | |
total_iou += img_iou | |
return total_iou | |
def scale_tensor(input,size=512,mode='bilinear'): | |
print(input.size()) | |
# b,h,w = input.size() | |
_, _, h, w = input.size() | |
if mode == 'nearest': | |
if h == 512 and w == 512: | |
return input | |
return F.upsample_nearest(input,size=(size,size)) | |
if h>512 and w > 512: | |
return F.upsample(input, size=(size,size), mode=mode, align_corners=True) | |
return F.upsample(input, size=(size,size), mode=mode, align_corners=True) | |
def scale_tensor_list(input,): | |
output = [] | |
for i in range(len(input)-1): | |
output_item = [] | |
for j in range(len(input[i])): | |
_, _, h, w = input[-1][j].size() | |
output_item.append(F.upsample(input[i][j], size=(h,w), mode='bilinear', align_corners=True)) | |
output.append(output_item) | |
output.append(input[-1]) | |
return output | |
def scale_tensor_list_0(input,base_input): | |
output = [] | |
assert len(input) == len(base_input) | |
for j in range(len(input)): | |
_, _, h, w = base_input[j].size() | |
after_size = F.upsample(input[j], size=(h,w), mode='bilinear', align_corners=True) | |
base_input[j] = base_input[j] + after_size | |
# output.append(output_item) | |
# output.append(input[-1]) | |
return base_input | |
if __name__ == '__main__': | |
print(lr_poly(0.007,iter_=99,max_iter=150)) |