TextureScraping / swapae /evaluation /swap_visualization_evaluator.py
sunshineatnoon
Add application file
1b2a9b1
raw
history blame
4.12 kB
import os
from PIL import Image
import numpy as np
import torch
from swapae.evaluation import BaseEvaluator
import swapae.util as util
class SwapVisualizationEvaluator(BaseEvaluator):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument("--swap_num_columns", type=int, default=4,
help="number of images to be shown in the swap visualization grid. Setting this value will result in 4x4 swapping grid, with additional row and col for showing original images.")
parser.add_argument("--swap_num_images", type=int, default=16,
help="total number of images to perform swapping. In the end, (swap_num_images / swap_num_columns) grid will be saved to disk")
return parser
def gather_images(self, dataset):
all_images = []
num_images_to_gather = max(self.opt.swap_num_columns, self.opt.num_gpus)
exhausted = False
while len(all_images) < num_images_to_gather:
try:
data = next(dataset)
except StopIteration:
print("Exhausted the dataset at %s" % (self.opt.dataroot))
exhausted = True
break
for i in range(data["real_A"].size(0)):
all_images.append(data["real_A"][i:i+1])
if "real_B" in data:
all_images.append(data["real_B"][i:i+1])
if len(all_images) >= num_images_to_gather:
break
if len(all_images) == 0:
return None, None, True
return all_images, exhausted
def generate_mix_grid(self, model, images):
sps, gls = [], []
for image in images:
assert image.size(0) == 1
sp, gl = model(image.expand(self.opt.num_gpus, -1, -1, -1), command="encode")
sp = sp[:1]
gl = gl[:1]
sps.append(sp)
gls.append(gl)
gl = torch.cat(gls, dim=0)
def put_img(img, canvas, row, col):
h, w = img.shape[0], img.shape[1]
start_x = int(self.opt.load_size * col + (self.opt.load_size - w) * 0.5)
start_y = int(self.opt.load_size * row + (self.opt.load_size - h) * 0.5)
canvas[start_y:start_y + h, start_x: start_x + w] = img
grid_w = self.opt.load_size * (gl.size(0) + 1)
grid_h = self.opt.load_size * (gl.size(0) + 1)
grid_img = np.ones((grid_h, grid_w, 3), dtype=np.uint8)
#images_np = util.tensor2im(images, tile=False)
for i, image in enumerate(images):
image_np = util.tensor2im(image, tile=False)[0]
put_img(image_np, grid_img, 0, i + 1)
put_img(image_np, grid_img, i + 1, 0)
for i, sp in enumerate(sps):
sp_for_current_row = sp.repeat(gl.size(0), 1, 1, 1)
mix_row = model(sp_for_current_row, gl, command="decode")
mix_row = util.tensor2im(mix_row, tile=False)
for j, mix in enumerate(mix_row):
put_img(mix, grid_img, i + 1, j + 1)
final_grid = Image.fromarray(grid_img)
return final_grid
def evaluate(self, model, dataset, nsteps):
nsteps = self.opt.resume_iter if nsteps is None else str(round(nsteps / 1000)) + "k"
savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps))
os.makedirs(savedir, exist_ok=True)
webpage_title = "Swap Visualization of %s. iter=%s. phase=%s" % \
(self.opt.name, str(nsteps), self.target_phase)
webpage = util.HTML(savedir, webpage_title)
num_repeats = int(np.ceil(self.opt.swap_num_images / max(self.opt.swap_num_columns, self.opt.num_gpus)))
for i in range(num_repeats):
images, should_break = self.gather_images(dataset)
if images is None:
break
mix_grid = self.generate_mix_grid(model, images)
webpage.add_images([mix_grid], ["%04d.png" % i])
if should_break:
break
webpage.save()
return {}