import os, sys, numpy, torch, argparse, skimage, json, shutil |
from PIL import Image |
from torch.utils.data import TensorDataset |
from matplotlib.figure import Figure |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
import matplotlib.gridspec as gridspec |
from scipy.ndimage.morphology import binary_dilation |
import netdissect.zdataset |
import netdissect.nethook |
from netdissect.dissection import safe_dir_name |
from netdissect.progress import verbose_progress, default_progress |
from netdissect.progress import print_progress, desc_progress, post_progress |
from netdissect.easydict import EasyDict |
from netdissect.workerpool import WorkerPool, WorkerBase |
from netdissect.runningstats import RunningQuantile |
from netdissect.pidfile import pidfile_taken |
from netdissect.modelconfig import create_instrumented_model |
from netdissect.autoeval import autoimport_eval |
def main(): |
parser = argparse.ArgumentParser(description='ACE optimization utility', |
prog='python -m netdissect.aceoptimize') |
parser.add_argument('--model', type=str, default=None, |
help='constructor for the model to test') |
parser.add_argument('--pthfile', type=str, default=None, |
help='filename of .pth file for the model') |
parser.add_argument('--segmenter', type=str, default=None, |
help='constructor for asegmenter class') |
parser.add_argument('--classname', type=str, default=None, |
help='intervention classname') |
parser.add_argument('--layer', type=str, default='layer4', |
help='layer name') |
parser.add_argument('--search_size', type=int, default=10000, |
help='size of search for finding training locations') |
parser.add_argument('--train_size', type=int, default=1000, |
help='size of training set') |
parser.add_argument('--eval_size', type=int, default=200, |
help='size of eval set') |
parser.add_argument('--inference_batch_size', type=int, default=10, |
help='forward pass batch size') |
parser.add_argument('--train_batch_size', type=int, default=2, |
help='backprop pass batch size') |
parser.add_argument('--train_update_freq', type=int, default=10, |
help='number of batches for each training update') |
parser.add_argument('--train_epochs', type=int, default=10, |
help='number of epochs of training') |
parser.add_argument('--l2_lambda', type=float, default=0.005, |
help='l2 regularizer hyperparameter') |
parser.add_argument('--eval_only', action='store_true', default=False, |
help='reruns eval only on trained snapshots') |
parser.add_argument('--no-cuda', action='store_true', default=False, |
help='disables CUDA usage') |
parser.add_argument('--no-cache', action='store_true', default=False, |
help='disables reading of cache') |
parser.add_argument('--outdir', type=str, default=None, |
help='dissection directory') |
parser.add_argument('--variant', type=str, default=None, |
help='experiment variant') |
args = parser.parse_args() |
args.cuda = not args.no_cuda and torch.cuda.is_available() |
torch.backends.cudnn.benchmark = True |
run_command(args) |
def run_command(args): |
verbose_progress(True) |
progress = default_progress() |
classname = args.classname |
layer = args.layer |
num_eval_units = 20 |
assert os.path.isfile(os.path.join(args.outdir, 'dissect.json')), ( |
"Should be a dissection directory") |
if args.variant is None: |
args.variant = 'ace' |
if args.l2_lambda != 0.005: |
args.variant = '%s_reg%g' % (args.variant, args.l2_lambda) |
cachedir = os.path.join(args.outdir, safe_dir_name(layer), args.variant, |
classname) |
if pidfile_taken(os.path.join(cachedir, 'lock.pid'), True): |
sys.exit(0) |
with open(os.path.join(args.outdir, 'dissect.json')) as f: |
dissection = EasyDict(json.load(f)) |
if args.model is None: |
args.model = dissection.settings.model |
if args.pthfile is None: |
args.pthfile = dissection.settings.pthfile |
if args.segmenter is None: |
args.segmenter = dissection.settings.segmenter |
if args.segmenter is None: |
args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" + |
"segsizes=[256], segdiv='quad')") |
if (not args.no_cache and |
os.path.isfile(os.path.join(cachedir, 'snapshots', 'epoch-%d.npy' % ( |
args.train_epochs - 1))) and |
os.path.isfile(os.path.join(cachedir, 'report.json'))): |
print('%s already done' % cachedir) |
sys.exit(0) |
os.makedirs(cachedir, exist_ok=True) |
model = create_instrumented_model(args, gen=True, edit=True, |
layers=[args.layer]) |
if model is None: |
print('No model specified') |
sys.exit(1) |
segmenter = autoimport_eval(args.segmenter) |
labelnames, catname = segmenter.get_label_and_category_names() |
classnum = [i for i, (n, c) in enumerate(labelnames) if n == classname][0] |
num_classes = len(labelnames) |
with open(os.path.join(cachedir, 'labelnames.json'), 'w') as f: |
json.dump(labelnames, f, indent=1) |
full_sample = netdissect.zdataset.z_sample_for_model(model, |
args.search_size, seed=10) |
second_sample = netdissect.zdataset.z_sample_for_model(model, |
args.search_size, seed=11) |
cache_filename = os.path.join(cachedir, 'corpus.npz') |
corpus = EasyDict() |
try: |
if not args.no_cache: |
corpus = EasyDict({k: torch.from_numpy(v) |
for k, v in numpy.load(cache_filename).items()}) |
except: |
pass |
compute_present_locations(args, corpus, cache_filename, |
model, segmenter, classnum, full_sample) |
compute_mean_present_features(args, corpus, cache_filename, model) |
compute_feature_quantiles(args, corpus, cache_filename, model, full_sample) |
compute_candidate_locations(args, corpus, cache_filename, model, segmenter, |
classnum, second_sample) |
init_ablation = initial_ablation(args, args.outdir) |
scores = train_ablation(args, corpus, cache_filename, |
model, segmenter, classnum, init_ablation) |
summarize_scores(args, corpus, cachedir, layer, classname, |
args.variant, scores) |
if args.variant == 'ace': |
add_ace_ranking_to_dissection(args.outdir, layer, classname, scores) |
class SaveImageWorker(WorkerBase): |
def work(self, data, filename): |
Image.fromarray(data).save(filename, optimize=True, quality=80) |
def plot_heatmap(output_filename, data, size=256): |
fig = Figure(figsize=(1, 1), dpi=size) |
canvas = FigureCanvas(fig) |
gs = gridspec.GridSpec(1, 1, left=0.0, right=1.0, bottom=0.0, top=1.0) |
ax = fig.add_subplot(gs[0]) |
ax.set_axis_off() |
ax.imshow(data, cmap='hot', aspect='equal', interpolation='nearest', |
vmin=-1, vmax=1) |
canvas.print_figure(output_filename, format='png') |
def draw_heatmap(output_filename, data, size=256): |
fig = Figure(figsize=(1, 1), dpi=size) |
canvas = FigureCanvas(fig) |
gs = gridspec.GridSpec(1, 1, left=0.0, right=1.0, bottom=0.0, top=1.0) |
ax = fig.add_subplot(gs[0]) |
ax.set_axis_off() |
ax.imshow(data, cmap='hot', aspect='equal', interpolation='nearest', |
vmin=-1, vmax=1) |
canvas.draw() |
image = numpy.fromstring(canvas.tostring_rgb(), dtype='uint8').reshape( |
(size, size, 3)) |
return image |
def compute_present_locations(args, corpus, cache_filename, |
model, segmenter, classnum, full_sample): |
if all(k in corpus for k in ['present_indices', |
'object_present_sample', 'object_present_location', |
'object_location_popularity', 'weighted_mean_present_feature']): |
return |
progress = default_progress() |
feature_shape = model.feature_shape[args.layer][2:] |
num_locations = numpy.prod(feature_shape).item() |
num_units = model.feature_shape[args.layer][1] |
with torch.no_grad(): |
weighted_feature_sum = torch.zeros(num_units).cuda() |
object_presence_scores = [] |
for [zbatch] in progress( |
torch.utils.data.DataLoader(TensorDataset(full_sample), |
batch_size=args.inference_batch_size, num_workers=10, |
pin_memory=True), |
desc="Object pool"): |
zbatch = zbatch.cuda() |
tensor_image = model(zbatch) |
segmented_image = segmenter.segment_batch(tensor_image, |
downsample=2) |
mask = (segmented_image == classnum).max(1)[0] |
score = torch.nn.functional.adaptive_avg_pool2d( |
mask.float(), feature_shape) |
object_presence_scores.append(score.cpu()) |
feat = model.retained_layer(args.layer) |
weighted_feature_sum += (feat * score[:,None,:,:]).view( |
feat.shape[0],feat.shape[1], -1).sum(2).sum(0) |
object_presence_at_feature = torch.cat(object_presence_scores) |
object_presence_at_image, object_location_in_image = ( |
object_presence_at_feature.view(args.search_size, -1).max(1)) |
best_presence_scores, best_presence_images = torch.sort( |
-object_presence_at_image) |
all_present_indices = torch.sort( |
best_presence_images[:(args.train_size+args.eval_size)])[0] |
corpus.present_indices = all_present_indices[:args.train_size] |
corpus.object_present_sample = full_sample[corpus.present_indices] |
corpus.object_present_location = object_location_in_image[ |
corpus.present_indices] |
corpus.object_location_popularity = torch.bincount( |
corpus.object_present_location, |
minlength=num_locations) |
corpus.weighted_mean_present_feature = (weighted_feature_sum.cpu() / ( |
1e-20 + object_presence_at_feature.view(-1).sum())) |
corpus.eval_present_indices = all_present_indices[-args.eval_size:] |
corpus.eval_present_sample = full_sample[corpus.eval_present_indices] |
corpus.eval_present_location = object_location_in_image[ |
corpus.eval_present_indices] |
if cache_filename: |
numpy.savez(cache_filename, **corpus) |
def compute_mean_present_features(args, corpus, cache_filename, model): |
if all(k in corpus for k in ['mean_present_feature']): |
return |
progress = default_progress() |
with torch.no_grad(): |
total_present_feature = 0 |
for [zbatch, featloc] in progress( |
torch.utils.data.DataLoader(TensorDataset( |
corpus.object_present_sample, |
corpus.object_present_location), |
batch_size=args.inference_batch_size, num_workers=10, |
pin_memory=True), |
desc="Mean activations"): |
zbatch = zbatch.cuda() |
featloc = featloc.cuda() |
tensor_image = model(zbatch) |
feat = model.retained_layer(args.layer) |
flatfeat = feat.view(feat.shape[0], feat.shape[1], -1) |
sum_feature_at_obj = flatfeat[ |
torch.arange(feat.shape[0]).to(feat.device), :, featloc |
].sum(0) |
total_present_feature = total_present_feature + sum_feature_at_obj |
corpus.mean_present_feature = (total_present_feature / len( |
corpus.object_present_sample)).cpu() |
if cache_filename: |
numpy.savez(cache_filename, **corpus) |
def compute_feature_quantiles(args, corpus, cache_filename, model, full_sample): |
if all(k in corpus for k in ['feature_99', 'feature_999']): |
return |
progress = default_progress() |
with torch.no_grad(): |
rq = RunningQuantile(resolution=10000) |
for [zbatch] in progress( |
torch.utils.data.DataLoader(TensorDataset(full_sample), |
batch_size=args.inference_batch_size, num_workers=10, |
pin_memory=True), |
desc="Calculating 0.999 quantile"): |
zbatch = zbatch.cuda() |
tensor_image = model(zbatch) |
feat = model.retained_layer(args.layer) |
rq.add(feat.permute(0, 2, 3, 1 |
).contiguous().view(-1, feat.shape[1])) |
result = rq.quantiles([0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999]) |
corpus.feature_001 = result[:, 0].cpu() |
corpus.feature_01 = result[:, 1].cpu() |
corpus.feature_10 = result[:, 2].cpu() |
corpus.feature_50 = result[:, 3].cpu() |
corpus.feature_90 = result[:, 4].cpu() |
corpus.feature_99 = result[:, 5].cpu() |
corpus.feature_999 = result[:, 6].cpu() |
numpy.savez(cache_filename, **corpus) |
def compute_candidate_locations(args, corpus, cache_filename, model, |
segmenter, classnum, second_sample): |
if all(k in corpus for k in ['candidate_indices', |
'candidate_sample', 'candidate_score', |
'candidate_location', 'object_score_at_candidate', |
'candidate_location_popularity']): |
return |
progress = default_progress() |
feature_shape = model.feature_shape[args.layer][2:] |
num_locations = numpy.prod(feature_shape).item() |
with torch.no_grad(): |
possible_locations = numpy.arange(num_locations) |
location_weights = (corpus.object_location_popularity).double() |
location_weights += (location_weights.mean()) / 10.0 |
location_weights = location_weights / location_weights.sum() |
candidate_scores = [] |
object_scores = [] |
prng = numpy.random.RandomState(1) |
for [zbatch] in progress( |
torch.utils.data.DataLoader(TensorDataset(second_sample), |
batch_size=args.inference_batch_size, num_workers=10, |
pin_memory=True), |
desc="Candidate pool"): |
batch_scores = torch.zeros((len(zbatch),) + feature_shape).cuda() |
flat_batch_scores = batch_scores.view(len(zbatch), -1) |
zbatch = zbatch.cuda() |
tensor_image = model(zbatch) |
segmented_image = segmenter.segment_batch(tensor_image, |
downsample=2) |
mask = (segmented_image == classnum).max(1)[0] |
object_score = torch.nn.functional.adaptive_avg_pool2d( |
mask.float(), feature_shape) |
baseline_presence = mask.float().view(mask.shape[0], -1).sum(1) |
edit_mask = torch.zeros((1, 1) + feature_shape).cuda() |
if '_tcm' in args.variant: |
replace_vec = (corpus.mean_present_feature |
[None,:,None,None].cuda()) |
else: |
replace_vec = (corpus.weighted_mean_present_feature |
[None,:,None,None].cuda()) |
for loc in prng.choice(possible_locations, replace=False, |
p=location_weights, size=5): |
edit_mask.zero_() |
edit_mask.view(-1)[loc] = 1 |
model.edit_layer(args.layer, |
ablation=edit_mask, replacement=replace_vec) |
tensor_image = model(zbatch) |
segmented_image = segmenter.segment_batch(tensor_image, |
downsample=2) |
mask = (segmented_image == classnum).max(1)[0] |
modified_presence = mask.float().view( |
mask.shape[0], -1).sum(1) |
flat_batch_scores[:,loc] = ( |
modified_presence - baseline_presence) |
candidate_scores.append(batch_scores.cpu()) |
object_scores.append(object_score.cpu()) |
object_scores = torch.cat(object_scores) |
candidate_scores = torch.cat(candidate_scores) |
candidate_scores = candidate_scores * (object_scores == 0).float() |
candidate_score_at_image, candidate_location_in_image = ( |
candidate_scores.view(args.search_size, -1).max(1)) |
best_candidate_scores, best_candidate_images = torch.sort( |
-candidate_score_at_image) |
all_candidate_indices = torch.sort( |
best_candidate_images[:(args.train_size+args.eval_size)])[0] |
corpus.candidate_indices = all_candidate_indices[:args.train_size] |
corpus.candidate_sample = second_sample[corpus.candidate_indices] |
corpus.candidate_location = candidate_location_in_image[ |
corpus.candidate_indices] |
corpus.candidate_score = candidate_score_at_image[ |
corpus.candidate_indices] |
corpus.object_score_at_candidate = object_scores.view( |
len(object_scores), -1)[ |
corpus.candidate_indices, corpus.candidate_location] |
corpus.candidate_location_popularity = torch.bincount( |
corpus.candidate_location, |
minlength=num_locations) |
corpus.eval_candidate_indices = all_candidate_indices[ |
-args.eval_size:] |
corpus.eval_candidate_sample = second_sample[ |
corpus.eval_candidate_indices] |
corpus.eval_candidate_location = candidate_location_in_image[ |
corpus.eval_candidate_indices] |
numpy.savez(cache_filename, **corpus) |
def visualize_training_locations(args, corpus, cachedir, model): |
progress = default_progress() |
feature_shape = model.feature_shape[args.layer][2:] |
num_locations = numpy.prod(feature_shape).item() |
with torch.no_grad(): |
imagedir = os.path.join(cachedir, 'image') |
os.makedirs(imagedir, exist_ok=True) |
image_saver = WorkerPool(SaveImageWorker) |
for group, group_sample, group_location, group_indices in [ |
('present', |
corpus.object_present_sample, |
corpus.object_present_location, |
corpus.present_indices), |
('candidate', |
corpus.candidate_sample, |
corpus.candidate_location, |
corpus.candidate_indices)]: |
for [zbatch, featloc, indices] in progress( |
torch.utils.data.DataLoader(TensorDataset( |
group_sample, group_location, group_indices), |
batch_size=args.inference_batch_size, num_workers=10, |
pin_memory=True), |
desc="Visualize %s" % group): |
zbatch = zbatch.cuda() |
tensor_image = model(zbatch) |
feature_mask = torch.zeros((len(zbatch), 1) + feature_shape) |
feature_mask.view(len(zbatch), -1).scatter_( |
1, featloc[:,None], 1) |
feature_mask = torch.nn.functional.adaptive_max_pool2d( |
feature_mask.float(), tensor_image.shape[-2:]).cuda() |
yellow = torch.Tensor([1.0, 1.0, -1.0] |
)[None, :, None, None].cuda() |
tensor_image = tensor_image * (1 - 0.5 * feature_mask) + ( |
0.5 * feature_mask * yellow) |
byte_image = (((tensor_image+1)/2)*255).clamp(0, 255).byte() |
numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy() |
for i, index in enumerate(indices): |
image_saver.add(numpy_image[i], os.path.join(imagedir, |
'%s_%d.jpg' % (group, index))) |
image_saver.join() |
def scale_summary(scale, lownums, highnums): |
value, order = (-(scale.detach())).cpu().sort(0) |
lowsum = ' '.join('%d: %.3g' % (o.item(), -v.item()) |
for v, o in zip(value[:lownums], order[:lownums])) |
highsum = ' '.join('%d: %.3g' % (o.item(), -v.item()) |
for v, o in zip(value[-highnums:], order[-highnums:])) |
return lowsum + ' ... ' + highsum |
def initial_ablation(args, dissectdir): |
with open(os.path.join(dissectdir, 'dissect.json')) as f: |
dissection = EasyDict(json.load(f)) |
lrec = [l for l in dissection.layers if l.layer == args.layer][0] |
rrec = [r for r in lrec.rankings if r.name == '%s-iou' % args.classname |
][0] |
init_scores = -torch.tensor(rrec.score) |
return init_scores / init_scores.max() |
def ace_loss(segmenter, classnum, model, layer, high_replacement, ablation, |
pbatch, ploc, cbatch, cloc, run_backward=False, |
discrete_pixels=False, |
discrete_units=False, |
mixed_units=False, |
ablation_only=False, |
fullimage_measurement=False, |
fullimage_ablation=False, |
): |
feature_shape = model.feature_shape[layer][2:] |
if discrete_units: |
assert discrete_units > 0 |
d = torch.zeros_like(ablation) |
top_units = torch.topk(ablation.view(-1), discrete_units)[1] |
if mixed_units: |
d.view(-1)[top_units] = ablation.view(-1)[top_units] |
else: |
d.view(-1)[top_units] = 1 |
ablation = d |
p_mask = torch.zeros((len(pbatch), 1) + feature_shape) |
if fullimage_ablation: |
p_mask[...] = 1 |
else: |
p_mask.view(len(pbatch), -1).scatter_(1, ploc[:,None], 1) |
p_mask = p_mask.cuda() |
a_p_mask = (ablation * p_mask) |
model.edit_layer(layer, ablation=a_p_mask, replacement=None) |
tensor_images = model(pbatch.cuda()) |
assert model._ablation[layer] is a_p_mask |
erase_effect, erased_mask = segmenter.predict_single_class( |
tensor_images, classnum, downsample=2) |
if discrete_pixels: |
erase_effect = erased_mask.float() |
erase_downsampled = torch.nn.functional.adaptive_avg_pool2d( |
erase_effect[:,None,:,:], feature_shape)[:,0,:,:] |
if fullimage_measurement: |
erase_loss = erase_downsampled.sum() |
else: |
erase_at_loc = erase_downsampled.view(len(erase_downsampled), -1 |
)[torch.arange(len(erase_downsampled)), ploc] |
erase_loss = erase_at_loc.sum() |
if run_backward: |
erase_loss.backward() |
if ablation_only: |
return erase_loss |
c_mask = torch.zeros((len(cbatch), 1) + feature_shape) |
c_mask.view(len(cbatch), -1).scatter_(1, cloc[:,None], 1) |
c_mask = c_mask.cuda() |
a_c_mask = (ablation * c_mask) |
model.edit_layer(layer, ablation=a_c_mask, replacement=high_replacement) |
tensor_images = model(cbatch.cuda()) |
assert model._ablation[layer] is a_c_mask |
add_effect, added_mask = segmenter.predict_single_class( |
tensor_images, classnum, downsample=2) |
if discrete_pixels: |
add_effect = added_mask.float() |
add_effect = -add_effect |
add_downsampled = torch.nn.functional.adaptive_avg_pool2d( |
add_effect[:,None,:,:], feature_shape)[:,0,:,:] |
if fullimage_measurement: |
add_loss = add_downsampled.mean() |
else: |
add_at_loc = add_downsampled.view(len(add_downsampled), -1 |
)[torch.arange(len(add_downsampled)), ploc] |
add_loss = add_at_loc.sum() |
if run_backward: |
add_loss.backward() |
return erase_loss + add_loss |
def train_ablation(args, corpus, cachefile, model, segmenter, classnum, |
initial_ablation=None): |
progress = default_progress() |
cachedir = os.path.dirname(cachefile) |
snapdir = os.path.join(cachedir, 'snapshots') |
os.makedirs(snapdir, exist_ok=True) |
if '_h99' in args.variant: |
high_replacement = corpus.feature_99[None,:,None,None].cuda() |
elif '_tcm' in args.variant: |
high_replacement = ( |
corpus.mean_present_feature[None,:,None,None].cuda()) |
else: |
high_replacement = ( |
corpus.weighted_mean_present_feature[None,:,None,None].cuda()) |
fullimage_measurement = False |
ablation_only = False |
fullimage_ablation = False |
if '_fim' in args.variant: |
fullimage_measurement = True |
elif '_fia' in args.variant: |
fullimage_measurement = True |
ablation_only = True |
fullimage_ablation = True |
high_replacement.requires_grad = False |
for p in model.parameters(): |
p.requires_grad = False |
ablation = torch.zeros(high_replacement.shape).cuda() |
if initial_ablation is not None: |
ablation.view(-1)[...] = initial_ablation |
ablation.requires_grad = True |
optimizer = torch.optim.Adam([ablation], lr=0.01) |
start_epoch = 0 |
epoch = 0 |
def eval_loss_and_reg(): |
discrete_experiments = dict( |
dboth20=dict(discrete_units=20, discrete_pixels=True), |
fimadbm10=dict(discrete_units=10, mixed_units=True, |
discrete_pixels=True, |
ablation_only=True, |
fullimage_ablation=True, |
fullimage_measurement=True), |
fimadbm20=dict(discrete_units=20, mixed_units=True, |
discrete_pixels=True, |
ablation_only=True, |
fullimage_ablation=True, |
fullimage_measurement=True) |
) |
with torch.no_grad(): |
total_loss = 0 |
discrete_losses = {k: 0 for k in discrete_experiments} |
for [pbatch, ploc, cbatch, cloc] in progress( |
torch.utils.data.DataLoader(TensorDataset( |
corpus.eval_present_sample, |
corpus.eval_present_location, |
corpus.eval_candidate_sample, |
corpus.eval_candidate_location), |
batch_size=args.inference_batch_size, num_workers=10, |
shuffle=False, pin_memory=True), |
desc="Eval"): |
total_loss = total_loss + ace_loss(segmenter, classnum, |
model, args.layer, high_replacement, ablation, |
pbatch, ploc, cbatch, cloc, run_backward=False, |
ablation_only=ablation_only, |
fullimage_measurement=fullimage_measurement) |
for k, config in discrete_experiments.items(): |
discrete_losses[k] = discrete_losses[k] + ace_loss( |
segmenter, classnum, |
model, args.layer, high_replacement, ablation, |
pbatch, ploc, cbatch, cloc, run_backward=False, |
**config) |
avg_loss = (total_loss / args.eval_size).item() |
avg_d_losses = {k: (d / args.eval_size).item() |
for k, d in discrete_losses.items()} |
regularizer = (args.l2_lambda * ablation.pow(2).sum()) |
print_progress('Epoch %d Loss %g Regularizer %g' % |
(epoch, avg_loss, regularizer)) |
print_progress(' '.join('%s: %g' % (k, d) |
for k, d in avg_d_losses.items())) |
print_progress(scale_summary(ablation.view(-1), 10, 3)) |
return avg_loss, regularizer, avg_d_losses |
if args.eval_only: |
for epoch in range(-1, args.train_epochs): |
snapfile = os.path.join(snapdir, 'epoch-%d.pth' % epoch) |
if not os.path.exists(snapfile): |
data = {} |
if epoch >= 0: |
print('No epoch %d' % epoch) |
continue |
else: |
data = torch.load(snapfile) |
with torch.no_grad(): |
ablation[...] = data['ablation'].to(ablation.device) |
optimizer.load_state_dict(data['optimizer']) |
avg_loss, regularizer, new_extra = eval_loss_and_reg() |
extra = {k: v for k, v in data.items() |
if k not in ['ablation', 'optimizer', 'avg_loss']} |
extra.update(new_extra) |
torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), |
avg_loss=avg_loss, **extra), |
os.path.join(snapdir, 'epoch-%d.pth' % epoch)) |
return ablation.view(-1).detach().cpu().numpy() |
if not args.no_cache: |
for start_epoch in reversed(range(args.train_epochs)): |
snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch) |
if os.path.exists(snapfile): |
data = torch.load(snapfile) |
with torch.no_grad(): |
ablation[...] = data['ablation'].to(ablation.device) |
optimizer.load_state_dict(data['optimizer']) |
start_epoch += 1 |
break |
if start_epoch < args.train_epochs: |
epoch = start_epoch - 1 |
avg_loss, regularizer, extra = eval_loss_and_reg() |
if epoch == -1: |
torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), |
avg_loss=avg_loss, **extra), |
os.path.join(snapdir, 'epoch-%d.pth' % epoch)) |
update_size = args.train_update_freq * args.train_batch_size |
for epoch in range(start_epoch, args.train_epochs): |
candidate_shuffle = torch.randperm(len(corpus.candidate_sample)) |
train_loss = 0 |
for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate(progress( |
torch.utils.data.DataLoader(TensorDataset( |
corpus.object_present_sample, |
corpus.object_present_location, |
corpus.candidate_sample[candidate_shuffle], |
corpus.candidate_location[candidate_shuffle]), |
batch_size=args.train_batch_size, num_workers=10, |
shuffle=True, pin_memory=True), |
desc="ACE opt epoch %d" % epoch)): |
if batch_num % args.train_update_freq == 0: |
optimizer.zero_grad() |
loss = ace_loss(segmenter, classnum, |
model, args.layer, high_replacement, ablation, |
pbatch, ploc, cbatch, cloc, run_backward=True, |
ablation_only=ablation_only, |
fullimage_measurement=fullimage_measurement) |
with torch.no_grad(): |
train_loss = train_loss + loss |
if (batch_num + 1) % args.train_update_freq == 0: |
regularizer = (args.l2_lambda * update_size |
* ablation.pow(2).sum()) |
regularizer.backward() |
optimizer.step() |
with torch.no_grad(): |
ablation.clamp_(0, 1) |
post_progress(l=(train_loss/update_size).item(), |
r=(regularizer/update_size).item()) |
train_loss = 0 |
avg_loss, regularizer, extra = eval_loss_and_reg() |
torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), |
avg_loss=avg_loss, **extra), |
os.path.join(snapdir, 'epoch-%d.pth' % epoch)) |
numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch), |
ablation.detach().cpu().numpy()) |
return ablation.view(-1).detach().cpu().numpy() |
def tensor_to_numpy_image_batch(tensor_image): |
byte_image = (((tensor_image+1)/2)*255).clamp(0, 255).byte() |
numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy() |
return numpy_image |
def evaluate_ablation(args, model, segmenter, eval_sample, classnum, layer, |
ordering): |
total_bincount = 0 |
data_size = 0 |
progress = default_progress() |
for l in model.ablation: |
model.ablation[l] = None |
feature_units = model.feature_shape[args.layer][1] |
feature_shape = model.feature_shape[args.layer][2:] |
repeats = len(ordering) |
total_scores = torch.zeros(repeats + 1) |
for i, batch in enumerate(progress(torch.utils.data.DataLoader( |
TensorDataset(eval_sample), |
batch_size=args.inference_batch_size, num_workers=10, |
pin_memory=True), |
desc="Evaluate interventions")): |
tensor_image = model(zbatch) |
segmented_image = segmenter.segment_batch(tensor_image, |
downsample=2) |
mask = (segmented_image == classnum).max(1)[0] |
downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( |
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
total_scores[0] += downsampled_seg.sum().cpu() |
interventions_needed = downsampled_seg.nonzero() |
location_count = len(interventions_needed) |
if location_count == 0: |
continue |
interventions_needed = interventions_needed.repeat(repeats, 1) |
inter_z = batch[0][interventions_needed[:,0]].to(device) |
inter_chan = torch.zeros(repeats, location_count, feature_units, |
device=device) |
for j, u in enumerate(ordering): |
inter_chan[j:, :, u] = 1 |
inter_chan = inter_chan.view(len(inter_z), feature_units) |
inter_loc = interventions_needed[:,1:] |
scores = torch.zeros(len(inter_z)) |
batch_size = len(batch[0]) |
for j in range(0, len(inter_z), batch_size): |
ibz = inter_z[j:j+batch_size] |
ibl = inter_loc[j:j+batch_size].t() |
imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) |
imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 |
ibc = inter_chan[j:j+batch_size] |
model.edit_layer(args.layer, ablation=( |
imask.float()[:,None,:,:] * ibc[:,:,None,None])) |
_, seg, _, _, _ = ( |
recovery.recover_im_seg_bc_and_features( |
[ibz], model)) |
mask = (seg == classnum).max(1)[0] |
downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( |
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
scores[j:j+batch_size] = downsampled_iseg[ |
(torch.arange(len(ibz)),) + tuple(ibl)] |
scores = scores.view(repeats, location_count).sum(1) |
total_scores[1:] += scores |
return total_scores |
def evaluate_interventions(args, model, segmenter, eval_sample, |
classnum, layer, units): |
total_bincount = 0 |
data_size = 0 |
progress = default_progress() |
for l in model.ablation: |
model.ablation[l] = None |
feature_units = model.feature_shape[args.layer][1] |
feature_shape = model.feature_shape[args.layer][2:] |
repeats = len(ordering) |
total_scores = torch.zeros(repeats + 1) |
for i, batch in enumerate(progress(torch.utils.data.DataLoader( |
TensorDataset(eval_sample), |
batch_size=args.inference_batch_size, num_workers=10, |
pin_memory=True), |
desc="Evaluate interventions")): |
tensor_image = model(zbatch) |
segmented_image = segmenter.segment_batch(tensor_image, |
downsample=2) |
mask = (segmented_image == classnum).max(1)[0] |
downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( |
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
total_scores[0] += downsampled_seg.sum().cpu() |
interventions_needed = downsampled_seg.nonzero() |
location_count = len(interventions_needed) |
if location_count == 0: |
continue |
interventions_needed = interventions_needed.repeat(repeats, 1) |
inter_z = batch[0][interventions_needed[:,0]].to(device) |
inter_chan = torch.zeros(repeats, location_count, feature_units, |
device=device) |
for j, u in enumerate(ordering): |
inter_chan[j:, :, u] = 1 |
inter_chan = inter_chan.view(len(inter_z), feature_units) |
inter_loc = interventions_needed[:,1:] |
scores = torch.zeros(len(inter_z)) |
batch_size = len(batch[0]) |
for j in range(0, len(inter_z), batch_size): |
ibz = inter_z[j:j+batch_size] |
ibl = inter_loc[j:j+batch_size].t() |
imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) |
imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 |
ibc = inter_chan[j:j+batch_size] |
model.ablation[args.layer] = ( |
imask.float()[:,None,:,:] * ibc[:,:,None,None]) |
_, seg, _, _, _ = ( |
recovery.recover_im_seg_bc_and_features( |
[ibz], model)) |
mask = (seg == classnum).max(1)[0] |
downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( |
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
scores[j:j+batch_size] = downsampled_iseg[ |
(torch.arange(len(ibz)),) + tuple(ibl)] |
scores = scores.view(repeats, location_count).sum(1) |
total_scores[1:] += scores |
return total_scores |
def add_ace_ranking_to_dissection(outdir, layer, classname, total_scores): |
source_filename = os.path.join(outdir, 'dissect.json') |
source_filename_bak = os.path.join(outdir, 'dissect.json.bak') |
if not os.path.exists(source_filename_bak): |
shutil.copy(source_filename, source_filename_bak) |
with open(source_filename) as f: |
dissection = EasyDict(json.load(f)) |
ranking_name = '%s-ace' % classname |
lrec = [l for l in dissection.layers if l.layer == layer][0] |
lrec.rankings = [r for r in lrec.rankings if r.name != ranking_name] |
new_rankings = [dict( |
name=ranking_name, |
score=(-total_scores).flatten().tolist(), |
metric='ace')] |
lrec.rankings[2:2] = new_rankings |
with open(source_filename, 'w') as f: |
json.dump(dissection, f, indent=1) |
def summarize_scores(args, corpus, cachedir, layer, classname, variant, scores): |
target_filename = os.path.join(cachedir, 'summary.json') |
ranking_name = '%s-%s' % (classname, variant) |
new_rankings = [dict( |
name=ranking_name, |
score=(-scores).flatten().tolist(), |
metric=variant)] |
result = dict(layers=[dict(layer=layer, rankings=new_rankings)]) |
with open(target_filename, 'w') as f: |
json.dump(result, f, indent=1) |
if __name__ == '__main__': |
main() |