|
''' |
|
A simple tool to generate sample of output of a GAN, |
|
and apply semantic segmentation on the output. |
|
''' |
|
|
|
import torch, numpy, os, argparse, sys, shutil |
|
from PIL import Image |
|
from torch.utils.data import TensorDataset |
|
from netdissect.zdataset import standard_z_sample, z_dataset_for_model |
|
from netdissect.progress import default_progress, verbose_progress |
|
from netdissect.autoeval import autoimport_eval |
|
from netdissect.workerpool import WorkerBase, WorkerPool |
|
from netdissect.nethook import edit_layers, retain_layers |
|
from netdissect.segviz import segment_visualization |
|
from netdissect.segmenter import UnifiedParsingSegmenter |
|
from scipy.io import savemat |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='GAN output segmentation util') |
|
parser.add_argument('--model', type=str, default= |
|
'netdissect.proggan.from_pth_file("' + |
|
'models/karras/churchoutdoor_lsun.pth")', |
|
help='constructor for the model to test') |
|
parser.add_argument('--outdir', type=str, default='images', |
|
help='directory for image output') |
|
parser.add_argument('--size', type=int, default=100, |
|
help='number of images to output') |
|
parser.add_argument('--seed', type=int, default=1, |
|
help='seed') |
|
parser.add_argument('--quiet', action='store_true', default=False, |
|
help='silences console output') |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
verbose_progress(not args.quiet) |
|
|
|
|
|
model = autoimport_eval(args.model) |
|
|
|
|
|
z_dataset = z_dataset_for_model(model, size=args.size) |
|
|
|
|
|
segmenter = UnifiedParsingSegmenter() |
|
|
|
|
|
labels, cats = segmenter.get_label_and_category_names() |
|
with open(os.path.join(args.outdir, 'labels.txt'), 'w') as f: |
|
for i, (label, cat) in enumerate(labels): |
|
f.write('%s %s\n' % (label, cat)) |
|
|
|
|
|
model.cuda() |
|
|
|
batch_size = 10 |
|
progress = default_progress() |
|
dirname = args.outdir |
|
|
|
with torch.no_grad(): |
|
|
|
z_loader = torch.utils.data.DataLoader(z_dataset, |
|
batch_size=batch_size, num_workers=2, |
|
pin_memory=True) |
|
for batch_num, [z] in enumerate(progress(z_loader, |
|
desc='Saving images')): |
|
z = z.cuda() |
|
start_index = batch_num * batch_size |
|
tensor_im = model(z) |
|
byte_im = ((tensor_im + 1) / 2 * 255).clamp(0, 255).byte().permute( |
|
0, 2, 3, 1).cpu() |
|
seg = segmenter.segment_batch(tensor_im) |
|
for i in range(len(tensor_im)): |
|
index = i + start_index |
|
filename = os.path.join(dirname, '%d_img.jpg' % index) |
|
Image.fromarray(byte_im[i].numpy()).save( |
|
filename, optimize=True, quality=100) |
|
filename = os.path.join(dirname, '%d_seg.mat' % index) |
|
savemat(filename, dict(seg=seg[i].cpu().numpy())) |
|
filename = os.path.join(dirname, '%d_seg.png' % index) |
|
Image.fromarray(segment_visualization(seg[i].cpu().numpy(), |
|
tensor_im.shape[2:])).save(filename) |
|
srcdir = os.path.realpath( |
|
os.path.join(os.getcwd(), os.path.dirname(__file__))) |
|
shutil.copy(os.path.join(srcdir, 'lightbox.html'), |
|
os.path.join(dirname, '+lightbox.html')) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|