|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, sys, numpy, torch, argparse, skimage, json, shutil |
|
from PIL import Image |
|
from torch.utils.data import TensorDataset |
|
from matplotlib.figure import Figure |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
|
import matplotlib.gridspec as gridspec |
|
from scipy.ndimage.morphology import binary_dilation |
|
|
|
import netdissect.zdataset |
|
import netdissect.nethook |
|
from netdissect.dissection import safe_dir_name |
|
from netdissect.progress import verbose_progress, default_progress |
|
from netdissect.progress import print_progress, desc_progress, post_progress |
|
from netdissect.easydict import EasyDict |
|
from netdissect.workerpool import WorkerPool, WorkerBase |
|
from netdissect.runningstats import RunningQuantile |
|
from netdissect.pidfile import pidfile_taken |
|
from netdissect.modelconfig import create_instrumented_model |
|
from netdissect.autoeval import autoimport_eval |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='ACE optimization utility', |
|
prog='python -m netdissect.aceoptimize') |
|
parser.add_argument('--model', type=str, default=None, |
|
help='constructor for the model to test') |
|
parser.add_argument('--pthfile', type=str, default=None, |
|
help='filename of .pth file for the model') |
|
parser.add_argument('--segmenter', type=str, default=None, |
|
help='constructor for asegmenter class') |
|
parser.add_argument('--classname', type=str, default=None, |
|
help='intervention classname') |
|
parser.add_argument('--layer', type=str, default='layer4', |
|
help='layer name') |
|
parser.add_argument('--search_size', type=int, default=10000, |
|
help='size of search for finding training locations') |
|
parser.add_argument('--train_size', type=int, default=1000, |
|
help='size of training set') |
|
parser.add_argument('--eval_size', type=int, default=200, |
|
help='size of eval set') |
|
parser.add_argument('--inference_batch_size', type=int, default=10, |
|
help='forward pass batch size') |
|
parser.add_argument('--train_batch_size', type=int, default=2, |
|
help='backprop pass batch size') |
|
parser.add_argument('--train_update_freq', type=int, default=10, |
|
help='number of batches for each training update') |
|
parser.add_argument('--train_epochs', type=int, default=10, |
|
help='number of epochs of training') |
|
parser.add_argument('--l2_lambda', type=float, default=0.005, |
|
help='l2 regularizer hyperparameter') |
|
parser.add_argument('--eval_only', action='store_true', default=False, |
|
help='reruns eval only on trained snapshots') |
|
parser.add_argument('--no-cuda', action='store_true', default=False, |
|
help='disables CUDA usage') |
|
parser.add_argument('--no-cache', action='store_true', default=False, |
|
help='disables reading of cache') |
|
parser.add_argument('--outdir', type=str, default=None, |
|
help='dissection directory') |
|
parser.add_argument('--variant', type=str, default=None, |
|
help='experiment variant') |
|
args = parser.parse_args() |
|
args.cuda = not args.no_cuda and torch.cuda.is_available() |
|
torch.backends.cudnn.benchmark = True |
|
|
|
run_command(args) |
|
|
|
def run_command(args): |
|
verbose_progress(True) |
|
progress = default_progress() |
|
classname = args.classname |
|
layer = args.layer |
|
num_eval_units = 20 |
|
|
|
assert os.path.isfile(os.path.join(args.outdir, 'dissect.json')), ( |
|
"Should be a dissection directory") |
|
|
|
if args.variant is None: |
|
args.variant = 'ace' |
|
|
|
if args.l2_lambda != 0.005: |
|
args.variant = '%s_reg%g' % (args.variant, args.l2_lambda) |
|
|
|
cachedir = os.path.join(args.outdir, safe_dir_name(layer), args.variant, |
|
classname) |
|
|
|
if pidfile_taken(os.path.join(cachedir, 'lock.pid'), True): |
|
sys.exit(0) |
|
|
|
|
|
with open(os.path.join(args.outdir, 'dissect.json')) as f: |
|
dissection = EasyDict(json.load(f)) |
|
if args.model is None: |
|
args.model = dissection.settings.model |
|
if args.pthfile is None: |
|
args.pthfile = dissection.settings.pthfile |
|
if args.segmenter is None: |
|
args.segmenter = dissection.settings.segmenter |
|
|
|
if args.segmenter is None: |
|
args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" + |
|
"segsizes=[256], segdiv='quad')") |
|
|
|
if (not args.no_cache and |
|
os.path.isfile(os.path.join(cachedir, 'snapshots', 'epoch-%d.npy' % ( |
|
args.train_epochs - 1))) and |
|
os.path.isfile(os.path.join(cachedir, 'report.json'))): |
|
print('%s already done' % cachedir) |
|
sys.exit(0) |
|
|
|
os.makedirs(cachedir, exist_ok=True) |
|
|
|
|
|
model = create_instrumented_model(args, gen=True, edit=True, |
|
layers=[args.layer]) |
|
if model is None: |
|
print('No model specified') |
|
sys.exit(1) |
|
|
|
segmenter = autoimport_eval(args.segmenter) |
|
labelnames, catname = segmenter.get_label_and_category_names() |
|
classnum = [i for i, (n, c) in enumerate(labelnames) if n == classname][0] |
|
num_classes = len(labelnames) |
|
with open(os.path.join(cachedir, 'labelnames.json'), 'w') as f: |
|
json.dump(labelnames, f, indent=1) |
|
|
|
|
|
full_sample = netdissect.zdataset.z_sample_for_model(model, |
|
args.search_size, seed=10) |
|
second_sample = netdissect.zdataset.z_sample_for_model(model, |
|
args.search_size, seed=11) |
|
|
|
cache_filename = os.path.join(cachedir, 'corpus.npz') |
|
corpus = EasyDict() |
|
try: |
|
if not args.no_cache: |
|
corpus = EasyDict({k: torch.from_numpy(v) |
|
for k, v in numpy.load(cache_filename).items()}) |
|
except: |
|
pass |
|
|
|
|
|
compute_present_locations(args, corpus, cache_filename, |
|
model, segmenter, classnum, full_sample) |
|
compute_mean_present_features(args, corpus, cache_filename, model) |
|
compute_feature_quantiles(args, corpus, cache_filename, model, full_sample) |
|
compute_candidate_locations(args, corpus, cache_filename, model, segmenter, |
|
classnum, second_sample) |
|
|
|
init_ablation = initial_ablation(args, args.outdir) |
|
scores = train_ablation(args, corpus, cache_filename, |
|
model, segmenter, classnum, init_ablation) |
|
summarize_scores(args, corpus, cachedir, layer, classname, |
|
args.variant, scores) |
|
if args.variant == 'ace': |
|
add_ace_ranking_to_dissection(args.outdir, layer, classname, scores) |
|
|
|
|
|
class SaveImageWorker(WorkerBase): |
|
def work(self, data, filename): |
|
Image.fromarray(data).save(filename, optimize=True, quality=80) |
|
|
|
def plot_heatmap(output_filename, data, size=256): |
|
fig = Figure(figsize=(1, 1), dpi=size) |
|
canvas = FigureCanvas(fig) |
|
gs = gridspec.GridSpec(1, 1, left=0.0, right=1.0, bottom=0.0, top=1.0) |
|
ax = fig.add_subplot(gs[0]) |
|
ax.set_axis_off() |
|
ax.imshow(data, cmap='hot', aspect='equal', interpolation='nearest', |
|
vmin=-1, vmax=1) |
|
canvas.print_figure(output_filename, format='png') |
|
|
|
|
|
def draw_heatmap(output_filename, data, size=256): |
|
fig = Figure(figsize=(1, 1), dpi=size) |
|
canvas = FigureCanvas(fig) |
|
gs = gridspec.GridSpec(1, 1, left=0.0, right=1.0, bottom=0.0, top=1.0) |
|
ax = fig.add_subplot(gs[0]) |
|
ax.set_axis_off() |
|
ax.imshow(data, cmap='hot', aspect='equal', interpolation='nearest', |
|
vmin=-1, vmax=1) |
|
canvas.draw() |
|
image = numpy.fromstring(canvas.tostring_rgb(), dtype='uint8').reshape( |
|
(size, size, 3)) |
|
return image |
|
|
|
def compute_present_locations(args, corpus, cache_filename, |
|
model, segmenter, classnum, full_sample): |
|
|
|
|
|
|
|
if all(k in corpus for k in ['present_indices', |
|
'object_present_sample', 'object_present_location', |
|
'object_location_popularity', 'weighted_mean_present_feature']): |
|
return |
|
progress = default_progress() |
|
feature_shape = model.feature_shape[args.layer][2:] |
|
num_locations = numpy.prod(feature_shape).item() |
|
num_units = model.feature_shape[args.layer][1] |
|
with torch.no_grad(): |
|
weighted_feature_sum = torch.zeros(num_units).cuda() |
|
object_presence_scores = [] |
|
for [zbatch] in progress( |
|
torch.utils.data.DataLoader(TensorDataset(full_sample), |
|
batch_size=args.inference_batch_size, num_workers=10, |
|
pin_memory=True), |
|
desc="Object pool"): |
|
zbatch = zbatch.cuda() |
|
tensor_image = model(zbatch) |
|
segmented_image = segmenter.segment_batch(tensor_image, |
|
downsample=2) |
|
mask = (segmented_image == classnum).max(1)[0] |
|
score = torch.nn.functional.adaptive_avg_pool2d( |
|
mask.float(), feature_shape) |
|
object_presence_scores.append(score.cpu()) |
|
feat = model.retained_layer(args.layer) |
|
weighted_feature_sum += (feat * score[:,None,:,:]).view( |
|
feat.shape[0],feat.shape[1], -1).sum(2).sum(0) |
|
object_presence_at_feature = torch.cat(object_presence_scores) |
|
object_presence_at_image, object_location_in_image = ( |
|
object_presence_at_feature.view(args.search_size, -1).max(1)) |
|
best_presence_scores, best_presence_images = torch.sort( |
|
-object_presence_at_image) |
|
all_present_indices = torch.sort( |
|
best_presence_images[:(args.train_size+args.eval_size)])[0] |
|
corpus.present_indices = all_present_indices[:args.train_size] |
|
corpus.object_present_sample = full_sample[corpus.present_indices] |
|
corpus.object_present_location = object_location_in_image[ |
|
corpus.present_indices] |
|
corpus.object_location_popularity = torch.bincount( |
|
corpus.object_present_location, |
|
minlength=num_locations) |
|
corpus.weighted_mean_present_feature = (weighted_feature_sum.cpu() / ( |
|
1e-20 + object_presence_at_feature.view(-1).sum())) |
|
corpus.eval_present_indices = all_present_indices[-args.eval_size:] |
|
corpus.eval_present_sample = full_sample[corpus.eval_present_indices] |
|
corpus.eval_present_location = object_location_in_image[ |
|
corpus.eval_present_indices] |
|
|
|
if cache_filename: |
|
numpy.savez(cache_filename, **corpus) |
|
|
|
def compute_mean_present_features(args, corpus, cache_filename, model): |
|
|
|
|
|
if all(k in corpus for k in ['mean_present_feature']): |
|
return |
|
progress = default_progress() |
|
with torch.no_grad(): |
|
total_present_feature = 0 |
|
for [zbatch, featloc] in progress( |
|
torch.utils.data.DataLoader(TensorDataset( |
|
corpus.object_present_sample, |
|
corpus.object_present_location), |
|
batch_size=args.inference_batch_size, num_workers=10, |
|
pin_memory=True), |
|
desc="Mean activations"): |
|
zbatch = zbatch.cuda() |
|
featloc = featloc.cuda() |
|
tensor_image = model(zbatch) |
|
feat = model.retained_layer(args.layer) |
|
flatfeat = feat.view(feat.shape[0], feat.shape[1], -1) |
|
sum_feature_at_obj = flatfeat[ |
|
torch.arange(feat.shape[0]).to(feat.device), :, featloc |
|
].sum(0) |
|
total_present_feature = total_present_feature + sum_feature_at_obj |
|
corpus.mean_present_feature = (total_present_feature / len( |
|
corpus.object_present_sample)).cpu() |
|
if cache_filename: |
|
numpy.savez(cache_filename, **corpus) |
|
|
|
def compute_feature_quantiles(args, corpus, cache_filename, model, full_sample): |
|
|
|
if all(k in corpus for k in ['feature_99', 'feature_999']): |
|
return |
|
progress = default_progress() |
|
with torch.no_grad(): |
|
rq = RunningQuantile(resolution=10000) |
|
for [zbatch] in progress( |
|
torch.utils.data.DataLoader(TensorDataset(full_sample), |
|
batch_size=args.inference_batch_size, num_workers=10, |
|
pin_memory=True), |
|
desc="Calculating 0.999 quantile"): |
|
zbatch = zbatch.cuda() |
|
tensor_image = model(zbatch) |
|
feat = model.retained_layer(args.layer) |
|
rq.add(feat.permute(0, 2, 3, 1 |
|
).contiguous().view(-1, feat.shape[1])) |
|
result = rq.quantiles([0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999]) |
|
corpus.feature_001 = result[:, 0].cpu() |
|
corpus.feature_01 = result[:, 1].cpu() |
|
corpus.feature_10 = result[:, 2].cpu() |
|
corpus.feature_50 = result[:, 3].cpu() |
|
corpus.feature_90 = result[:, 4].cpu() |
|
corpus.feature_99 = result[:, 5].cpu() |
|
corpus.feature_999 = result[:, 6].cpu() |
|
numpy.savez(cache_filename, **corpus) |
|
|
|
def compute_candidate_locations(args, corpus, cache_filename, model, |
|
segmenter, classnum, second_sample): |
|
|
|
|
|
|
|
if all(k in corpus for k in ['candidate_indices', |
|
'candidate_sample', 'candidate_score', |
|
'candidate_location', 'object_score_at_candidate', |
|
'candidate_location_popularity']): |
|
return |
|
progress = default_progress() |
|
feature_shape = model.feature_shape[args.layer][2:] |
|
num_locations = numpy.prod(feature_shape).item() |
|
with torch.no_grad(): |
|
|
|
possible_locations = numpy.arange(num_locations) |
|
|
|
|
|
|
|
location_weights = (corpus.object_location_popularity).double() |
|
location_weights += (location_weights.mean()) / 10.0 |
|
location_weights = location_weights / location_weights.sum() |
|
|
|
candidate_scores = [] |
|
object_scores = [] |
|
prng = numpy.random.RandomState(1) |
|
for [zbatch] in progress( |
|
torch.utils.data.DataLoader(TensorDataset(second_sample), |
|
batch_size=args.inference_batch_size, num_workers=10, |
|
pin_memory=True), |
|
desc="Candidate pool"): |
|
batch_scores = torch.zeros((len(zbatch),) + feature_shape).cuda() |
|
flat_batch_scores = batch_scores.view(len(zbatch), -1) |
|
zbatch = zbatch.cuda() |
|
tensor_image = model(zbatch) |
|
segmented_image = segmenter.segment_batch(tensor_image, |
|
downsample=2) |
|
mask = (segmented_image == classnum).max(1)[0] |
|
object_score = torch.nn.functional.adaptive_avg_pool2d( |
|
mask.float(), feature_shape) |
|
baseline_presence = mask.float().view(mask.shape[0], -1).sum(1) |
|
|
|
edit_mask = torch.zeros((1, 1) + feature_shape).cuda() |
|
if '_tcm' in args.variant: |
|
|
|
replace_vec = (corpus.mean_present_feature |
|
[None,:,None,None].cuda()) |
|
else: |
|
replace_vec = (corpus.weighted_mean_present_feature |
|
[None,:,None,None].cuda()) |
|
|
|
for loc in prng.choice(possible_locations, replace=False, |
|
p=location_weights, size=5): |
|
edit_mask.zero_() |
|
edit_mask.view(-1)[loc] = 1 |
|
model.edit_layer(args.layer, |
|
ablation=edit_mask, replacement=replace_vec) |
|
tensor_image = model(zbatch) |
|
segmented_image = segmenter.segment_batch(tensor_image, |
|
downsample=2) |
|
mask = (segmented_image == classnum).max(1)[0] |
|
modified_presence = mask.float().view( |
|
mask.shape[0], -1).sum(1) |
|
flat_batch_scores[:,loc] = ( |
|
modified_presence - baseline_presence) |
|
candidate_scores.append(batch_scores.cpu()) |
|
object_scores.append(object_score.cpu()) |
|
|
|
object_scores = torch.cat(object_scores) |
|
candidate_scores = torch.cat(candidate_scores) |
|
|
|
candidate_scores = candidate_scores * (object_scores == 0).float() |
|
candidate_score_at_image, candidate_location_in_image = ( |
|
candidate_scores.view(args.search_size, -1).max(1)) |
|
best_candidate_scores, best_candidate_images = torch.sort( |
|
-candidate_score_at_image) |
|
all_candidate_indices = torch.sort( |
|
best_candidate_images[:(args.train_size+args.eval_size)])[0] |
|
corpus.candidate_indices = all_candidate_indices[:args.train_size] |
|
corpus.candidate_sample = second_sample[corpus.candidate_indices] |
|
corpus.candidate_location = candidate_location_in_image[ |
|
corpus.candidate_indices] |
|
corpus.candidate_score = candidate_score_at_image[ |
|
corpus.candidate_indices] |
|
corpus.object_score_at_candidate = object_scores.view( |
|
len(object_scores), -1)[ |
|
corpus.candidate_indices, corpus.candidate_location] |
|
corpus.candidate_location_popularity = torch.bincount( |
|
corpus.candidate_location, |
|
minlength=num_locations) |
|
corpus.eval_candidate_indices = all_candidate_indices[ |
|
-args.eval_size:] |
|
corpus.eval_candidate_sample = second_sample[ |
|
corpus.eval_candidate_indices] |
|
corpus.eval_candidate_location = candidate_location_in_image[ |
|
corpus.eval_candidate_indices] |
|
numpy.savez(cache_filename, **corpus) |
|
|
|
def visualize_training_locations(args, corpus, cachedir, model): |
|
|
|
progress = default_progress() |
|
feature_shape = model.feature_shape[args.layer][2:] |
|
num_locations = numpy.prod(feature_shape).item() |
|
with torch.no_grad(): |
|
imagedir = os.path.join(cachedir, 'image') |
|
os.makedirs(imagedir, exist_ok=True) |
|
image_saver = WorkerPool(SaveImageWorker) |
|
for group, group_sample, group_location, group_indices in [ |
|
('present', |
|
corpus.object_present_sample, |
|
corpus.object_present_location, |
|
corpus.present_indices), |
|
('candidate', |
|
corpus.candidate_sample, |
|
corpus.candidate_location, |
|
corpus.candidate_indices)]: |
|
for [zbatch, featloc, indices] in progress( |
|
torch.utils.data.DataLoader(TensorDataset( |
|
group_sample, group_location, group_indices), |
|
batch_size=args.inference_batch_size, num_workers=10, |
|
pin_memory=True), |
|
desc="Visualize %s" % group): |
|
zbatch = zbatch.cuda() |
|
tensor_image = model(zbatch) |
|
feature_mask = torch.zeros((len(zbatch), 1) + feature_shape) |
|
feature_mask.view(len(zbatch), -1).scatter_( |
|
1, featloc[:,None], 1) |
|
feature_mask = torch.nn.functional.adaptive_max_pool2d( |
|
feature_mask.float(), tensor_image.shape[-2:]).cuda() |
|
yellow = torch.Tensor([1.0, 1.0, -1.0] |
|
)[None, :, None, None].cuda() |
|
tensor_image = tensor_image * (1 - 0.5 * feature_mask) + ( |
|
0.5 * feature_mask * yellow) |
|
byte_image = (((tensor_image+1)/2)*255).clamp(0, 255).byte() |
|
numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy() |
|
for i, index in enumerate(indices): |
|
image_saver.add(numpy_image[i], os.path.join(imagedir, |
|
'%s_%d.jpg' % (group, index))) |
|
image_saver.join() |
|
|
|
def scale_summary(scale, lownums, highnums): |
|
value, order = (-(scale.detach())).cpu().sort(0) |
|
lowsum = ' '.join('%d: %.3g' % (o.item(), -v.item()) |
|
for v, o in zip(value[:lownums], order[:lownums])) |
|
highsum = ' '.join('%d: %.3g' % (o.item(), -v.item()) |
|
for v, o in zip(value[-highnums:], order[-highnums:])) |
|
return lowsum + ' ... ' + highsum |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initial_ablation(args, dissectdir): |
|
|
|
with open(os.path.join(dissectdir, 'dissect.json')) as f: |
|
dissection = EasyDict(json.load(f)) |
|
lrec = [l for l in dissection.layers if l.layer == args.layer][0] |
|
rrec = [r for r in lrec.rankings if r.name == '%s-iou' % args.classname |
|
][0] |
|
init_scores = -torch.tensor(rrec.score) |
|
return init_scores / init_scores.max() |
|
|
|
def ace_loss(segmenter, classnum, model, layer, high_replacement, ablation, |
|
pbatch, ploc, cbatch, cloc, run_backward=False, |
|
discrete_pixels=False, |
|
discrete_units=False, |
|
mixed_units=False, |
|
ablation_only=False, |
|
fullimage_measurement=False, |
|
fullimage_ablation=False, |
|
): |
|
feature_shape = model.feature_shape[layer][2:] |
|
if discrete_units: |
|
assert discrete_units > 0 |
|
d = torch.zeros_like(ablation) |
|
top_units = torch.topk(ablation.view(-1), discrete_units)[1] |
|
if mixed_units: |
|
d.view(-1)[top_units] = ablation.view(-1)[top_units] |
|
else: |
|
d.view(-1)[top_units] = 1 |
|
ablation = d |
|
|
|
|
|
p_mask = torch.zeros((len(pbatch), 1) + feature_shape) |
|
if fullimage_ablation: |
|
p_mask[...] = 1 |
|
else: |
|
p_mask.view(len(pbatch), -1).scatter_(1, ploc[:,None], 1) |
|
p_mask = p_mask.cuda() |
|
a_p_mask = (ablation * p_mask) |
|
model.edit_layer(layer, ablation=a_p_mask, replacement=None) |
|
tensor_images = model(pbatch.cuda()) |
|
assert model._ablation[layer] is a_p_mask |
|
erase_effect, erased_mask = segmenter.predict_single_class( |
|
tensor_images, classnum, downsample=2) |
|
if discrete_pixels: |
|
erase_effect = erased_mask.float() |
|
erase_downsampled = torch.nn.functional.adaptive_avg_pool2d( |
|
erase_effect[:,None,:,:], feature_shape)[:,0,:,:] |
|
if fullimage_measurement: |
|
erase_loss = erase_downsampled.sum() |
|
else: |
|
erase_at_loc = erase_downsampled.view(len(erase_downsampled), -1 |
|
)[torch.arange(len(erase_downsampled)), ploc] |
|
erase_loss = erase_at_loc.sum() |
|
if run_backward: |
|
erase_loss.backward() |
|
if ablation_only: |
|
return erase_loss |
|
|
|
|
|
c_mask = torch.zeros((len(cbatch), 1) + feature_shape) |
|
c_mask.view(len(cbatch), -1).scatter_(1, cloc[:,None], 1) |
|
c_mask = c_mask.cuda() |
|
a_c_mask = (ablation * c_mask) |
|
model.edit_layer(layer, ablation=a_c_mask, replacement=high_replacement) |
|
tensor_images = model(cbatch.cuda()) |
|
assert model._ablation[layer] is a_c_mask |
|
add_effect, added_mask = segmenter.predict_single_class( |
|
tensor_images, classnum, downsample=2) |
|
if discrete_pixels: |
|
add_effect = added_mask.float() |
|
add_effect = -add_effect |
|
add_downsampled = torch.nn.functional.adaptive_avg_pool2d( |
|
add_effect[:,None,:,:], feature_shape)[:,0,:,:] |
|
if fullimage_measurement: |
|
add_loss = add_downsampled.mean() |
|
else: |
|
add_at_loc = add_downsampled.view(len(add_downsampled), -1 |
|
)[torch.arange(len(add_downsampled)), ploc] |
|
add_loss = add_at_loc.sum() |
|
if run_backward: |
|
add_loss.backward() |
|
return erase_loss + add_loss |
|
|
|
def train_ablation(args, corpus, cachefile, model, segmenter, classnum, |
|
initial_ablation=None): |
|
progress = default_progress() |
|
cachedir = os.path.dirname(cachefile) |
|
snapdir = os.path.join(cachedir, 'snapshots') |
|
os.makedirs(snapdir, exist_ok=True) |
|
|
|
|
|
if '_h99' in args.variant: |
|
high_replacement = corpus.feature_99[None,:,None,None].cuda() |
|
elif '_tcm' in args.variant: |
|
|
|
high_replacement = ( |
|
corpus.mean_present_feature[None,:,None,None].cuda()) |
|
else: |
|
high_replacement = ( |
|
corpus.weighted_mean_present_feature[None,:,None,None].cuda()) |
|
fullimage_measurement = False |
|
ablation_only = False |
|
fullimage_ablation = False |
|
if '_fim' in args.variant: |
|
fullimage_measurement = True |
|
elif '_fia' in args.variant: |
|
fullimage_measurement = True |
|
ablation_only = True |
|
fullimage_ablation = True |
|
high_replacement.requires_grad = False |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
ablation = torch.zeros(high_replacement.shape).cuda() |
|
if initial_ablation is not None: |
|
ablation.view(-1)[...] = initial_ablation |
|
ablation.requires_grad = True |
|
optimizer = torch.optim.Adam([ablation], lr=0.01) |
|
start_epoch = 0 |
|
epoch = 0 |
|
|
|
def eval_loss_and_reg(): |
|
discrete_experiments = dict( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dboth20=dict(discrete_units=20, discrete_pixels=True), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fimadbm10=dict(discrete_units=10, mixed_units=True, |
|
discrete_pixels=True, |
|
ablation_only=True, |
|
fullimage_ablation=True, |
|
fullimage_measurement=True), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fimadbm20=dict(discrete_units=20, mixed_units=True, |
|
discrete_pixels=True, |
|
ablation_only=True, |
|
fullimage_ablation=True, |
|
fullimage_measurement=True) |
|
) |
|
with torch.no_grad(): |
|
total_loss = 0 |
|
discrete_losses = {k: 0 for k in discrete_experiments} |
|
for [pbatch, ploc, cbatch, cloc] in progress( |
|
torch.utils.data.DataLoader(TensorDataset( |
|
corpus.eval_present_sample, |
|
corpus.eval_present_location, |
|
corpus.eval_candidate_sample, |
|
corpus.eval_candidate_location), |
|
batch_size=args.inference_batch_size, num_workers=10, |
|
shuffle=False, pin_memory=True), |
|
desc="Eval"): |
|
|
|
|
|
total_loss = total_loss + ace_loss(segmenter, classnum, |
|
model, args.layer, high_replacement, ablation, |
|
pbatch, ploc, cbatch, cloc, run_backward=False, |
|
ablation_only=ablation_only, |
|
fullimage_measurement=fullimage_measurement) |
|
for k, config in discrete_experiments.items(): |
|
discrete_losses[k] = discrete_losses[k] + ace_loss( |
|
segmenter, classnum, |
|
model, args.layer, high_replacement, ablation, |
|
pbatch, ploc, cbatch, cloc, run_backward=False, |
|
**config) |
|
avg_loss = (total_loss / args.eval_size).item() |
|
avg_d_losses = {k: (d / args.eval_size).item() |
|
for k, d in discrete_losses.items()} |
|
regularizer = (args.l2_lambda * ablation.pow(2).sum()) |
|
print_progress('Epoch %d Loss %g Regularizer %g' % |
|
(epoch, avg_loss, regularizer)) |
|
print_progress(' '.join('%s: %g' % (k, d) |
|
for k, d in avg_d_losses.items())) |
|
print_progress(scale_summary(ablation.view(-1), 10, 3)) |
|
return avg_loss, regularizer, avg_d_losses |
|
|
|
if args.eval_only: |
|
|
|
|
|
for epoch in range(-1, args.train_epochs): |
|
snapfile = os.path.join(snapdir, 'epoch-%d.pth' % epoch) |
|
if not os.path.exists(snapfile): |
|
data = {} |
|
if epoch >= 0: |
|
print('No epoch %d' % epoch) |
|
continue |
|
else: |
|
data = torch.load(snapfile) |
|
with torch.no_grad(): |
|
ablation[...] = data['ablation'].to(ablation.device) |
|
optimizer.load_state_dict(data['optimizer']) |
|
avg_loss, regularizer, new_extra = eval_loss_and_reg() |
|
|
|
extra = {k: v for k, v in data.items() |
|
if k not in ['ablation', 'optimizer', 'avg_loss']} |
|
extra.update(new_extra) |
|
torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), |
|
avg_loss=avg_loss, **extra), |
|
os.path.join(snapdir, 'epoch-%d.pth' % epoch)) |
|
|
|
return ablation.view(-1).detach().cpu().numpy() |
|
|
|
if not args.no_cache: |
|
for start_epoch in reversed(range(args.train_epochs)): |
|
snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch) |
|
if os.path.exists(snapfile): |
|
data = torch.load(snapfile) |
|
with torch.no_grad(): |
|
ablation[...] = data['ablation'].to(ablation.device) |
|
optimizer.load_state_dict(data['optimizer']) |
|
start_epoch += 1 |
|
break |
|
|
|
if start_epoch < args.train_epochs: |
|
epoch = start_epoch - 1 |
|
avg_loss, regularizer, extra = eval_loss_and_reg() |
|
if epoch == -1: |
|
torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), |
|
avg_loss=avg_loss, **extra), |
|
os.path.join(snapdir, 'epoch-%d.pth' % epoch)) |
|
|
|
update_size = args.train_update_freq * args.train_batch_size |
|
for epoch in range(start_epoch, args.train_epochs): |
|
candidate_shuffle = torch.randperm(len(corpus.candidate_sample)) |
|
train_loss = 0 |
|
for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate(progress( |
|
torch.utils.data.DataLoader(TensorDataset( |
|
corpus.object_present_sample, |
|
corpus.object_present_location, |
|
corpus.candidate_sample[candidate_shuffle], |
|
corpus.candidate_location[candidate_shuffle]), |
|
batch_size=args.train_batch_size, num_workers=10, |
|
shuffle=True, pin_memory=True), |
|
desc="ACE opt epoch %d" % epoch)): |
|
if batch_num % args.train_update_freq == 0: |
|
optimizer.zero_grad() |
|
|
|
|
|
loss = ace_loss(segmenter, classnum, |
|
model, args.layer, high_replacement, ablation, |
|
pbatch, ploc, cbatch, cloc, run_backward=True, |
|
ablation_only=ablation_only, |
|
fullimage_measurement=fullimage_measurement) |
|
with torch.no_grad(): |
|
train_loss = train_loss + loss |
|
if (batch_num + 1) % args.train_update_freq == 0: |
|
|
|
regularizer = (args.l2_lambda * update_size |
|
* ablation.pow(2).sum()) |
|
regularizer.backward() |
|
optimizer.step() |
|
with torch.no_grad(): |
|
ablation.clamp_(0, 1) |
|
post_progress(l=(train_loss/update_size).item(), |
|
r=(regularizer/update_size).item()) |
|
train_loss = 0 |
|
|
|
avg_loss, regularizer, extra = eval_loss_and_reg() |
|
torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), |
|
avg_loss=avg_loss, **extra), |
|
os.path.join(snapdir, 'epoch-%d.pth' % epoch)) |
|
numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch), |
|
ablation.detach().cpu().numpy()) |
|
|
|
|
|
return ablation.view(-1).detach().cpu().numpy() |
|
|
|
|
|
def tensor_to_numpy_image_batch(tensor_image): |
|
byte_image = (((tensor_image+1)/2)*255).clamp(0, 255).byte() |
|
numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy() |
|
return numpy_image |
|
|
|
|
|
|
|
def evaluate_ablation(args, model, segmenter, eval_sample, classnum, layer, |
|
ordering): |
|
total_bincount = 0 |
|
data_size = 0 |
|
progress = default_progress() |
|
for l in model.ablation: |
|
model.ablation[l] = None |
|
feature_units = model.feature_shape[args.layer][1] |
|
feature_shape = model.feature_shape[args.layer][2:] |
|
repeats = len(ordering) |
|
total_scores = torch.zeros(repeats + 1) |
|
for i, batch in enumerate(progress(torch.utils.data.DataLoader( |
|
TensorDataset(eval_sample), |
|
batch_size=args.inference_batch_size, num_workers=10, |
|
pin_memory=True), |
|
desc="Evaluate interventions")): |
|
tensor_image = model(zbatch) |
|
segmented_image = segmenter.segment_batch(tensor_image, |
|
downsample=2) |
|
mask = (segmented_image == classnum).max(1)[0] |
|
downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( |
|
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
|
total_scores[0] += downsampled_seg.sum().cpu() |
|
|
|
|
|
interventions_needed = downsampled_seg.nonzero() |
|
location_count = len(interventions_needed) |
|
if location_count == 0: |
|
continue |
|
interventions_needed = interventions_needed.repeat(repeats, 1) |
|
inter_z = batch[0][interventions_needed[:,0]].to(device) |
|
inter_chan = torch.zeros(repeats, location_count, feature_units, |
|
device=device) |
|
for j, u in enumerate(ordering): |
|
inter_chan[j:, :, u] = 1 |
|
inter_chan = inter_chan.view(len(inter_z), feature_units) |
|
inter_loc = interventions_needed[:,1:] |
|
scores = torch.zeros(len(inter_z)) |
|
batch_size = len(batch[0]) |
|
for j in range(0, len(inter_z), batch_size): |
|
ibz = inter_z[j:j+batch_size] |
|
ibl = inter_loc[j:j+batch_size].t() |
|
imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) |
|
imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 |
|
ibc = inter_chan[j:j+batch_size] |
|
model.edit_layer(args.layer, ablation=( |
|
imask.float()[:,None,:,:] * ibc[:,:,None,None])) |
|
_, seg, _, _, _ = ( |
|
recovery.recover_im_seg_bc_and_features( |
|
[ibz], model)) |
|
mask = (seg == classnum).max(1)[0] |
|
downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( |
|
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
|
scores[j:j+batch_size] = downsampled_iseg[ |
|
(torch.arange(len(ibz)),) + tuple(ibl)] |
|
scores = scores.view(repeats, location_count).sum(1) |
|
total_scores[1:] += scores |
|
return total_scores |
|
|
|
def evaluate_interventions(args, model, segmenter, eval_sample, |
|
classnum, layer, units): |
|
total_bincount = 0 |
|
data_size = 0 |
|
progress = default_progress() |
|
for l in model.ablation: |
|
model.ablation[l] = None |
|
feature_units = model.feature_shape[args.layer][1] |
|
feature_shape = model.feature_shape[args.layer][2:] |
|
repeats = len(ordering) |
|
total_scores = torch.zeros(repeats + 1) |
|
for i, batch in enumerate(progress(torch.utils.data.DataLoader( |
|
TensorDataset(eval_sample), |
|
batch_size=args.inference_batch_size, num_workers=10, |
|
pin_memory=True), |
|
desc="Evaluate interventions")): |
|
tensor_image = model(zbatch) |
|
segmented_image = segmenter.segment_batch(tensor_image, |
|
downsample=2) |
|
mask = (segmented_image == classnum).max(1)[0] |
|
downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( |
|
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
|
total_scores[0] += downsampled_seg.sum().cpu() |
|
|
|
|
|
interventions_needed = downsampled_seg.nonzero() |
|
location_count = len(interventions_needed) |
|
if location_count == 0: |
|
continue |
|
interventions_needed = interventions_needed.repeat(repeats, 1) |
|
inter_z = batch[0][interventions_needed[:,0]].to(device) |
|
inter_chan = torch.zeros(repeats, location_count, feature_units, |
|
device=device) |
|
for j, u in enumerate(ordering): |
|
inter_chan[j:, :, u] = 1 |
|
inter_chan = inter_chan.view(len(inter_z), feature_units) |
|
inter_loc = interventions_needed[:,1:] |
|
scores = torch.zeros(len(inter_z)) |
|
batch_size = len(batch[0]) |
|
for j in range(0, len(inter_z), batch_size): |
|
ibz = inter_z[j:j+batch_size] |
|
ibl = inter_loc[j:j+batch_size].t() |
|
imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) |
|
imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 |
|
ibc = inter_chan[j:j+batch_size] |
|
model.ablation[args.layer] = ( |
|
imask.float()[:,None,:,:] * ibc[:,:,None,None]) |
|
_, seg, _, _, _ = ( |
|
recovery.recover_im_seg_bc_and_features( |
|
[ibz], model)) |
|
mask = (seg == classnum).max(1)[0] |
|
downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( |
|
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
|
scores[j:j+batch_size] = downsampled_iseg[ |
|
(torch.arange(len(ibz)),) + tuple(ibl)] |
|
scores = scores.view(repeats, location_count).sum(1) |
|
total_scores[1:] += scores |
|
return total_scores |
|
|
|
|
|
def add_ace_ranking_to_dissection(outdir, layer, classname, total_scores): |
|
source_filename = os.path.join(outdir, 'dissect.json') |
|
source_filename_bak = os.path.join(outdir, 'dissect.json.bak') |
|
|
|
|
|
if not os.path.exists(source_filename_bak): |
|
shutil.copy(source_filename, source_filename_bak) |
|
|
|
with open(source_filename) as f: |
|
dissection = EasyDict(json.load(f)) |
|
|
|
ranking_name = '%s-ace' % classname |
|
|
|
|
|
lrec = [l for l in dissection.layers if l.layer == layer][0] |
|
lrec.rankings = [r for r in lrec.rankings if r.name != ranking_name] |
|
|
|
|
|
new_rankings = [dict( |
|
name=ranking_name, |
|
score=(-total_scores).flatten().tolist(), |
|
metric='ace')] |
|
|
|
|
|
lrec.rankings[2:2] = new_rankings |
|
|
|
|
|
with open(source_filename, 'w') as f: |
|
json.dump(dissection, f, indent=1) |
|
|
|
def summarize_scores(args, corpus, cachedir, layer, classname, variant, scores): |
|
target_filename = os.path.join(cachedir, 'summary.json') |
|
|
|
ranking_name = '%s-%s' % (classname, variant) |
|
|
|
new_rankings = [dict( |
|
name=ranking_name, |
|
score=(-scores).flatten().tolist(), |
|
metric=variant)] |
|
result = dict(layers=[dict(layer=layer, rankings=new_rankings)]) |
|
|
|
|
|
with open(target_filename, 'w') as f: |
|
json.dump(result, f, indent=1) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|