Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from swapae.evaluation import BaseEvaluator | |
import swapae.util as util | |
import numpy as np | |
from PIL import Image | |
class StructureStyleGridGenerationEvaluator(BaseEvaluator): | |
""" generate swapping images and save to disk """ | |
def modify_commandline_options(parser, is_train): | |
return parser | |
def create_webpage(self, nsteps): | |
nsteps = self.opt.resume_iter if nsteps is None else nsteps | |
savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps)) | |
os.makedirs(savedir, exist_ok=True) | |
webpage_title = "%s. iter=%s. phase=%s" % \ | |
(self.opt.name, str(nsteps), self.target_phase) | |
self.webpage = util.HTML(savedir, webpage_title) | |
def add_to_webpage(self, images, filenames, tile=1): | |
converted_images = [] | |
for image in images: | |
if isinstance(image, list): | |
image = torch.stack(image, dim=0).flatten(0, 1) | |
image = Image.fromarray(util.tensor2im(image, tile=min(image.size(0), tile))) | |
converted_images.append(image) | |
self.webpage.add_images(converted_images, | |
filenames) | |
print("saved %s" % str(filenames)) | |
#self.webpage.save() | |
def evaluate(self, model, dataset, nsteps=None): | |
self.create_webpage(nsteps) | |
structure_images, style_images = {}, {} | |
for i, data_i in enumerate(dataset): | |
bs = data_i["real_A"].size(0) | |
#sp, gl = model(data_i["real_A"].cuda(), command="encode") | |
for j in range(bs): | |
image = data_i["real_A"][j:j+1] | |
path = data_i["path_A"][j] | |
imagename = os.path.splitext(os.path.basename(path))[0] | |
if "/structure/" in path: | |
structure_images[imagename] = image | |
else: | |
style_images[imagename] = image | |
gls = [] | |
style_paths = list(style_images.keys()) | |
for style_path in style_paths: | |
style_image = style_images[style_path].cuda() | |
gls.append(model(style_image, command="encode")[1]) | |
sps = [] | |
structure_paths = list(structure_images.keys()) | |
for structure_path in structure_paths: | |
structure_image = structure_images[structure_path].cuda() | |
sps.append(model(structure_image, command="encode")[0]) | |
# top row to show the input images | |
blank_image = style_images[style_paths[0]] * 0.0 + 1.0 | |
self.add_to_webpage([blank_image] + [style_images[style_path] for style_path in style_paths], | |
["blank.png"] + [style_path + ".png" for style_path in style_paths], | |
tile=1) | |
# swapping | |
for i, structure_path in enumerate(structure_paths): | |
structure_image = structure_images[structure_path] | |
swaps = [] | |
filenames = [] | |
for j, style_path in enumerate(style_paths): | |
swaps.append(model(sps[i], gls[j], command="decode")) | |
filenames.append(structure_path + "_" + style_path + ".png") | |
self.add_to_webpage([structure_image] + swaps, | |
[structure_path + ".png"] + filenames, | |
tile=1) | |
self.webpage.save() | |
return {} | |