TextureScraping / swapae /evaluation /swap_visualization_evaluator.py
sunshineatnoon
Add application file
1b2a9b1
import os
from PIL import Image
import numpy as np
import torch
from swapae.evaluation import BaseEvaluator
import swapae.util as util
class SwapVisualizationEvaluator(BaseEvaluator):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument("--swap_num_columns", type=int, default=4,
help="number of images to be shown in the swap visualization grid. Setting this value will result in 4x4 swapping grid, with additional row and col for showing original images.")
parser.add_argument("--swap_num_images", type=int, default=16,
help="total number of images to perform swapping. In the end, (swap_num_images / swap_num_columns) grid will be saved to disk")
return parser
def gather_images(self, dataset):
all_images = []
num_images_to_gather = max(self.opt.swap_num_columns, self.opt.num_gpus)
exhausted = False
while len(all_images) < num_images_to_gather:
try:
data = next(dataset)
except StopIteration:
print("Exhausted the dataset at %s" % (self.opt.dataroot))
exhausted = True
break
for i in range(data["real_A"].size(0)):
all_images.append(data["real_A"][i:i+1])
if "real_B" in data:
all_images.append(data["real_B"][i:i+1])
if len(all_images) >= num_images_to_gather:
break
if len(all_images) == 0:
return None, None, True
return all_images, exhausted
def generate_mix_grid(self, model, images):
sps, gls = [], []
for image in images:
assert image.size(0) == 1
sp, gl = model(image.expand(self.opt.num_gpus, -1, -1, -1), command="encode")
sp = sp[:1]
gl = gl[:1]
sps.append(sp)
gls.append(gl)
gl = torch.cat(gls, dim=0)
def put_img(img, canvas, row, col):
h, w = img.shape[0], img.shape[1]
start_x = int(self.opt.load_size * col + (self.opt.load_size - w) * 0.5)
start_y = int(self.opt.load_size * row + (self.opt.load_size - h) * 0.5)
canvas[start_y:start_y + h, start_x: start_x + w] = img
grid_w = self.opt.load_size * (gl.size(0) + 1)
grid_h = self.opt.load_size * (gl.size(0) + 1)
grid_img = np.ones((grid_h, grid_w, 3), dtype=np.uint8)
#images_np = util.tensor2im(images, tile=False)
for i, image in enumerate(images):
image_np = util.tensor2im(image, tile=False)[0]
put_img(image_np, grid_img, 0, i + 1)
put_img(image_np, grid_img, i + 1, 0)
for i, sp in enumerate(sps):
sp_for_current_row = sp.repeat(gl.size(0), 1, 1, 1)
mix_row = model(sp_for_current_row, gl, command="decode")
mix_row = util.tensor2im(mix_row, tile=False)
for j, mix in enumerate(mix_row):
put_img(mix, grid_img, i + 1, j + 1)
final_grid = Image.fromarray(grid_img)
return final_grid
def evaluate(self, model, dataset, nsteps):
nsteps = self.opt.resume_iter if nsteps is None else str(round(nsteps / 1000)) + "k"
savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps))
os.makedirs(savedir, exist_ok=True)
webpage_title = "Swap Visualization of %s. iter=%s. phase=%s" % \
(self.opt.name, str(nsteps), self.target_phase)
webpage = util.HTML(savedir, webpage_title)
num_repeats = int(np.ceil(self.opt.swap_num_images / max(self.opt.swap_num_columns, self.opt.num_gpus)))
for i in range(num_repeats):
images, should_break = self.gather_images(dataset)
if images is None:
break
mix_grid = self.generate_mix_grid(model, images)
webpage.add_images([mix_grid], ["%04d.png" % i])
if should_break:
break
webpage.save()
return {}