Spaces:
Running
Running
import sys | |
import os | |
import logging | |
import re | |
import functools | |
import fnmatch | |
import numpy as np | |
def setup_logger(distributed_rank=0, filename="log.txt"): | |
logger = logging.getLogger("Logger") | |
logger.setLevel(logging.DEBUG) | |
# don't log results for the non-master process | |
if distributed_rank > 0: | |
return logger | |
ch = logging.StreamHandler(stream=sys.stdout) | |
ch.setLevel(logging.DEBUG) | |
fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" | |
ch.setFormatter(logging.Formatter(fmt)) | |
logger.addHandler(ch) | |
return logger | |
def find_recursive(root_dir, ext='.jpg'): | |
files = [] | |
for root, dirnames, filenames in os.walk(root_dir): | |
for filename in fnmatch.filter(filenames, '*' + ext): | |
files.append(os.path.join(root, filename)) | |
return files | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.initialized = False | |
self.val = None | |
self.avg = None | |
self.sum = None | |
self.count = None | |
def initialize(self, val, weight): | |
self.val = val | |
self.avg = val | |
self.sum = val * weight | |
self.count = weight | |
self.initialized = True | |
def update(self, val, weight=1): | |
if not self.initialized: | |
self.initialize(val, weight) | |
else: | |
self.add(val, weight) | |
def add(self, val, weight): | |
self.val = val | |
self.sum += val * weight | |
self.count += weight | |
self.avg = self.sum / self.count | |
def value(self): | |
return self.val | |
def average(self): | |
return self.avg | |
def unique(ar, return_index=False, return_inverse=False, return_counts=False): | |
ar = np.asanyarray(ar).flatten() | |
optional_indices = return_index or return_inverse | |
optional_returns = optional_indices or return_counts | |
if ar.size == 0: | |
if not optional_returns: | |
ret = ar | |
else: | |
ret = (ar,) | |
if return_index: | |
ret += (np.empty(0, np.bool),) | |
if return_inverse: | |
ret += (np.empty(0, np.bool),) | |
if return_counts: | |
ret += (np.empty(0, np.intp),) | |
return ret | |
if optional_indices: | |
perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') | |
aux = ar[perm] | |
else: | |
ar.sort() | |
aux = ar | |
flag = np.concatenate(([True], aux[1:] != aux[:-1])) | |
if not optional_returns: | |
ret = aux[flag] | |
else: | |
ret = (aux[flag],) | |
if return_index: | |
ret += (perm[flag],) | |
if return_inverse: | |
iflag = np.cumsum(flag) - 1 | |
inv_idx = np.empty(ar.shape, dtype=np.intp) | |
inv_idx[perm] = iflag | |
ret += (inv_idx,) | |
if return_counts: | |
idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) | |
ret += (np.diff(idx),) | |
return ret | |
def colorEncode(labelmap, colors, mode='RGB'): | |
labelmap = labelmap.astype('int') | |
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), | |
dtype=np.uint8) | |
for label in unique(labelmap): | |
if label < 0: | |
continue | |
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ | |
np.tile(colors[label], | |
(labelmap.shape[0], labelmap.shape[1], 1)) | |
if mode == 'BGR': | |
return labelmap_rgb[:, :, ::-1] | |
else: | |
return labelmap_rgb | |
def accuracy(preds, label): | |
valid = (label >= 0) | |
acc_sum = (valid * (preds == label)).sum() | |
valid_sum = valid.sum() | |
acc = float(acc_sum) / (valid_sum + 1e-10) | |
return acc, valid_sum | |
def intersectionAndUnion(imPred, imLab, numClass): | |
imPred = np.asarray(imPred).copy() | |
imLab = np.asarray(imLab).copy() | |
imPred += 1 | |
imLab += 1 | |
# Remove classes from unlabeled pixels in gt image. | |
# We should not penalize detections in unlabeled portions of the image. | |
imPred = imPred * (imLab > 0) | |
# Compute area intersection: | |
intersection = imPred * (imPred == imLab) | |
(area_intersection, _) = np.histogram( | |
intersection, bins=numClass, range=(1, numClass)) | |
# Compute area union: | |
(area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) | |
(area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) | |
area_union = area_pred + area_lab - area_intersection | |
return (area_intersection, area_union) | |
class NotSupportedCliException(Exception): | |
pass | |
def process_range(xpu, inp): | |
start, end = map(int, inp) | |
if start > end: | |
end, start = start, end | |
return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) | |
REGEX = [ | |
(re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), | |
(re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), | |
(re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), | |
functools.partial(process_range, 'gpu')), | |
(re.compile(r'^(\d+)-(\d+)$'), | |
functools.partial(process_range, 'gpu')), | |
] | |
def parse_devices(input_devices): | |
"""Parse user's devices input str to standard format. | |
e.g. [gpu0, gpu1, ...] | |
""" | |
ret = [] | |
for d in input_devices.split(','): | |
for regex, func in REGEX: | |
m = regex.match(d.lower().strip()) | |
if m: | |
tmp = func(m.groups()) | |
# prevent duplicate | |
for x in tmp: | |
if x not in ret: | |
ret.append(x) | |
break | |
else: | |
raise NotSupportedCliException( | |
'Can not recognize device: "{}"'.format(d)) | |
return ret | |