|
''' |
|
To run dissection: |
|
|
|
1. Load up the convolutional model you wish to dissect, and wrap it in |
|
an InstrumentedModel; then call imodel.retain_layers([layernames,..]) |
|
to instrument the layers of interest. |
|
2. Load the segmentation dataset using the BrodenDataset class; |
|
use the transform_image argument to normalize images to be |
|
suitable for the model, or the size argument to truncate the dataset. |
|
3. Choose a directory in which to write the output, and call |
|
dissect(outdir, model, dataset). |
|
|
|
Example: |
|
|
|
from dissect import InstrumentedModel, dissect |
|
from broden import BrodenDataset |
|
|
|
model = InstrumentedModel(load_my_model()) |
|
model.eval() |
|
model.cuda() |
|
model.retain_layers(['conv1', 'conv2', 'conv3', 'conv4', 'conv5']) |
|
bds = BrodenDataset('dataset/broden1_227', |
|
transform_image=transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), |
|
size=1000) |
|
dissect('result/dissect', model, bds, |
|
examples_per_unit=10) |
|
''' |
|
|
|
import torch, numpy, os, re, json, shutil, types, tempfile, torchvision |
|
|
|
|
|
from PIL import Image |
|
from xml.etree import ElementTree as et |
|
from collections import OrderedDict, defaultdict |
|
from .progress import verbose_progress, default_progress, print_progress |
|
from .progress import desc_progress |
|
from .runningstats import RunningQuantile, RunningTopK |
|
from .runningstats import RunningCrossCovariance, RunningConditionalQuantile |
|
from .sampler import FixedSubsetSampler |
|
from .actviz import activation_visualization |
|
from .segviz import segment_visualization, high_contrast |
|
from .workerpool import WorkerBase, WorkerPool |
|
from .segmenter import UnifiedParsingSegmenter |
|
|
|
def dissect(outdir, model, dataset, |
|
segrunner=None, |
|
train_dataset=None, |
|
model_segmenter=None, |
|
quantile_threshold=0.005, |
|
iou_threshold=0.05, |
|
iqr_threshold=0.01, |
|
examples_per_unit=100, |
|
batch_size=100, |
|
num_workers=24, |
|
seg_batch_size=5, |
|
make_images=True, |
|
make_labels=True, |
|
make_maxiou=False, |
|
make_covariance=False, |
|
make_report=True, |
|
make_row_images=True, |
|
make_single_images=False, |
|
rank_all_labels=False, |
|
netname=None, |
|
meta=None, |
|
merge=None, |
|
settings=None, |
|
): |
|
''' |
|
Runs net dissection in-memory, using pytorch, and saves visualizations |
|
and metadata into outdir. |
|
''' |
|
assert not model.training, 'Run model.eval() before dissection' |
|
if netname is None: |
|
netname = type(model).__name__ |
|
if segrunner is None: |
|
segrunner = ClassifierSegRunner(dataset) |
|
if train_dataset is None: |
|
train_dataset = dataset |
|
make_iqr = (quantile_threshold == 'iqr') |
|
with torch.no_grad(): |
|
device = next(model.parameters()).device |
|
levels = None |
|
labelnames, catnames = None, None |
|
maxioudata, iqrdata = None, None |
|
labeldata = None |
|
iqrdata, cov = None, None |
|
|
|
labelnames, catnames = segrunner.get_label_and_category_names() |
|
label_category = [catnames.index(c) if c in catnames else 0 |
|
for l, c in labelnames] |
|
|
|
|
|
segloader = torch.utils.data.DataLoader(dataset, |
|
batch_size=batch_size, num_workers=num_workers, |
|
pin_memory=(device.type == 'cuda')) |
|
quantiles, topk = collect_quantiles_and_topk(outdir, model, |
|
segloader, segrunner, k=examples_per_unit) |
|
|
|
|
|
if make_iqr: |
|
|
|
segloader = torch.utils.data.DataLoader(train_dataset, |
|
batch_size=1, num_workers=num_workers, |
|
pin_memory=(device.type == 'cuda')) |
|
iqrdata = collect_iqr(outdir, model, segloader, segrunner) |
|
max_iqr, full_iqr_levels = iqrdata[:2] |
|
max_iqr_agreement = iqrdata[4] |
|
|
|
levels = {layer: full_iqr_levels[layer][ |
|
max_iqr[layer].max(0)[1], |
|
torch.arange(max_iqr[layer].shape[1])].to(device) |
|
for layer in full_iqr_levels} |
|
else: |
|
levels = {k: qc.quantiles([1.0 - quantile_threshold])[:,0] |
|
for k, qc in quantiles.items()} |
|
|
|
quantiledata = (topk, quantiles, levels, quantile_threshold) |
|
|
|
if make_images: |
|
segloader = torch.utils.data.DataLoader(dataset, |
|
batch_size=batch_size, num_workers=num_workers, |
|
pin_memory=(device.type == 'cuda')) |
|
generate_images(outdir, model, dataset, topk, levels, segrunner, |
|
row_length=examples_per_unit, batch_size=seg_batch_size, |
|
row_images=make_row_images, |
|
single_images=make_single_images, |
|
num_workers=num_workers) |
|
|
|
if make_maxiou: |
|
assert train_dataset, "Need training dataset for maxiou." |
|
segloader = torch.utils.data.DataLoader(train_dataset, |
|
batch_size=1, num_workers=num_workers, |
|
pin_memory=(device.type == 'cuda')) |
|
maxioudata = collect_maxiou(outdir, model, segloader, |
|
segrunner) |
|
|
|
if make_labels: |
|
segloader = torch.utils.data.DataLoader(dataset, |
|
batch_size=1, num_workers=num_workers, |
|
pin_memory=(device.type == 'cuda')) |
|
iou_scores, iqr_scores, tcs, lcs, ccs, ics = ( |
|
collect_bincounts(outdir, model, segloader, |
|
levels, segrunner)) |
|
labeldata = (iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, |
|
iqr_threshold) |
|
|
|
if make_covariance: |
|
segloader = torch.utils.data.DataLoader(dataset, |
|
batch_size=seg_batch_size, |
|
num_workers=num_workers, |
|
pin_memory=(device.type == 'cuda')) |
|
cov = collect_covariance(outdir, model, segloader, segrunner) |
|
|
|
if make_report: |
|
generate_report(outdir, |
|
quantiledata=quantiledata, |
|
labelnames=labelnames, |
|
catnames=catnames, |
|
labeldata=labeldata, |
|
maxioudata=maxioudata, |
|
iqrdata=iqrdata, |
|
covariancedata=cov, |
|
rank_all_labels=rank_all_labels, |
|
netname=netname, |
|
meta=meta, |
|
mergedata=merge, |
|
settings=settings) |
|
|
|
return quantiledata, labeldata |
|
|
|
def generate_report(outdir, quantiledata, labelnames=None, catnames=None, |
|
labeldata=None, maxioudata=None, iqrdata=None, covariancedata=None, |
|
rank_all_labels=False, netname='Model', meta=None, settings=None, |
|
mergedata=None): |
|
''' |
|
Creates dissection.json reports and summary bargraph.svg files in the |
|
specified output directory, and copies a dissection.html interface |
|
to go along with it. |
|
''' |
|
all_layers = [] |
|
|
|
srcdir = os.path.realpath( |
|
os.path.join(os.getcwd(), os.path.dirname(__file__))) |
|
|
|
topk, quantiles, levels, quantile_threshold = quantiledata |
|
top_record = dict( |
|
netname=netname, |
|
meta=meta, |
|
default_ranking='unit', |
|
quantile_threshold=quantile_threshold) |
|
if settings is not None: |
|
top_record['settings'] = settings |
|
if labeldata is not None: |
|
iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, iqr_threshold = ( |
|
labeldata) |
|
catorder = {'object': -7, 'scene': -6, 'part': -5, |
|
'piece': -4, |
|
'material': -3, 'texture': -2, 'color': -1} |
|
for i, cat in enumerate(c for c in catnames if c not in catorder): |
|
catorder[cat] = i |
|
catnumber = {n: i for i, n in enumerate(catnames)} |
|
catnumber['-'] = 0 |
|
top_record['default_ranking'] = 'label' |
|
top_record['iou_threshold'] = iou_threshold |
|
top_record['iqr_threshold'] = iqr_threshold |
|
labelnumber = dict((name[0], num) |
|
for num, name in enumerate(labelnames)) |
|
|
|
segcolors = {} |
|
for i, name in enumerate(labelnames): |
|
key = ','.join(str(s) for s in high_contrast[i % len(high_contrast)]) |
|
if key in segcolors: |
|
segcolors[key] += '/' + name[0] |
|
else: |
|
segcolors[key] = name[0] |
|
top_record['segcolors'] = segcolors |
|
for layer in topk.keys(): |
|
units, rankings = [], [] |
|
record = dict(layer=layer, units=units, rankings=rankings) |
|
|
|
topa, topi = topk[layer].result() |
|
lev = levels[layer] |
|
for u in range(len(topa)): |
|
units.append(dict( |
|
unit=u, |
|
interp=True, |
|
level=lev[u].item(), |
|
top=[dict(imgnum=i.item(), maxact=a.item()) |
|
for i, a in zip(topi[u], topa[u])], |
|
)) |
|
rankings.append(dict(name="unit", score=list([ |
|
u for u in range(len(topa))]))) |
|
|
|
|
|
|
|
|
|
if labeldata is not None: |
|
lscore, qscore, cc, ic = [dat[layer] |
|
for dat in [iou_scores, iqr_scores, ccs, ics]] |
|
if iqrdata is not None: |
|
|
|
max_iqr, max_iqr_level = iqrdata[:2] |
|
best_label = max_iqr[layer].max(0)[1] |
|
best_score = lscore[best_label, torch.arange(lscore.shape[1])] |
|
best_qscore = qscore[best_label, torch.arange(lscore.shape[1])] |
|
else: |
|
|
|
best_score, best_label = lscore.max(0) |
|
best_qscore = qscore[best_label, torch.arange(qscore.shape[1])] |
|
record['iou_threshold'] = iou_threshold, |
|
for u, urec in enumerate(units): |
|
score, qscore, label = ( |
|
best_score[u], best_qscore[u], best_label[u]) |
|
urec.update(dict( |
|
iou=score.item(), |
|
iou_iqr=qscore.item(), |
|
lc=lcs[label].item(), |
|
cc=cc[catnumber[labelnames[label][1]], u].item(), |
|
ic=ic[label, u].item(), |
|
interp=(qscore.item() > iqr_threshold and |
|
score.item() > iou_threshold), |
|
iou_labelnum=label.item(), |
|
iou_label=labelnames[label.item()][0], |
|
iou_cat=labelnames[label.item()][1], |
|
)) |
|
if maxioudata is not None: |
|
max_iou, max_iou_level, max_iou_quantile = maxioudata |
|
qualified_iou = max_iou[layer].clone() |
|
|
|
best_score, best_label = qualified_iou.max(0) |
|
for u, urec in enumerate(units): |
|
urec.update(dict( |
|
maxiou=best_score[u].item(), |
|
maxiou_label=labelnames[best_label[u].item()][0], |
|
maxiou_cat=labelnames[best_label[u].item()][1], |
|
maxiou_level=max_iou_level[layer][best_label[u], u].item(), |
|
maxiou_quantile=max_iou_quantile[layer][ |
|
best_label[u], u].item())) |
|
if iqrdata is not None: |
|
[max_iqr, max_iqr_level, max_iqr_quantile, |
|
max_iqr_iou, max_iqr_agreement] = iqrdata |
|
qualified_iqr = max_iqr[layer].clone() |
|
qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0 |
|
best_score, best_label = qualified_iqr.max(0) |
|
for u, urec in enumerate(units): |
|
urec.update(dict( |
|
iqr=best_score[u].item(), |
|
iqr_label=labelnames[best_label[u].item()][0], |
|
iqr_cat=labelnames[best_label[u].item()][1], |
|
iqr_level=max_iqr_level[layer][best_label[u], u].item(), |
|
iqr_quantile=max_iqr_quantile[layer][ |
|
best_label[u], u].item(), |
|
iqr_iou=max_iqr_iou[layer][best_label[u], u].item() |
|
)) |
|
if covariancedata is not None: |
|
score = covariancedata[layer].correlation() |
|
best_score, best_label = score.max(1) |
|
for u, urec in enumerate(units): |
|
urec.update(dict( |
|
cor=best_score[u].item(), |
|
cor_label=labelnames[best_label[u].item()][0], |
|
cor_cat=labelnames[best_label[u].item()][1] |
|
)) |
|
if mergedata is not None: |
|
|
|
|
|
|
|
|
|
for lrec in mergedata.get('layers', []): |
|
if lrec['layer'] == layer: |
|
break |
|
else: |
|
lrec = None |
|
for u, urec in enumerate(lrec.get('units', []) if lrec else []): |
|
units[u].update(urec) |
|
|
|
if labeldata is not None: |
|
|
|
labelunits = defaultdict(list) |
|
all_labelunits = defaultdict(list) |
|
for u, urec in enumerate(units): |
|
if urec['interp']: |
|
labelunits[urec['iou_labelnum']].append(u) |
|
all_labelunits[urec['iou_labelnum']].append(u) |
|
|
|
label_ordering = sorted(units, |
|
|
|
key=lambda r: (-1 if r['interp'] else 0, |
|
-len(labelunits[r['iou_labelnum']]), |
|
-max([units[u]['iou'] |
|
for u in labelunits[r['iou_labelnum']]], default=0), |
|
r['iou_labelnum'], |
|
-r['iou'])) |
|
|
|
rankings.append(dict(name="label", score=(numpy.argsort(list( |
|
ur['unit'] for ur in label_ordering))).tolist())) |
|
rankings.append(dict(name="max iou", metric="iou", score=list( |
|
-ur['iou'] for ur in units))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
record['labels'] = [dict( |
|
label=labelnames[label][0], |
|
labelnum=label, |
|
units=labelunits[label], |
|
cat=labelnames[label][1]) |
|
for label in (sorted(labelunits.keys(), |
|
|
|
key=lambda l: (catorder.get( |
|
labelnames[l][1], 0), |
|
-len(labelunits[l]), |
|
-max([units[u]['iou'] for u in labelunits[l]], |
|
default=0) |
|
))) if len(labelunits[label])] |
|
|
|
record['interpretable'] = sum(len(group['units']) |
|
for group in record['labels']) |
|
|
|
os.makedirs(os.path.join(outdir, safe_dir_name(layer)), |
|
exist_ok=True) |
|
catgroups = OrderedDict() |
|
for _, cat in sorted([(v, k) for k, v in catorder.items()]): |
|
catgroups[cat] = [] |
|
for rec in record['labels']: |
|
if rec['cat'] not in catgroups: |
|
catgroups[rec['cat']] = [] |
|
catgroups[rec['cat']].append(rec['label']) |
|
make_svg_bargraph( |
|
[rec['label'] for rec in record['labels']], |
|
[len(rec['units']) for rec in record['labels']], |
|
[(cat, len(group)) for cat, group in catgroups.items()], |
|
filename=os.path.join(outdir, safe_dir_name(layer), |
|
'bargraph.svg')) |
|
|
|
if len(record['labels']): |
|
record['bargraph'] = 'bargraph.svg' |
|
if maxioudata is not None: |
|
rankings.append(dict(name="max maxiou", metric="maxiou", score=list( |
|
-ur['maxiou'] for ur in units))) |
|
if iqrdata is not None: |
|
rankings.append(dict(name="max iqr", metric="iqr", score=list( |
|
-ur['iqr'] for ur in units))) |
|
if covariancedata is not None: |
|
rankings.append(dict(name="max cor", metric="cor", score=list( |
|
-ur['cor'] for ur in units))) |
|
|
|
all_layers.append(record) |
|
|
|
all_labels = None |
|
if rank_all_labels: |
|
all_labels = [name for name, cat in labelnames] |
|
if labeldata is not None: |
|
|
|
counted_labels = defaultdict(int) |
|
for label in [ |
|
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', unitrec['iou_label']) |
|
for record in all_layers for unitrec in record['units']]: |
|
counted_labels[label] += 1 |
|
if all_labels is None: |
|
all_labels = [label for count, label in sorted((-v, k) |
|
for k, v in counted_labels.items())] |
|
for record in all_layers: |
|
layer = record['layer'] |
|
for label in all_labels: |
|
labelnum = labelnumber[label] |
|
record['rankings'].append(dict(name="%s-iou" % label, |
|
concept=label, metric='iou', |
|
score=(-iou_scores[layer][labelnum, :]).tolist())) |
|
|
|
if maxioudata is not None: |
|
if all_labels is None: |
|
counted_labels = defaultdict(int) |
|
for label in [ |
|
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', |
|
unitrec['maxiou_label']) |
|
for record in all_layers for unitrec in record['units']]: |
|
counted_labels[label] += 1 |
|
all_labels = [label for count, label in sorted((-v, k) |
|
for k, v in counted_labels.items())] |
|
qualified_iou = max_iou[layer].clone() |
|
qualified_iou[max_iou_quantile[layer] > 0.5] = 0 |
|
for record in all_layers: |
|
layer = record['layer'] |
|
for label in all_labels: |
|
labelnum = labelnumber[label] |
|
record['rankings'].append(dict(name="%s-maxiou" % label, |
|
concept=label, metric='maxiou', |
|
score=(-qualified_iou[labelnum, :]).tolist())) |
|
|
|
if iqrdata is not None: |
|
if all_labels is None: |
|
counted_labels = defaultdict(int) |
|
for label in [ |
|
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', |
|
unitrec['iqr_label']) |
|
for record in all_layers for unitrec in record['units']]: |
|
counted_labels[label] += 1 |
|
all_labels = [label for count, label in sorted((-v, k) |
|
for k, v in counted_labels.items())] |
|
|
|
for record in all_layers: |
|
layer = record['layer'] |
|
qualified_iqr = max_iqr[layer].clone() |
|
for label in all_labels: |
|
labelnum = labelnumber[label] |
|
record['rankings'].append(dict(name="%s-iqr" % label, |
|
concept=label, metric='iqr', |
|
score=(-qualified_iqr[labelnum, :]).tolist())) |
|
|
|
if covariancedata is not None: |
|
if all_labels is None: |
|
counted_labels = defaultdict(int) |
|
for label in [ |
|
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', |
|
unitrec['cor_label']) |
|
for record in all_layers for unitrec in record['units']]: |
|
counted_labels[label] += 1 |
|
all_labels = [label for count, label in sorted((-v, k) |
|
for k, v in counted_labels.items())] |
|
for record in all_layers: |
|
layer = record['layer'] |
|
score = covariancedata[layer].correlation() |
|
for label in all_labels: |
|
labelnum = labelnumber[label] |
|
record['rankings'].append(dict(name="%s-cor" % label, |
|
concept=label, metric='cor', |
|
score=(-score[:, labelnum]).tolist())) |
|
|
|
for record in all_layers: |
|
layer = record['layer'] |
|
|
|
record['dirname'] = '.' |
|
with open(os.path.join(outdir, safe_dir_name(layer), 'dissect.json'), |
|
'w') as jsonfile: |
|
top_record['layers'] = [record] |
|
json.dump(top_record, jsonfile, indent=1) |
|
|
|
shutil.copy(os.path.join(srcdir, 'dissect.html'), |
|
os.path.join(outdir, safe_dir_name(layer), 'dissect.html')) |
|
record['dirname'] = safe_dir_name(layer) |
|
|
|
|
|
with open(os.path.join(outdir, 'dissect.json'), 'w') as jsonfile: |
|
top_record['layers'] = all_layers |
|
json.dump(top_record, jsonfile, indent=1) |
|
|
|
shutil.copy(os.path.join(srcdir, 'dissect.html'), |
|
os.path.join(outdir, 'dissect.html')) |
|
shutil.copy(os.path.join(srcdir, 'edit.html'), |
|
os.path.join(outdir, 'edit.html')) |
|
|
|
|
|
def generate_images(outdir, model, dataset, topk, levels, |
|
segrunner, row_length=None, gap_pixels=5, |
|
row_images=True, single_images=False, prefix='', |
|
batch_size=100, num_workers=24): |
|
''' |
|
Creates an image strip file for every unit of every retained layer |
|
of the model, in the format [outdir]/[layername]/[unitnum]-top.jpg. |
|
Assumes that the indexes of topk refer to the indexes of dataset. |
|
Limits each strip to the top row_length images. |
|
''' |
|
progress = default_progress() |
|
needed_images = {} |
|
if row_images is False: |
|
row_length = 1 |
|
|
|
for layer in topk: |
|
topresult = topk[layer].result()[1].cpu() |
|
for unit, row in enumerate(topresult): |
|
for rank, imgnum in enumerate(row[:row_length]): |
|
imgnum = imgnum.item() |
|
if imgnum not in needed_images: |
|
needed_images[imgnum] = [] |
|
needed_images[imgnum].append((layer, unit, rank)) |
|
levels = {k: v.cpu().numpy() for k, v in levels.items()} |
|
row_length = len(row[:row_length]) |
|
needed_sample = FixedSubsetSampler(sorted(needed_images.keys())) |
|
device = next(model.parameters()).device |
|
segloader = torch.utils.data.DataLoader(dataset, |
|
batch_size=batch_size, num_workers=num_workers, |
|
pin_memory=(device.type == 'cuda'), |
|
sampler=needed_sample) |
|
vizgrid, maskgrid, origrid, seggrid = [{} for _ in range(4)] |
|
|
|
pool = None |
|
for i, batch in enumerate( |
|
progress(segloader, desc='Making images')): |
|
|
|
seg, _, byte_im, _ = segrunner.run_and_segment_batch(batch, model, |
|
want_rgb=True) |
|
torch_features = model.retained_features() |
|
scale_offset = getattr(model, 'scale_offset', None) |
|
if pool is None: |
|
|
|
for layer, tf in torch_features.items(): |
|
[vizgrid[layer], maskgrid[layer], origrid[layer], |
|
seggrid[layer]] = [ |
|
create_temp_mmap_grid((tf.shape[1], |
|
byte_im.shape[1], row_length, |
|
byte_im.shape[2] + gap_pixels, depth), |
|
dtype='uint8', |
|
fill=255) |
|
for depth in [3, 4, 3, 3]] |
|
|
|
pool = WorkerPool(worker=VisualizeImageWorker, |
|
memmap_grid_info=[ |
|
{layer: (g.filename, g.shape, g.dtype) |
|
for layer, g in grid.items()} |
|
for grid in [vizgrid, maskgrid, origrid, seggrid]]) |
|
byte_im = byte_im.cpu().numpy() |
|
numpy_seg = seg.cpu().numpy() |
|
features = {} |
|
for index in range(len(byte_im)): |
|
imgnum = needed_sample.samples[index + i*segloader.batch_size] |
|
for layer, unit, rank in needed_images[imgnum]: |
|
if layer not in features: |
|
features[layer] = torch_features[layer].cpu().numpy() |
|
pool.add(layer, unit, rank, |
|
byte_im[index], |
|
features[layer][index, unit], |
|
levels[layer][unit], |
|
scale_offset[layer] if scale_offset else None, |
|
numpy_seg[index]) |
|
pool.join() |
|
|
|
pool = WorkerPool(worker=SaveImageWorker) |
|
for layer, vg in progress(vizgrid.items(), desc='Saving images'): |
|
os.makedirs(os.path.join(outdir, safe_dir_name(layer), |
|
prefix + 'image'), exist_ok=True) |
|
if single_images: |
|
os.makedirs(os.path.join(outdir, safe_dir_name(layer), |
|
prefix + 's-image'), exist_ok=True) |
|
og, sg, mg = origrid[layer], seggrid[layer], maskgrid[layer] |
|
for unit in progress(range(len(vg)), desc='Units'): |
|
for suffix, grid in [('top.jpg', vg), ('orig.jpg', og), |
|
('seg.png', sg), ('mask.png', mg)]: |
|
strip = grid[unit].reshape( |
|
(grid.shape[1], grid.shape[2] * grid.shape[3], |
|
grid.shape[4])) |
|
if row_images: |
|
filename = os.path.join(outdir, safe_dir_name(layer), |
|
prefix + 'image', '%d-%s' % (unit, suffix)) |
|
pool.add(strip[:,:-gap_pixels,:].copy(), filename) |
|
|
|
|
|
if single_images: |
|
single_filename = os.path.join(outdir, safe_dir_name(layer), |
|
prefix + 's-image', '%d-%s' % (unit, suffix)) |
|
pool.add(strip[:,:strip.shape[1] // row_length |
|
- gap_pixels,:].copy(), single_filename) |
|
|
|
|
|
|
|
pool.join() |
|
|
|
clear_global_shared_files([g.filename |
|
for grid in [vizgrid, maskgrid, origrid, seggrid] |
|
for g in grid.values()]) |
|
|
|
global_shared_files = {} |
|
def create_temp_mmap_grid(shape, dtype, fill): |
|
dtype = numpy.dtype(dtype) |
|
filename = os.path.join(tempfile.mkdtemp(), 'temp-%s-%s.mmap' % |
|
('x'.join('%d' % s for s in shape), dtype.name)) |
|
fid = open(filename, mode='w+b') |
|
original = numpy.memmap(fid, dtype=dtype, mode='w+', shape=shape) |
|
original.fid = fid |
|
original[...] = fill |
|
global_shared_files[filename] = original |
|
return original |
|
|
|
def shared_temp_mmap_grid(filename, shape, dtype): |
|
if filename not in global_shared_files: |
|
global_shared_files[filename] = numpy.memmap( |
|
filename, dtype=dtype, mode='r+', shape=shape) |
|
return global_shared_files[filename] |
|
|
|
def clear_global_shared_files(filenames): |
|
for fn in filenames: |
|
if fn in global_shared_files: |
|
del global_shared_files[fn] |
|
try: |
|
os.unlink(fn) |
|
except OSError: |
|
pass |
|
|
|
class VisualizeImageWorker(WorkerBase): |
|
def setup(self, memmap_grid_info): |
|
self.vizgrid, self.maskgrid, self.origrid, self.seggrid = [ |
|
{layer: shared_temp_mmap_grid(*info) |
|
for layer, info in grid.items()} |
|
for grid in memmap_grid_info] |
|
def work(self, layer, unit, rank, |
|
byte_im, acts, level, scale_offset, seg): |
|
self.origrid[layer][unit,:,rank,:byte_im.shape[0],:] = byte_im |
|
[self.vizgrid[layer][unit,:,rank,:byte_im.shape[0],:], |
|
self.maskgrid[layer][unit,:,rank,:byte_im.shape[0],:]] = ( |
|
activation_visualization( |
|
byte_im, |
|
acts, |
|
level, |
|
scale_offset=scale_offset, |
|
return_mask=True)) |
|
self.seggrid[layer][unit,:,rank,:byte_im.shape[0],:] = ( |
|
segment_visualization(seg, byte_im.shape[0:2])) |
|
|
|
class SaveImageWorker(WorkerBase): |
|
def work(self, data, filename): |
|
Image.fromarray(data).save(filename, optimize=True, quality=80) |
|
|
|
def score_tally_stats(label_category, tc, truth, cc, ic): |
|
pred = cc[label_category] |
|
total = tc[label_category][:, None] |
|
truth = truth[:, None] |
|
epsilon = 1e-20 |
|
union = pred + truth - ic |
|
iou = ic.double() / (union.double() + epsilon) |
|
arr = torch.empty(size=(2, 2) + ic.shape, dtype=ic.dtype, device=ic.device) |
|
arr[0, 0] = ic |
|
arr[0, 1] = pred - ic |
|
arr[1, 0] = truth - ic |
|
arr[1, 1] = total - union |
|
arr = arr.double() / total.double() |
|
mi = mutual_information(arr) |
|
je = joint_entropy(arr) |
|
iqr = mi / je |
|
iqr[torch.isnan(iqr)] = 0 |
|
return iou, iqr |
|
|
|
def collect_quantiles_and_topk(outdir, model, segloader, |
|
segrunner, k=100, resolution=1024): |
|
''' |
|
Collects (estimated) quantile information and (exact) sorted top-K lists |
|
for every channel in the retained layers of the model. Returns |
|
a map of quantiles (one RunningQuantile for each layer) along with |
|
a map of topk (one RunningTopK for each layer). |
|
''' |
|
device = next(model.parameters()).device |
|
features = model.retained_features() |
|
cached_quantiles = { |
|
layer: load_quantile_if_present(os.path.join(outdir, |
|
safe_dir_name(layer)), 'quantiles.npz', |
|
device=torch.device('cpu')) |
|
for layer in features } |
|
cached_topks = { |
|
layer: load_topk_if_present(os.path.join(outdir, |
|
safe_dir_name(layer)), 'topk.npz', |
|
device=torch.device('cpu')) |
|
for layer in features } |
|
if (all(value is not None for value in cached_quantiles.values()) and |
|
all(value is not None for value in cached_topks.values())): |
|
return cached_quantiles, cached_topks |
|
|
|
layer_batch_size = 8 |
|
all_layers = list(features.keys()) |
|
layer_batches = [all_layers[i:i+layer_batch_size] |
|
for i in range(0, len(all_layers), layer_batch_size)] |
|
|
|
quantiles, topks = {}, {} |
|
progress = default_progress() |
|
for layer_batch in layer_batches: |
|
for i, batch in enumerate(progress(segloader, desc='Quantiles')): |
|
|
|
model(batch[0].to(device)) |
|
features = model.retained_features() |
|
|
|
for key in layer_batch: |
|
value = features[key] |
|
if topks.get(key, None) is None: |
|
topks[key] = RunningTopK(k) |
|
if quantiles.get(key, None) is None: |
|
quantiles[key] = RunningQuantile(resolution=resolution) |
|
topvalue = value |
|
if len(value.shape) > 2: |
|
topvalue, _ = value.view(*(value.shape[:2] + (-1,))).max(2) |
|
|
|
value = value.permute( |
|
(0,) + tuple(range(2, len(value.shape))) + (1,) |
|
).contiguous().view(-1, value.shape[1]) |
|
quantiles[key].add(value) |
|
topks[key].add(topvalue) |
|
|
|
for key in layer_batch: |
|
quantiles[key].to_(torch.device('cpu')) |
|
topks[key].to_(torch.device('cpu')) |
|
for layer in quantiles: |
|
save_state_dict(quantiles[layer], |
|
os.path.join(outdir, safe_dir_name(layer), 'quantiles.npz')) |
|
save_state_dict(topks[layer], |
|
os.path.join(outdir, safe_dir_name(layer), 'topk.npz')) |
|
return quantiles, topks |
|
|
|
def collect_bincounts(outdir, model, segloader, levels, segrunner): |
|
''' |
|
Returns label_counts, category_activation_counts, and intersection_counts, |
|
across the data set, counting the pixels of intersection between upsampled, |
|
thresholded model featuremaps, with segmentation classes in the segloader. |
|
|
|
label_counts (independent of model): pixels across the data set that |
|
are labeled with the given label. |
|
category_activation_counts (one per layer): for each feature channel, |
|
pixels across the dataset where the channel exceeds the level |
|
threshold. There is one count per category: activations only |
|
contribute to the categories for which any category labels are |
|
present on the images. |
|
intersection_counts (one per layer): for each feature channel and |
|
label, pixels across the dataset where the channel exceeds |
|
the level, and the labeled segmentation class is also present. |
|
|
|
This is a performance-sensitive function. Best performance is |
|
achieved with a counting scheme which assumes a segloader with |
|
batch_size 1. |
|
''' |
|
|
|
(iou_scores, iqr_scores, |
|
total_counts, label_counts, category_activation_counts, |
|
intersection_counts) = {}, {}, None, None, {}, {} |
|
found_all = True |
|
for layer in model.retained_features(): |
|
filename = os.path.join(outdir, safe_dir_name(layer), 'bincounts.npz') |
|
if os.path.isfile(filename): |
|
data = numpy.load(filename) |
|
iou_scores[layer] = torch.from_numpy(data['iou_scores']) |
|
iqr_scores[layer] = torch.from_numpy(data['iqr_scores']) |
|
total_counts = torch.from_numpy(data['total_counts']) |
|
label_counts = torch.from_numpy(data['label_counts']) |
|
category_activation_counts[layer] = torch.from_numpy( |
|
data['category_activation_counts']) |
|
intersection_counts[layer] = torch.from_numpy( |
|
data['intersection_counts']) |
|
else: |
|
found_all = False |
|
if found_all: |
|
return (iou_scores, iqr_scores, |
|
total_counts, label_counts, category_activation_counts, |
|
intersection_counts) |
|
|
|
device = next(model.parameters()).device |
|
labelcat, categories = segrunner.get_label_and_category_names() |
|
label_category = [categories.index(c) if c in categories else 0 |
|
for l, c in labelcat] |
|
num_labels, num_categories = (len(n) for n in [labelcat, categories]) |
|
|
|
|
|
labelcat = torch.zeros(num_labels, num_categories, |
|
dtype=torch.long, device=device) |
|
labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category, |
|
dtype='int64')).to(device)[:,None], 1) |
|
|
|
|
|
assert segloader.batch_size == 1 |
|
category_activation_counts = {} |
|
intersection_counts = {} |
|
label_counts = torch.zeros(num_labels, dtype=torch.long, device=device) |
|
total_counts = torch.zeros(num_categories, dtype=torch.long, device=device) |
|
progress = default_progress() |
|
scale_offset_map = getattr(model, 'scale_offset', None) |
|
upsample_grids = {} |
|
|
|
|
|
for i, batch in enumerate(progress(segloader, desc='Bincounts')): |
|
seg, batch_label_counts, _, imshape = segrunner.run_and_segment_batch( |
|
batch, model, want_bincount=True, want_rgb=True) |
|
bc = batch_label_counts.cpu() |
|
batch_label_counts = batch_label_counts.to(device) |
|
seg = seg.to(device) |
|
features = model.retained_features() |
|
|
|
label_counts += batch_label_counts[0] |
|
batch_labels = bc[0].nonzero()[:,0] |
|
batch_categories = labelcat[batch_labels].max(0)[0] |
|
total_counts += batch_categories * ( |
|
seg.shape[0] * seg.shape[2] * seg.shape[3]) |
|
for key, value in features.items(): |
|
if key not in upsample_grids: |
|
upsample_grids[key] = upsample_grid(value.shape[2:], |
|
seg.shape[2:], imshape, |
|
scale_offset=scale_offset_map.get(key, None) |
|
if scale_offset_map is not None else None, |
|
dtype=value.dtype, device=value.device) |
|
upsampled = torch.nn.functional.grid_sample(value, |
|
upsample_grids[key], padding_mode='border') |
|
amask = (upsampled > levels[key][None,:,None,None].to( |
|
upsampled.device)) |
|
ac = amask.int().view(amask.shape[1], -1).sum(1) |
|
|
|
|
|
|
|
|
|
|
|
for label in batch_labels.tolist(): |
|
if label == 0: |
|
continue |
|
imask = amask * ((seg == label).max(dim=1, keepdim=True)[0]) |
|
ic = imask.int().view(imask.shape[1], -1).sum(1) |
|
if key not in intersection_counts: |
|
intersection_counts[key] = torch.zeros(num_labels, |
|
amask.shape[1], dtype=torch.long, device=device) |
|
intersection_counts[key][label] += ic |
|
|
|
|
|
|
|
cc = batch_categories[:,None] * ac[None,:] |
|
if key not in category_activation_counts: |
|
category_activation_counts[key] = cc |
|
else: |
|
category_activation_counts[key] += cc |
|
iou_scores = {} |
|
iqr_scores = {} |
|
for k in intersection_counts: |
|
iou_scores[k], iqr_scores[k] = score_tally_stats( |
|
label_category, total_counts, label_counts, |
|
category_activation_counts[k], intersection_counts[k]) |
|
for k in intersection_counts: |
|
numpy.savez(os.path.join(outdir, safe_dir_name(k), 'bincounts.npz'), |
|
iou_scores=iou_scores[k].cpu().numpy(), |
|
iqr_scores=iqr_scores[k].cpu().numpy(), |
|
total_counts=total_counts.cpu().numpy(), |
|
label_counts=label_counts.cpu().numpy(), |
|
category_activation_counts=category_activation_counts[k] |
|
.cpu().numpy(), |
|
intersection_counts=intersection_counts[k].cpu().numpy(), |
|
levels=levels[k].cpu().numpy()) |
|
return (iou_scores, iqr_scores, |
|
total_counts, label_counts, category_activation_counts, |
|
intersection_counts) |
|
|
|
def collect_cond_quantiles(outdir, model, segloader, segrunner): |
|
''' |
|
Returns maxiou and maxiou_level across the data set, one per layer. |
|
|
|
This is a performance-sensitive function. Best performance is |
|
achieved with a counting scheme which assumes a segloader with |
|
batch_size 1. |
|
''' |
|
device = next(model.parameters()).device |
|
cached_cond_quantiles = { |
|
layer: load_conditional_quantile_if_present(os.path.join(outdir, |
|
safe_dir_name(layer)), 'cond_quantiles.npz') |
|
for layer in model.retained_features() } |
|
label_fracs = load_npy_if_present(outdir, 'label_fracs.npy', 'cpu') |
|
if label_fracs is not None and all( |
|
value is not None for value in cached_cond_quantiles.values()): |
|
return cached_cond_quantiles, label_fracs |
|
|
|
labelcat, categories = segrunner.get_label_and_category_names() |
|
label_category = [categories.index(c) if c in categories else 0 |
|
for l, c in labelcat] |
|
num_labels, num_categories = (len(n) for n in [labelcat, categories]) |
|
|
|
|
|
labelcat = torch.zeros(num_labels, num_categories, |
|
dtype=torch.long, device=device) |
|
labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category, |
|
dtype='int64')).to(device)[:,None], 1) |
|
|
|
assert segloader.batch_size == 1 |
|
conditional_quantiles = {} |
|
label_counts = torch.zeros(num_labels, dtype=torch.long, device=device) |
|
pixel_count = 0 |
|
progress = default_progress() |
|
scale_offset_map = getattr(model, 'scale_offset', None) |
|
upsample_grids = {} |
|
common_conditions = set() |
|
if label_fracs is None or label_fracs == 0: |
|
for i, batch in enumerate(progress(segloader, desc='label fracs')): |
|
seg, batch_label_counts, im, _ = segrunner.run_and_segment_batch( |
|
batch, model, want_bincount=True, want_rgb=True) |
|
batch_label_counts = batch_label_counts.to(device) |
|
features = model.retained_features() |
|
|
|
label_counts += batch_label_counts[0] |
|
pixel_count += seg.shape[2] * seg.shape[3] |
|
label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None] |
|
numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs) |
|
|
|
skip_threshold = 1e-4 |
|
skip_labels = set(i.item() |
|
for i in (label_fracs.view(-1) < skip_threshold).nonzero().view(-1)) |
|
|
|
for layer in progress(model.retained_features().keys(), desc='CQ layers'): |
|
if cached_cond_quantiles.get(layer, None) is not None: |
|
conditional_quantiles[layer] = cached_cond_quantiles[layer] |
|
continue |
|
|
|
for i, batch in enumerate(progress(segloader, desc='Condquant')): |
|
seg, batch_label_counts, _, imshape = ( |
|
segrunner.run_and_segment_batch( |
|
batch, model, want_bincount=True, want_rgb=True)) |
|
bc = batch_label_counts.cpu() |
|
batch_label_counts = batch_label_counts.to(device) |
|
features = model.retained_features() |
|
|
|
label_counts += batch_label_counts[0] |
|
pixel_count += seg.shape[2] * seg.shape[3] |
|
batch_labels = bc[0].nonzero()[:,0] |
|
batch_categories = labelcat[batch_labels].max(0)[0] |
|
cpu_seg = None |
|
value = features[layer] |
|
if layer not in upsample_grids: |
|
upsample_grids[layer] = upsample_grid(value.shape[2:], |
|
seg.shape[2:], imshape, |
|
scale_offset=scale_offset_map.get(layer, None) |
|
if scale_offset_map is not None else None, |
|
dtype=value.dtype, device=value.device) |
|
if layer not in conditional_quantiles: |
|
conditional_quantiles[layer] = RunningConditionalQuantile( |
|
resolution=2048) |
|
upsampled = torch.nn.functional.grid_sample(value, |
|
upsample_grids[layer], padding_mode='border').view( |
|
value.shape[1], -1) |
|
conditional_quantiles[layer].add(('all',), upsampled.t()) |
|
cpu_upsampled = None |
|
for label in batch_labels.tolist(): |
|
if label in skip_labels: |
|
continue |
|
label_key = ('label', label) |
|
if label_key in common_conditions: |
|
imask = (seg == label).max(dim=1)[0].view(-1) |
|
intersected = upsampled[:, imask] |
|
conditional_quantiles[layer].add(('label', label), |
|
intersected.t()) |
|
else: |
|
if cpu_seg is None: |
|
cpu_seg = seg.cpu() |
|
if cpu_upsampled is None: |
|
cpu_upsampled = upsampled.cpu() |
|
imask = (cpu_seg == label).max(dim=1)[0].view(-1) |
|
intersected = cpu_upsampled[:, imask] |
|
conditional_quantiles[layer].add(('label', label), |
|
intersected.t()) |
|
if num_categories > 1: |
|
for cat in batch_categories.nonzero()[:,0]: |
|
conditional_quantiles[layer].add(('cat', cat.item()), |
|
upsampled.t()) |
|
|
|
if i and not i & (i - 1): |
|
cq = conditional_quantiles[layer] |
|
common_conditions = set(cq.most_common_conditions(64)) |
|
cq.to_('cpu', [k for k in cq.running_quantiles.keys() |
|
if k not in common_conditions]) |
|
|
|
conditional_quantiles[layer].to_('cpu') |
|
|
|
label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None] |
|
|
|
for cq in conditional_quantiles.values(): |
|
cq.to_('cpu') |
|
|
|
for layer in conditional_quantiles: |
|
save_state_dict(conditional_quantiles[layer], |
|
os.path.join(outdir, safe_dir_name(layer), 'cond_quantiles.npz')) |
|
numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs) |
|
|
|
return conditional_quantiles, label_fracs |
|
|
|
|
|
def collect_maxiou(outdir, model, segloader, segrunner): |
|
''' |
|
Returns maxiou and maxiou_level across the data set, one per layer. |
|
|
|
This is a performance-sensitive function. Best performance is |
|
achieved with a counting scheme which assumes a segloader with |
|
batch_size 1. |
|
''' |
|
device = next(model.parameters()).device |
|
conditional_quantiles, label_fracs = collect_cond_quantiles( |
|
outdir, model, segloader, segrunner) |
|
|
|
labelcat, categories = segrunner.get_label_and_category_names() |
|
label_category = [categories.index(c) if c in categories else 0 |
|
for l, c in labelcat] |
|
num_labels, num_categories = (len(n) for n in [labelcat, categories]) |
|
|
|
label_list = [('label', i) for i in range(num_labels)] |
|
category_list = [('all',)] if num_categories <= 1 else ( |
|
[('cat', i) for i in range(num_categories)]) |
|
max_iou, max_iou_level, max_iou_quantile = {}, {}, {} |
|
fracs = torch.logspace(-3, 0, 100) |
|
progress = default_progress() |
|
for layer, cq in progress(conditional_quantiles.items(), desc='Maxiou'): |
|
levels = cq.conditional(('all',)).quantiles(1 - fracs) |
|
denoms = 1 - cq.collected_normalize(category_list, levels) |
|
isects = (1 - cq.collected_normalize(label_list, levels)) * label_fracs |
|
unions = label_fracs + denoms[label_category, :, :] - isects |
|
iou = isects / unions |
|
|
|
max_iou[layer], level_bucket = iou.max(2) |
|
max_iou_level[layer] = levels[ |
|
torch.arange(levels.shape[0])[None,:], level_bucket] |
|
max_iou_quantile[layer] = fracs[level_bucket] |
|
for layer in model.retained_features(): |
|
numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'max_iou.npz'), |
|
max_iou=max_iou[layer].cpu().numpy(), |
|
max_iou_level=max_iou_level[layer].cpu().numpy(), |
|
max_iou_quantile=max_iou_quantile[layer].cpu().numpy()) |
|
return (max_iou, max_iou_level, max_iou_quantile) |
|
|
|
def collect_iqr(outdir, model, segloader, segrunner): |
|
''' |
|
Returns iqr and iqr_level. |
|
|
|
This is a performance-sensitive function. Best performance is |
|
achieved with a counting scheme which assumes a segloader with |
|
batch_size 1. |
|
''' |
|
max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou = {}, {}, {}, {} |
|
max_iqr_agreement = {} |
|
found_all = True |
|
for layer in model.retained_features(): |
|
filename = os.path.join(outdir, safe_dir_name(layer), 'iqr.npz') |
|
if os.path.isfile(filename): |
|
data = numpy.load(filename) |
|
max_iqr[layer] = torch.from_numpy(data['max_iqr']) |
|
max_iqr_level[layer] = torch.from_numpy(data['max_iqr_level']) |
|
max_iqr_quantile[layer] = torch.from_numpy(data['max_iqr_quantile']) |
|
max_iqr_iou[layer] = torch.from_numpy(data['max_iqr_iou']) |
|
max_iqr_agreement[layer] = torch.from_numpy( |
|
data['max_iqr_agreement']) |
|
else: |
|
found_all = False |
|
if found_all: |
|
return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou, |
|
max_iqr_agreement) |
|
|
|
|
|
device = next(model.parameters()).device |
|
conditional_quantiles, label_fracs = collect_cond_quantiles( |
|
outdir, model, segloader, segrunner) |
|
|
|
labelcat, categories = segrunner.get_label_and_category_names() |
|
label_category = [categories.index(c) if c in categories else 0 |
|
for l, c in labelcat] |
|
num_labels, num_categories = (len(n) for n in [labelcat, categories]) |
|
|
|
label_list = [('label', i) for i in range(num_labels)] |
|
category_list = [('all',)] if num_categories <= 1 else ( |
|
[('cat', i) for i in range(num_categories)]) |
|
full_mi, full_je, full_iqr = {}, {}, {} |
|
fracs = torch.logspace(-3, 0, 100) |
|
progress = default_progress() |
|
for layer, cq in progress(conditional_quantiles.items(), desc='IQR'): |
|
levels = cq.conditional(('all',)).quantiles(1 - fracs) |
|
truth = label_fracs.to(device) |
|
preds = (1 - cq.collected_normalize(category_list, levels) |
|
)[label_category, :, :].to(device) |
|
cond_isects = 1 - cq.collected_normalize(label_list, levels).to(device) |
|
isects = cond_isects * truth |
|
unions = truth + preds - isects |
|
arr = torch.empty(size=(2, 2) + isects.shape, dtype=isects.dtype, |
|
device=device) |
|
arr[0, 0] = isects |
|
arr[0, 1] = preds - isects |
|
arr[1, 0] = truth - isects |
|
arr[1, 1] = 1 - unions |
|
arr.clamp_(0, 1) |
|
mi = mutual_information(arr) |
|
mi[:,:,-1] = 0 |
|
|
|
|
|
mi[label_fracs.view(-1) < 1e-3, :, :] = 0 |
|
je = joint_entropy(arr) |
|
iqr = mi / je |
|
iqr[torch.isnan(iqr)] = 0 |
|
full_mi[layer] = mi.cpu() |
|
full_je[layer] = je.cpu() |
|
full_iqr[layer] = iqr.cpu() |
|
del mi, je |
|
agreement = isects + arr[1, 1] |
|
|
|
|
|
|
|
positive_iqr = iqr |
|
positive_iqr[agreement <= 0.8] = 0 |
|
positive_iqr[(levels <= 0.0)[None, :, :].expand(positive_iqr.shape)] = 0 |
|
|
|
maxiqr, level_bucket = positive_iqr.max(2) |
|
max_iqr[layer] = maxiqr.cpu() |
|
max_iqr_level[layer] = levels.to(device)[ |
|
torch.arange(levels.shape[0])[None,:], level_bucket].cpu() |
|
max_iqr_quantile[layer] = fracs.to(device)[level_bucket].cpu() |
|
max_iqr_agreement[layer] = agreement[ |
|
torch.arange(agreement.shape[0])[:, None], |
|
torch.arange(agreement.shape[1])[None, :], |
|
level_bucket].cpu() |
|
|
|
|
|
matching_iou = (isects[ |
|
torch.arange(isects.shape[0])[:, None], |
|
torch.arange(isects.shape[1])[None, :], |
|
level_bucket] / |
|
unions[ |
|
torch.arange(unions.shape[0])[:, None], |
|
torch.arange(unions.shape[1])[None, :], |
|
level_bucket]) |
|
matching_iou[torch.isnan(matching_iou)] = 0 |
|
max_iqr_iou[layer] = matching_iou.cpu() |
|
for layer in model.retained_features(): |
|
numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'iqr.npz'), |
|
max_iqr=max_iqr[layer].cpu().numpy(), |
|
max_iqr_level=max_iqr_level[layer].cpu().numpy(), |
|
max_iqr_quantile=max_iqr_quantile[layer].cpu().numpy(), |
|
max_iqr_iou=max_iqr_iou[layer].cpu().numpy(), |
|
max_iqr_agreement=max_iqr_agreement[layer].cpu().numpy(), |
|
full_mi=full_mi[layer].cpu().numpy(), |
|
full_je=full_je[layer].cpu().numpy(), |
|
full_iqr=full_iqr[layer].cpu().numpy()) |
|
return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou, |
|
max_iqr_agreement) |
|
|
|
def mutual_information(arr): |
|
total = 0 |
|
for j in range(arr.shape[0]): |
|
for k in range(arr.shape[1]): |
|
joint = arr[j,k] |
|
ind = arr[j,:].sum(dim=0) * arr[:,k].sum(dim=0) |
|
term = joint * (joint / ind).log() |
|
term[torch.isnan(term)] = 0 |
|
total += term |
|
return total.clamp_(0) |
|
|
|
def joint_entropy(arr): |
|
total = 0 |
|
for j in range(arr.shape[0]): |
|
for k in range(arr.shape[1]): |
|
joint = arr[j,k] |
|
term = joint * joint.log() |
|
term[torch.isnan(term)] = 0 |
|
total += term |
|
return (-total).clamp_(0) |
|
|
|
def information_quality_ratio(arr): |
|
iqr = mutual_information(arr) / joint_entropy(arr) |
|
iqr[torch.isnan(iqr)] = 0 |
|
return iqr |
|
|
|
def collect_covariance(outdir, model, segloader, segrunner): |
|
''' |
|
Returns label_mean, label_variance, unit_mean, unit_variance, |
|
and cross_covariance across the data set. |
|
|
|
label_mean, label_variance (independent of model): |
|
treating the label as a one-hot, each label's mean and variance. |
|
unit_mean, unit_variance (one per layer): for each feature channel, |
|
the mean and variance of the activations in that channel. |
|
cross_covariance (one per layer): the cross covariance between the |
|
labels and the units in the layer. |
|
''' |
|
device = next(model.parameters()).device |
|
cached_covariance = { |
|
layer: load_covariance_if_present(os.path.join(outdir, |
|
safe_dir_name(layer)), 'covariance.npz', device=device) |
|
for layer in model.retained_features() } |
|
if all(value is not None for value in cached_covariance.values()): |
|
return cached_covariance |
|
labelcat, categories = segrunner.get_label_and_category_names() |
|
label_category = [categories.index(c) if c in categories else 0 |
|
for l, c in labelcat] |
|
num_labels, num_categories = (len(n) for n in [labelcat, categories]) |
|
|
|
|
|
cov = {} |
|
progress = default_progress() |
|
scale_offset_map = getattr(model, 'scale_offset', None) |
|
upsample_grids = {} |
|
for i, batch in enumerate(progress(segloader, desc='Covariance')): |
|
seg, _, _, imshape = segrunner.run_and_segment_batch(batch, model, |
|
want_rgb=True) |
|
features = model.retained_features() |
|
ohfeats = multilabel_onehot(seg, num_labels, ignore_index=0) |
|
|
|
for key, value in features.items(): |
|
if key not in upsample_grids: |
|
upsample_grids[key] = upsample_grid(value.shape[2:], |
|
seg.shape[2:], imshape, |
|
scale_offset=scale_offset_map.get(key, None) |
|
if scale_offset_map is not None else None, |
|
dtype=value.dtype, device=value.device) |
|
upsampled = torch.nn.functional.grid_sample(value, |
|
upsample_grids[key].expand( |
|
(value.shape[0],) + upsample_grids[key].shape[1:]), |
|
padding_mode='border') |
|
if key not in cov: |
|
cov[key] = RunningCrossCovariance() |
|
cov[key].add(upsampled, ohfeats) |
|
for layer in cov: |
|
save_state_dict(cov[layer], |
|
os.path.join(outdir, safe_dir_name(layer), 'covariance.npz')) |
|
return cov |
|
|
|
def multilabel_onehot(labels, num_labels, dtype=None, ignore_index=None): |
|
''' |
|
Converts a multilabel tensor into a onehot tensor. |
|
|
|
The input labels is a tensor of shape (samples, multilabels, y, x). |
|
The output is a tensor of shape (samples, num_labels, y, x). |
|
If ignore_index is specified, labels with that index are ignored. |
|
Each x in labels should be 0 <= x < num_labels, or x == ignore_index. |
|
''' |
|
assert ignore_index is None or ignore_index <= 0 |
|
if dtype is None: |
|
dtype = torch.float |
|
device = labels.device |
|
chans = num_labels + (-ignore_index if ignore_index else 0) |
|
outshape = (labels.shape[0], chans) + labels.shape[2:] |
|
result = torch.zeros(outshape, device=device, dtype=dtype) |
|
if ignore_index and ignore_index < 0: |
|
labels = labels + (-ignore_index) |
|
result.scatter_(1, labels, 1) |
|
if ignore_index and ignore_index < 0: |
|
result = result[:, -ignore_index:] |
|
elif ignore_index is not None: |
|
result[:, ignore_index] = 0 |
|
return result |
|
|
|
def load_npy_if_present(outdir, filename, device): |
|
filepath = os.path.join(outdir, filename) |
|
if os.path.isfile(filepath): |
|
data = numpy.load(filepath) |
|
return torch.from_numpy(data).to(device) |
|
return 0 |
|
|
|
def load_npz_if_present(outdir, filename, varnames, device): |
|
filepath = os.path.join(outdir, filename) |
|
if os.path.isfile(filepath): |
|
data = numpy.load(filepath) |
|
numpy_result = [data[n] for n in varnames] |
|
return tuple(torch.from_numpy(data).to(device) for data in numpy_result) |
|
return None |
|
|
|
def load_quantile_if_present(outdir, filename, device): |
|
filepath = os.path.join(outdir, filename) |
|
if os.path.isfile(filepath): |
|
data = numpy.load(filepath) |
|
result = RunningQuantile(state=data) |
|
result.to_(device) |
|
return result |
|
return None |
|
|
|
def load_conditional_quantile_if_present(outdir, filename): |
|
filepath = os.path.join(outdir, filename) |
|
if os.path.isfile(filepath): |
|
data = numpy.load(filepath) |
|
result = RunningConditionalQuantile(state=data) |
|
return result |
|
return None |
|
|
|
def load_topk_if_present(outdir, filename, device): |
|
filepath = os.path.join(outdir, filename) |
|
if os.path.isfile(filepath): |
|
data = numpy.load(filepath) |
|
result = RunningTopK(state=data) |
|
result.to_(device) |
|
return result |
|
return None |
|
|
|
def load_covariance_if_present(outdir, filename, device): |
|
filepath = os.path.join(outdir, filename) |
|
if os.path.isfile(filepath): |
|
data = numpy.load(filepath) |
|
result = RunningCrossCovariance(state=data) |
|
result.to_(device) |
|
return result |
|
return None |
|
|
|
def save_state_dict(obj, filepath): |
|
dirname = os.path.dirname(filepath) |
|
os.makedirs(dirname, exist_ok=True) |
|
dic = obj.state_dict() |
|
numpy.savez(filepath, **dic) |
|
|
|
def upsample_grid(data_shape, target_shape, input_shape=None, |
|
scale_offset=None, dtype=torch.float, device=None): |
|
'''Prepares a grid to use with grid_sample to upsample a batch of |
|
features in data_shape to the target_shape. Can use scale_offset |
|
and input_shape to center the grid in a nondefault way: scale_offset |
|
maps feature pixels to input_shape pixels, and it is assumed that |
|
the target_shape is a uniform downsampling of input_shape.''' |
|
|
|
if target_shape is None: |
|
target_shape = data_shape |
|
|
|
if scale_offset is None: |
|
scale = tuple(float(ts) / ds |
|
for ts, ds in zip(target_shape, data_shape)) |
|
offset = tuple(0.5 * s - 0.5 for s in scale) |
|
else: |
|
scale, offset = (v for v in zip(*scale_offset)) |
|
|
|
if input_shape is not None: |
|
scale = tuple(s * (ts - 1) / (ns - 1) |
|
for s, ns, ts in zip(scale, input_shape, target_shape)) |
|
offset = tuple(o * (ts - 1) / (ns - 1) |
|
for o, ns, ts in zip(offset, input_shape, target_shape)) |
|
|
|
ty, tx = (((torch.arange(ts, dtype=dtype, device=device) - o) |
|
* (2 / (s * (ss - 1))) - 1) |
|
for ts, ss, s, o, in zip(target_shape, data_shape, scale, offset)) |
|
|
|
grid = torch.stack( |
|
(tx[None,:].expand(target_shape), ty[:,None].expand(target_shape)),2 |
|
)[None,:,:,:].expand((1, target_shape[0], target_shape[1], 2)) |
|
return grid |
|
|
|
def safe_dir_name(filename): |
|
keepcharacters = (' ','.','_','-') |
|
return ''.join(c |
|
for c in filename if c.isalnum() or c in keepcharacters).rstrip() |
|
|
|
bargraph_palette = [ |
|
('#4B4CBF', '#B6B6F2'), |
|
('#55B05B', '#B6F2BA'), |
|
('#50BDAC', '#A5E5DB'), |
|
('#81C679', '#C0FF9B'), |
|
('#F0883B', '#F2CFB6'), |
|
('#D4CF24', '#F2F1B6'), |
|
('#D92E2B', '#F2B6B6'), |
|
('#AB6BC6', '#CFAAFF'), |
|
] |
|
|
|
def make_svg_bargraph(labels, heights, categories, |
|
barheight=100, barwidth=12, show_labels=True, filename=None): |
|
|
|
|
|
unitheight = float(barheight) / max(max(heights, default=1), 1) |
|
textheight = barheight if show_labels else 0 |
|
labelsize = float(barwidth) |
|
gap = float(barwidth) / 4 |
|
textsize = barwidth + gap |
|
rollup = max(heights, default=1) |
|
textmargin = float(labelsize) * 2 / 3 |
|
leftmargin = 32 |
|
rightmargin = 8 |
|
svgwidth = len(heights) * (barwidth + gap) + 2 * leftmargin + rightmargin |
|
svgheight = barheight + textheight |
|
|
|
|
|
svg = et.Element('svg', width=str(svgwidth), height=str(svgheight), |
|
version='1.1', xmlns='http://www.w3.org/2000/svg') |
|
|
|
|
|
basey = svgheight - textheight |
|
x = leftmargin |
|
|
|
if len(heights): |
|
for h in [1, (max(heights) + 1) // 2, max(heights)]: |
|
et.SubElement(svg, 'text', x='0', y='0', |
|
style=('font-family:sans-serif;font-size:%dpx;' + |
|
'text-anchor:end;alignment-baseline:hanging;' + |
|
'transform:translate(%dpx, %dpx);') % |
|
(textsize, x - gap, basey - h * unitheight)).text = str(h) |
|
et.SubElement(svg, 'text', x='0', y='0', |
|
style=('font-family:sans-serif;font-size:%dpx;' + |
|
'text-anchor:middle;' + |
|
'transform:translate(%dpx, %dpx) rotate(-90deg)') % |
|
(textsize, x - gap - textsize, basey - h * unitheight / 2) |
|
).text = 'units' |
|
|
|
for catindex, (cat, catcount) in enumerate(categories): |
|
if not catcount: |
|
continue |
|
et.SubElement(svg, 'rect', x=str(x), y=str(basey - rollup * unitheight), |
|
width=(str((barwidth + gap) * catcount - gap)), |
|
height = str(rollup*unitheight), |
|
fill=bargraph_palette[catindex % len(bargraph_palette)][1]) |
|
x += (barwidth + gap) * catcount |
|
|
|
x = leftmargin |
|
catindex = -1 |
|
catcount = 0 |
|
for label, height in zip(labels, heights): |
|
while not catcount and catindex <= len(categories): |
|
catindex += 1 |
|
catcount = categories[catindex][1] |
|
color = bargraph_palette[catindex % len(bargraph_palette)][0] |
|
et.SubElement(svg, 'rect', x=str(x), y=str(basey-(height * unitheight)), |
|
width=str(barwidth), height=str(height * unitheight), |
|
fill=color) |
|
x += barwidth |
|
if show_labels: |
|
et.SubElement(svg, 'text', x='0', y='0', |
|
style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+ |
|
'transform:translate(%dpx, %dpx) rotate(-45deg);') % |
|
(labelsize, x, basey + textmargin)).text = readable(label) |
|
x += gap |
|
catcount -= 1 |
|
|
|
x = leftmargin |
|
for cat, catcount in categories: |
|
if not catcount: |
|
continue |
|
et.SubElement(svg, 'text', x='0', y='0', |
|
style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+ |
|
'transform:translate(%dpx, %dpx) rotate(-90deg);') % |
|
(textsize, x + (barwidth + gap) * catcount - gap, |
|
basey - rollup * unitheight + gap)).text = '%d %s' % ( |
|
catcount, readable(cat + ('s' if catcount != 1 else ''))) |
|
x += (barwidth + gap) * catcount |
|
|
|
result = et.tostring(svg) |
|
if filename: |
|
f = open(filename, 'wb') |
|
|
|
f.write(''.join([ |
|
'<?xml version=\"1.0\" standalone=\"no\"?>\n', |
|
'<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n', |
|
'\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n'] |
|
).encode('utf-8')) |
|
f.write(result) |
|
f.close() |
|
return result |
|
|
|
readable_replacements = [(re.compile(r[0]), r[1]) for r in [ |
|
(r'-[sc]$', ''), |
|
(r'_', ' '), |
|
]] |
|
|
|
def readable(label): |
|
for pattern, subst in readable_replacements: |
|
label= re.sub(pattern, subst, label) |
|
return label |
|
|
|
def reverse_normalize_from_transform(transform): |
|
''' |
|
Crawl around the transforms attached to a dataset looking for a |
|
Normalize transform, and return it a corresponding ReverseNormalize, |
|
or None if no normalization is found. |
|
''' |
|
if isinstance(transform, torchvision.transforms.Normalize): |
|
return ReverseNormalize(transform.mean, transform.std) |
|
t = getattr(transform, 'transform', None) |
|
if t is not None: |
|
return reverse_normalize_from_transform(t) |
|
transforms = getattr(transform, 'transforms', None) |
|
if transforms is not None: |
|
for t in reversed(transforms): |
|
result = reverse_normalize_from_transform(t) |
|
if result is not None: |
|
return result |
|
return None |
|
|
|
class ReverseNormalize: |
|
''' |
|
Applies the reverse of torchvision.transforms.Normalize. |
|
''' |
|
def __init__(self, mean, stdev): |
|
mean = numpy.array(mean) |
|
stdev = numpy.array(stdev) |
|
self.mean = torch.from_numpy(mean)[None,:,None,None].float() |
|
self.stdev = torch.from_numpy(stdev)[None,:,None,None].float() |
|
def __call__(self, data): |
|
device = data.device |
|
return data.mul(self.stdev.to(device)).add_(self.mean.to(device)) |
|
|
|
class ImageOnlySegRunner: |
|
def __init__(self, dataset, recover_image=None): |
|
if recover_image is None: |
|
recover_image = reverse_normalize_from_transform(dataset) |
|
self.recover_image = recover_image |
|
self.dataset = dataset |
|
def get_label_and_category_names(self): |
|
return [('-', '-')], ['-'] |
|
def run_and_segment_batch(self, batch, model, |
|
want_bincount=False, want_rgb=False): |
|
[im] = batch |
|
device = next(model.parameters()).device |
|
if want_rgb: |
|
rgb = self.recover_image(im.clone() |
|
).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte() |
|
else: |
|
rgb = None |
|
|
|
seg = torch.zeros(im.shape[0], 1, 1, 1, dtype=torch.long) |
|
bc = torch.ones(im.shape[0], 1, dtype=torch.long) |
|
|
|
model(im.to(device)) |
|
return seg, bc, rgb, im.shape[2:] |
|
|
|
class ClassifierSegRunner: |
|
def __init__(self, dataset, recover_image=None): |
|
|
|
if recover_image is None: |
|
recover_image = reverse_normalize_from_transform(dataset) |
|
self.recover_image = recover_image |
|
self.dataset = dataset |
|
def get_label_and_category_names(self): |
|
catnames = self.dataset.categories |
|
label_and_cat_names = [(readable(label), |
|
catnames[self.dataset.label_category[i]]) |
|
for i, label in enumerate(self.dataset.labels)] |
|
return label_and_cat_names, catnames |
|
def run_and_segment_batch(self, batch, model, |
|
want_bincount=False, want_rgb=False): |
|
''' |
|
Runs the dissected model on one batch of the dataset, and |
|
returns a multilabel semantic segmentation for the data. |
|
Given a batch of size (n, c, y, x) the segmentation should |
|
be a (long integer) tensor of size (n, d, y//r, x//r) where |
|
d is the maximum number of simultaneous labels given to a pixel, |
|
and where r is some (optional) resolution reduction factor. |
|
In the segmentation returned, the label `0` is reserved for |
|
the background "no-label". |
|
|
|
In addition to the segmentation, bc, rgb, and shape are returned |
|
where bc is a per-image bincount counting returned label pixels, |
|
rgb is a viewable (n, y, x, rgb) byte image tensor for the data |
|
for visualizations (reversing normalizations, for example), and |
|
shape is the (y, x) size of the data. If want_bincount or |
|
want_rgb are False, those return values may be None. |
|
''' |
|
im, seg, bc = batch |
|
device = next(model.parameters()).device |
|
if want_rgb: |
|
rgb = self.recover_image(im.clone() |
|
).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte() |
|
else: |
|
rgb = None |
|
|
|
model(im.to(device)) |
|
return seg, bc, rgb, im.shape[2:] |
|
|
|
class GeneratorSegRunner: |
|
def __init__(self, segmenter): |
|
|
|
if segmenter is None: |
|
segmenter = UnifiedParsingSegmenter(segsizes=[256], segdiv='quad') |
|
self.segmenter = segmenter |
|
self.num_classes = len(segmenter.get_label_and_category_names()[0]) |
|
def get_label_and_category_names(self): |
|
return self.segmenter.get_label_and_category_names() |
|
def run_and_segment_batch(self, batch, model, |
|
want_bincount=False, want_rgb=False): |
|
''' |
|
Runs the dissected model on one batch of the dataset, and |
|
returns a multilabel semantic segmentation for the data. |
|
Given a batch of size (n, c, y, x) the segmentation should |
|
be a (long integer) tensor of size (n, d, y//r, x//r) where |
|
d is the maximum number of simultaneous labels given to a pixel, |
|
and where r is some (optional) resolution reduction factor. |
|
In the segmentation returned, the label `0` is reserved for |
|
the background "no-label". |
|
|
|
In addition to the segmentation, bc, rgb, and shape are returned |
|
where bc is a per-image bincount counting returned label pixels, |
|
rgb is a viewable (n, y, x, rgb) byte image tensor for the data |
|
for visualizations (reversing normalizations, for example), and |
|
shape is the (y, x) size of the data. If want_bincount or |
|
want_rgb are False, those return values may be None. |
|
''' |
|
device = next(model.parameters()).device |
|
z_batch = batch[0] |
|
tensor_images = model(z_batch.to(device)) |
|
seg = self.segmenter.segment_batch(tensor_images, downsample=2) |
|
if want_bincount: |
|
index = torch.arange(z_batch.shape[0], |
|
dtype=torch.long, device=device) |
|
bc = (seg + index[:, None, None, None] * self.num_classes).view(-1 |
|
).bincount(minlength=z_batch.shape[0] * self.num_classes) |
|
bc = bc.view(z_batch.shape[0], self.num_classes) |
|
else: |
|
bc = None |
|
if want_rgb: |
|
images = ((tensor_images + 1) / 2 * 255) |
|
rgb = images.permute(0, 2, 3, 1).clamp(0, 255).byte() |
|
else: |
|
rgb = None |
|
return seg, bc, rgb, tensor_images.shape[2:] |
|
|