|
''' |
|
A simple tool to generate sample of output of a GAN, |
|
subject to filtering, sorting, or intervention. |
|
''' |
|
|
|
import torch, numpy, os, argparse, sys, shutil, errno, numbers |
|
from PIL import Image |
|
from torch.utils.data import TensorDataset |
|
from netdissect.zdataset import standard_z_sample |
|
from netdissect.progress import default_progress, verbose_progress |
|
from netdissect.autoeval import autoimport_eval |
|
from netdissect.workerpool import WorkerBase, WorkerPool |
|
from netdissect.nethook import retain_layers |
|
from netdissect.runningstats import RunningTopK |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='GAN sample making utility') |
|
parser.add_argument('--model', type=str, default=None, |
|
help='constructor for the model to test') |
|
parser.add_argument('--pthfile', type=str, default=None, |
|
help='filename of .pth file for the model') |
|
parser.add_argument('--outdir', type=str, default='images', |
|
help='directory for image output') |
|
parser.add_argument('--size', type=int, default=100, |
|
help='number of images to output') |
|
parser.add_argument('--test_size', type=int, default=None, |
|
help='number of images to test') |
|
parser.add_argument('--layer', type=str, default=None, |
|
help='layer to inspect') |
|
parser.add_argument('--seed', type=int, default=1, |
|
help='seed') |
|
parser.add_argument('--quiet', action='store_true', default=False, |
|
help='silences console output') |
|
if len(sys.argv) == 1: |
|
parser.print_usage(sys.stderr) |
|
sys.exit(1) |
|
args = parser.parse_args() |
|
verbose_progress(not args.quiet) |
|
|
|
|
|
model = autoimport_eval(args.model) |
|
if args.pthfile is not None: |
|
data = torch.load(args.pthfile) |
|
if 'state_dict' in data: |
|
meta = {} |
|
for key in data: |
|
if isinstance(data[key], numbers.Number): |
|
meta[key] = data[key] |
|
data = data['state_dict'] |
|
model.load_state_dict(data) |
|
|
|
if isinstance(model, torch.nn.DataParallel): |
|
model = next(model.children()) |
|
|
|
first_layer = [c for c in model.modules() |
|
if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, |
|
torch.nn.Linear))][0] |
|
|
|
if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): |
|
z_channels = first_layer.in_channels |
|
spatialdims = (1, 1) |
|
else: |
|
z_channels = first_layer.in_features |
|
spatialdims = () |
|
|
|
retain_layers(model, [args.layer]) |
|
model.cuda() |
|
|
|
if args.test_size is None: |
|
args.test_size = args.size * 20 |
|
z_universe = standard_z_sample(args.test_size, z_channels, |
|
seed=args.seed) |
|
z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims) |
|
indexes = get_all_highest_znums( |
|
model, z_universe, args.size, seed=args.seed) |
|
save_chosen_unit_images(args.outdir, model, z_universe, indexes, |
|
lightbox=True) |
|
|
|
|
|
def get_all_highest_znums(model, z_universe, size, |
|
batch_size=10, seed=1): |
|
|
|
retained_items = list(model.retained.items()) |
|
assert len(retained_items) == 1 |
|
layer = retained_items[0][0] |
|
|
|
progress = default_progress() |
|
num_units = None |
|
with torch.no_grad(): |
|
|
|
z_loader = torch.utils.data.DataLoader(TensorDataset(z_universe), |
|
batch_size=batch_size, num_workers=2, |
|
pin_memory=True) |
|
rtk = RunningTopK(k=size) |
|
for [z] in progress(z_loader, desc='Finding max activations'): |
|
z = z.cuda() |
|
model(z) |
|
feature = model.retained[layer] |
|
num_units = feature.shape[1] |
|
max_feature = feature.view( |
|
feature.shape[0], num_units, -1).max(2)[0] |
|
rtk.add(max_feature) |
|
td, ti = rtk.result() |
|
highest = ti.sort(1)[0] |
|
return highest |
|
|
|
def save_chosen_unit_images(dirname, model, z_universe, indices, |
|
shared_dir="shared_images", |
|
unitdir_template="unit_{}", |
|
name_template="image_{}.jpg", |
|
lightbox=False, batch_size=50, seed=1): |
|
all_indices = torch.unique(indices.view(-1), sorted=True) |
|
z_sample = z_universe[all_indices] |
|
progress = default_progress() |
|
sdir = os.path.join(dirname, shared_dir) |
|
created_hashdirs = set() |
|
for index in range(len(z_universe)): |
|
hd = hashdir(index) |
|
if hd not in created_hashdirs: |
|
created_hashdirs.add(hd) |
|
os.makedirs(os.path.join(sdir, hd), exist_ok=True) |
|
with torch.no_grad(): |
|
|
|
z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample), |
|
batch_size=batch_size, num_workers=2, |
|
pin_memory=True) |
|
saver = WorkerPool(SaveImageWorker) |
|
for batch_num, [z] in enumerate(progress(z_loader, |
|
desc='Saving images')): |
|
z = z.cuda() |
|
start_index = batch_num * batch_size |
|
im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute( |
|
0, 2, 3, 1).cpu() |
|
for i in range(len(im)): |
|
index = all_indices[i + start_index].item() |
|
filename = os.path.join(sdir, hashdir(index), |
|
name_template.format(index)) |
|
saver.add(im[i].numpy(), filename) |
|
saver.join() |
|
linker = WorkerPool(MakeLinkWorker) |
|
for u in progress(range(len(indices)), desc='Making links'): |
|
udir = os.path.join(dirname, unitdir_template.format(u)) |
|
os.makedirs(udir, exist_ok=True) |
|
for r in range(indices.shape[1]): |
|
index = indices[u,r].item() |
|
fn = name_template.format(index) |
|
|
|
sourcename = os.path.join(sdir, hashdir(index), fn) |
|
targname = os.path.join(udir, fn) |
|
linker.add(sourcename, targname) |
|
if lightbox: |
|
copy_lightbox_to(udir) |
|
linker.join() |
|
|
|
def copy_lightbox_to(dirname): |
|
srcdir = os.path.realpath( |
|
os.path.join(os.getcwd(), os.path.dirname(__file__))) |
|
shutil.copy(os.path.join(srcdir, 'lightbox.html'), |
|
os.path.join(dirname, '+lightbox.html')) |
|
|
|
def hashdir(index): |
|
|
|
|
|
return '%02d' % (index % 100) |
|
|
|
class SaveImageWorker(WorkerBase): |
|
|
|
|
|
def work(self, data, filename): |
|
Image.fromarray(data).save(filename, optimize=True, quality=100) |
|
|
|
class MakeLinkWorker(WorkerBase): |
|
|
|
|
|
def work(self, sourcename, targname): |
|
try: |
|
os.link(sourcename, targname) |
|
except OSError as e: |
|
if e.errno == errno.EEXIST: |
|
os.remove(targname) |
|
os.link(sourcename, targname) |
|
else: |
|
raise |
|
|
|
class MakeSyminkWorker(WorkerBase): |
|
|
|
|
|
def work(self, sourcename, targname): |
|
try: |
|
os.symlink(sourcename, targname) |
|
except OSError as e: |
|
if e.errno == errno.EEXIST: |
|
os.remove(targname) |
|
os.symlink(sourcename, targname) |
|
else: |
|
raise |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|