|
import os, errno, numpy, torch, csv, re, shutil, os, zipfile |
|
from collections import OrderedDict |
|
from torchvision.datasets.folder import default_loader |
|
from torchvision import transforms |
|
from scipy import ndimage |
|
from urllib.request import urlopen |
|
|
|
class BrodenDataset(torch.utils.data.Dataset): |
|
''' |
|
A multicategory segmentation data set. |
|
|
|
Returns three streams: |
|
(1) The image (3, h, w). |
|
(2) The multicategory segmentation (labelcount, h, w). |
|
(3) A bincount of pixels in the segmentation (labelcount). |
|
|
|
Net dissect also assumes that the dataset object has three properties |
|
with human-readable labels: |
|
|
|
ds.labels = ['red', 'black', 'car', 'tree', 'grid', ...] |
|
ds.categories = ['color', 'part', 'object', 'texture'] |
|
ds.label_category = [0, 0, 2, 2, 3, ...] # The category for each label |
|
''' |
|
def __init__(self, directory='dataset/broden', resolution=384, |
|
split='train', categories=None, |
|
transform=None, transform_segment=None, |
|
download=False, size=None, include_bincount=True, |
|
broden_version=1, max_segment_depth=6): |
|
assert resolution in [224, 227, 384] |
|
if download: |
|
ensure_broden_downloaded(directory, resolution, broden_version) |
|
self.directory = directory |
|
self.resolution = resolution |
|
self.resdir = os.path.join(directory, 'broden%d_%d' % |
|
(broden_version, resolution)) |
|
self.loader = default_loader |
|
self.transform = transform |
|
self.transform_segment = transform_segment |
|
self.include_bincount = include_bincount |
|
|
|
self.max_segment_depth = max_segment_depth |
|
with open(os.path.join(self.resdir, 'category.csv'), |
|
encoding='utf-8') as f: |
|
self.category_info = OrderedDict() |
|
for row in csv.DictReader(f): |
|
self.category_info[row['name']] = row |
|
if categories is not None: |
|
|
|
categories = set([c for c in categories if c in self.category_info]) |
|
for cat in list(self.category_info.keys()): |
|
if cat not in categories: |
|
del self.category_info[cat] |
|
categories = list(self.category_info.keys()) |
|
self.categories = categories |
|
|
|
|
|
with open(os.path.join(self.resdir, 'index.csv'), |
|
encoding='utf-8') as f: |
|
all_images = [decode_index_dict(r) for r in csv.DictReader(f)] |
|
self.image = [row for row in all_images |
|
if index_has_any_data(row, categories) and row['split'] == split] |
|
if size is not None: |
|
self.image = self.image[:size] |
|
with open(os.path.join(self.resdir, 'label.csv'), |
|
encoding='utf-8') as f: |
|
self.label_info = build_dense_label_array([ |
|
decode_label_dict(r) for r in csv.DictReader(f)]) |
|
self.labels = [l['name'] for l in self.label_info] |
|
|
|
|
|
self.category_map = {} |
|
self.category_unmap = {} |
|
self.category_label = {} |
|
for cat in self.categories: |
|
with open(os.path.join(self.resdir, 'c_%s.csv' % cat), |
|
encoding='utf-8') as f: |
|
c_data = [decode_label_dict(r) for r in csv.DictReader(f)] |
|
self.category_unmap[cat], self.category_map[cat] = ( |
|
build_numpy_category_map(c_data)) |
|
self.category_label[cat] = build_dense_label_array( |
|
c_data, key='code') |
|
self.num_labels = len(self.labels) |
|
|
|
|
|
self.label_category = numpy.zeros(self.num_labels, dtype=int) |
|
for i in range(self.num_labels): |
|
maxcoverage, self.label_category[i] = max( |
|
(self.category_label[cat][self.category_map[cat][i]]['coverage'] |
|
if i < len(self.category_map[cat]) |
|
and self.category_map[cat][i] else 0, ic) |
|
for ic, cat in enumerate(categories)) |
|
|
|
def __len__(self): |
|
return len(self.image) |
|
|
|
def __getitem__(self, idx): |
|
record = self.image[idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = self.loader(os.path.join(self.resdir, 'images', |
|
record['image'])) |
|
segment = numpy.zeros(shape=(self.max_segment_depth, |
|
record['sh'], record['sw']), dtype=int) |
|
if self.include_bincount: |
|
bincount = numpy.zeros(shape=(self.num_labels,), dtype=int) |
|
depth = 0 |
|
for cat in self.categories: |
|
for layer in record[cat]: |
|
if isinstance(layer, int): |
|
segment[depth,:,:] = layer |
|
if self.include_bincount: |
|
bincount[layer] += segment.shape[1] * segment.shape[2] |
|
else: |
|
png = numpy.asarray(self.loader(os.path.join( |
|
self.resdir, 'images', layer))) |
|
segment[depth,:,:] = png[:,:,0] + png[:,:,1] * 256 |
|
if self.include_bincount: |
|
bincount += numpy.bincount(segment[depth,:,:].flatten(), |
|
minlength=self.num_labels) |
|
depth += 1 |
|
if self.transform: |
|
image = self.transform(image) |
|
if self.transform_segment: |
|
segment = self.transform_segment(segment) |
|
if self.include_bincount: |
|
bincount[0] = 0 |
|
return (image, segment, bincount) |
|
else: |
|
return (image, segment) |
|
|
|
def build_dense_label_array(label_data, key='number', allow_none=False): |
|
''' |
|
Input: set of rows with 'number' fields (or another field name key). |
|
Output: array such that a[number] = the row with the given number. |
|
''' |
|
result = [None] * (max([d[key] for d in label_data]) + 1) |
|
for d in label_data: |
|
result[d[key]] = d |
|
|
|
if not allow_none: |
|
example = label_data[0] |
|
def make_empty(k): |
|
return dict((c, k if c is key else type(v)()) |
|
for c, v in example.items()) |
|
for i, d in enumerate(result): |
|
if d is None: |
|
result[i] = dict(make_empty(i)) |
|
return result |
|
|
|
def build_numpy_category_map(map_data, key1='code', key2='number'): |
|
''' |
|
Input: set of rows with 'number' fields (or another field name key). |
|
Output: array such that a[number] = the row with the given number. |
|
''' |
|
results = list(numpy.zeros((max([d[key] for d in map_data]) + 1), |
|
dtype=numpy.int16) for key in (key1, key2)) |
|
for d in map_data: |
|
results[0][d[key1]] = d[key2] |
|
results[1][d[key2]] = d[key1] |
|
return results |
|
|
|
def index_has_any_data(row, categories): |
|
for c in categories: |
|
for data in row[c]: |
|
if data: return True |
|
return False |
|
|
|
def decode_label_dict(row): |
|
result = {} |
|
for key, val in row.items(): |
|
if key == 'category': |
|
result[key] = dict((c, int(n)) |
|
for c, n in [re.match('^([^(]*)\(([^)]*)\)$', f).groups() |
|
for f in val.split(';')]) |
|
elif key == 'name': |
|
result[key] = val |
|
elif key == 'syns': |
|
result[key] = val.split(';') |
|
elif re.match('^\d+$', val): |
|
result[key] = int(val) |
|
elif re.match('^\d+\.\d*$', val): |
|
result[key] = float(val) |
|
else: |
|
result[key] = val |
|
return result |
|
|
|
def decode_index_dict(row): |
|
result = {} |
|
for key, val in row.items(): |
|
if key in ['image', 'split']: |
|
result[key] = val |
|
elif key in ['sw', 'sh', 'iw', 'ih']: |
|
result[key] = int(val) |
|
else: |
|
item = [s for s in val.split(';') if s] |
|
for i, v in enumerate(item): |
|
if re.match('^\d+$', v): |
|
item[i] = int(v) |
|
result[key] = item |
|
return result |
|
|
|
class ScaleSegmentation: |
|
''' |
|
Utility for scaling segmentations, using nearest-neighbor zooming. |
|
''' |
|
def __init__(self, target_height, target_width): |
|
self.target_height = target_height |
|
self.target_width = target_width |
|
def __call__(self, seg): |
|
ratio = (1, self.target_height / float(seg.shape[1]), |
|
self.target_width / float(seg.shape[2])) |
|
return ndimage.zoom(seg, ratio, order=0) |
|
|
|
def scatter_batch(seg, num_labels, omit_zero=True, dtype=torch.uint8): |
|
''' |
|
Utility for scattering semgentations into a one-hot representation. |
|
''' |
|
result = torch.zeros(*((seg.shape[0], num_labels,) + seg.shape[2:]), |
|
dtype=dtype, device=seg.device) |
|
result.scatter_(1, seg, 1) |
|
if omit_zero: |
|
result[:,0] = 0 |
|
return result |
|
|
|
def ensure_broden_downloaded(directory, resolution, broden_version=1): |
|
assert resolution in [224, 227, 384] |
|
baseurl = 'http://netdissect.csail.mit.edu/data/' |
|
dirname = 'broden%d_%d' % (broden_version, resolution) |
|
if os.path.isfile(os.path.join(directory, dirname, 'index.csv')): |
|
return |
|
zipfilename = 'broden1_%d.zip' % resolution |
|
download_dir = os.path.join(directory, 'download') |
|
os.makedirs(download_dir, exist_ok=True) |
|
full_zipfilename = os.path.join(download_dir, zipfilename) |
|
if not os.path.exists(full_zipfilename): |
|
url = '%s/%s' % (baseurl, zipfilename) |
|
print('Downloading %s' % url) |
|
data = urlopen(url) |
|
with open(full_zipfilename, 'wb') as f: |
|
f.write(data.read()) |
|
print('Unzipping %s' % zipfilename) |
|
with zipfile.ZipFile(full_zipfilename, 'r') as zip_ref: |
|
zip_ref.extractall(directory) |
|
assert os.path.isfile(os.path.join(directory, dirname, 'index.csv')) |
|
|
|
def test_broden_dataset(): |
|
''' |
|
Testing code. |
|
''' |
|
bds = BrodenDataset('dataset/broden', resolution=384, |
|
transform=transforms.Compose([ |
|
transforms.Resize(224), |
|
transforms.ToTensor()]), |
|
transform_segment=transforms.Compose([ |
|
ScaleSegmentation(224, 224) |
|
]), |
|
include_bincount=True) |
|
loader = torch.utils.data.DataLoader(bds, batch_size=100, num_workers=24) |
|
for i in range(1,20): |
|
print(bds.label[i]['name'], |
|
list(bds.category.keys())[bds.primary_category[i]]) |
|
for i, (im, seg, bc) in enumerate(loader): |
|
print(i, im.shape, seg.shape, seg.max(), bc.shape) |
|
|
|
if __name__ == '__main__': |
|
test_broden_dataset() |
|
|