TextureScraping / swapae /evaluation /structure_style_grid_generation_evaluator.py
sunshineatnoon
Add application file
1b2a9b1
raw
history blame
3.44 kB
import os
import torch
from swapae.evaluation import BaseEvaluator
import swapae.util as util
import numpy as np
from PIL import Image
class StructureStyleGridGenerationEvaluator(BaseEvaluator):
""" generate swapping images and save to disk """
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def create_webpage(self, nsteps):
nsteps = self.opt.resume_iter if nsteps is None else nsteps
savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps))
os.makedirs(savedir, exist_ok=True)
webpage_title = "%s. iter=%s. phase=%s" % \
(self.opt.name, str(nsteps), self.target_phase)
self.webpage = util.HTML(savedir, webpage_title)
def add_to_webpage(self, images, filenames, tile=1):
converted_images = []
for image in images:
if isinstance(image, list):
image = torch.stack(image, dim=0).flatten(0, 1)
image = Image.fromarray(util.tensor2im(image, tile=min(image.size(0), tile)))
converted_images.append(image)
self.webpage.add_images(converted_images,
filenames)
print("saved %s" % str(filenames))
#self.webpage.save()
def evaluate(self, model, dataset, nsteps=None):
self.create_webpage(nsteps)
structure_images, style_images = {}, {}
for i, data_i in enumerate(dataset):
bs = data_i["real_A"].size(0)
#sp, gl = model(data_i["real_A"].cuda(), command="encode")
for j in range(bs):
image = data_i["real_A"][j:j+1]
path = data_i["path_A"][j]
imagename = os.path.splitext(os.path.basename(path))[0]
if "/structure/" in path:
structure_images[imagename] = image
else:
style_images[imagename] = image
gls = []
style_paths = list(style_images.keys())
for style_path in style_paths:
style_image = style_images[style_path].cuda()
gls.append(model(style_image, command="encode")[1])
sps = []
structure_paths = list(structure_images.keys())
for structure_path in structure_paths:
structure_image = structure_images[structure_path].cuda()
sps.append(model(structure_image, command="encode")[0])
# top row to show the input images
blank_image = style_images[style_paths[0]] * 0.0 + 1.0
self.add_to_webpage([blank_image] + [style_images[style_path] for style_path in style_paths],
["blank.png"] + [style_path + ".png" for style_path in style_paths],
tile=1)
# swapping
for i, structure_path in enumerate(structure_paths):
structure_image = structure_images[structure_path]
swaps = []
filenames = []
for j, style_path in enumerate(style_paths):
swaps.append(model(sps[i], gls[j], command="decode"))
filenames.append(structure_path + "_" + style_path + ".png")
self.add_to_webpage([structure_image] + swaps,
[structure_path + ".png"] + filenames,
tile=1)
self.webpage.save()
return {}