|
import os, sys, numpy, torch, argparse, skimage, json, shutil |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
|
from matplotlib.figure import Figure |
|
from matplotlib.ticker import MaxNLocator |
|
import matplotlib |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='ACE optimization utility', |
|
prog='python -m netdissect.aceoptimize') |
|
parser.add_argument('--classname', type=str, default=None, |
|
help='intervention classname') |
|
parser.add_argument('--layer', type=str, default='layer4', |
|
help='layer name') |
|
parser.add_argument('--l2_lambda', type=float, nargs='+', |
|
help='l2 regularizer hyperparameter') |
|
parser.add_argument('--outdir', type=str, default=None, |
|
help='dissection directory') |
|
parser.add_argument('--variant', type=str, default=None, |
|
help='experiment variant') |
|
args = parser.parse_args() |
|
|
|
if args.variant is None: |
|
args.variant = 'ace' |
|
|
|
run_command(args) |
|
|
|
def run_command(args): |
|
fig = Figure(figsize=(4.5,3.5)) |
|
FigureCanvas(fig) |
|
ax = fig.add_subplot(111) |
|
for l2_lambda in args.l2_lambda: |
|
variant = args.variant |
|
if l2_lambda != 0.01: |
|
variant += '_reg%g' % l2_lambda |
|
|
|
dirname = os.path.join(args.outdir, args.layer, variant, args.classname) |
|
snapshots = os.path.join(dirname, 'snapshots') |
|
try: |
|
dat = [torch.load(os.path.join(snapshots, 'epoch-%d.pth' % i)) |
|
for i in range(10)] |
|
except: |
|
print('Missing %s snapshots' % dirname) |
|
return |
|
print('reg %g' % l2_lambda) |
|
for i in range(10): |
|
print(i, dat[i]['avg_loss'], |
|
len((dat[i]['ablation'] == 1).nonzero())) |
|
|
|
ax.plot([dat[i]['avg_loss'] for i in range(10)], |
|
label='reg %g' % l2_lambda) |
|
ax.set_title('%s %s' % (args.classname, args.variant)) |
|
ax.grid(True) |
|
ax.legend() |
|
ax.set_ylabel('Loss') |
|
ax.set_xlabel('Epochs') |
|
fig.tight_layout() |
|
dirname = os.path.join(args.outdir, args.layer, |
|
args.variant, args.classname) |
|
fig.savefig(os.path.join(dirname, 'loss-plot.png')) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|