Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
from swapae.evaluation import BaseEvaluator | |
import swapae.util as util | |
class SwapVisualizationEvaluator(BaseEvaluator): | |
def modify_commandline_options(parser, is_train): | |
parser.add_argument("--swap_num_columns", type=int, default=4, | |
help="number of images to be shown in the swap visualization grid. Setting this value will result in 4x4 swapping grid, with additional row and col for showing original images.") | |
parser.add_argument("--swap_num_images", type=int, default=16, | |
help="total number of images to perform swapping. In the end, (swap_num_images / swap_num_columns) grid will be saved to disk") | |
return parser | |
def gather_images(self, dataset): | |
all_images = [] | |
num_images_to_gather = max(self.opt.swap_num_columns, self.opt.num_gpus) | |
exhausted = False | |
while len(all_images) < num_images_to_gather: | |
try: | |
data = next(dataset) | |
except StopIteration: | |
print("Exhausted the dataset at %s" % (self.opt.dataroot)) | |
exhausted = True | |
break | |
for i in range(data["real_A"].size(0)): | |
all_images.append(data["real_A"][i:i+1]) | |
if "real_B" in data: | |
all_images.append(data["real_B"][i:i+1]) | |
if len(all_images) >= num_images_to_gather: | |
break | |
if len(all_images) == 0: | |
return None, None, True | |
return all_images, exhausted | |
def generate_mix_grid(self, model, images): | |
sps, gls = [], [] | |
for image in images: | |
assert image.size(0) == 1 | |
sp, gl = model(image.expand(self.opt.num_gpus, -1, -1, -1), command="encode") | |
sp = sp[:1] | |
gl = gl[:1] | |
sps.append(sp) | |
gls.append(gl) | |
gl = torch.cat(gls, dim=0) | |
def put_img(img, canvas, row, col): | |
h, w = img.shape[0], img.shape[1] | |
start_x = int(self.opt.load_size * col + (self.opt.load_size - w) * 0.5) | |
start_y = int(self.opt.load_size * row + (self.opt.load_size - h) * 0.5) | |
canvas[start_y:start_y + h, start_x: start_x + w] = img | |
grid_w = self.opt.load_size * (gl.size(0) + 1) | |
grid_h = self.opt.load_size * (gl.size(0) + 1) | |
grid_img = np.ones((grid_h, grid_w, 3), dtype=np.uint8) | |
#images_np = util.tensor2im(images, tile=False) | |
for i, image in enumerate(images): | |
image_np = util.tensor2im(image, tile=False)[0] | |
put_img(image_np, grid_img, 0, i + 1) | |
put_img(image_np, grid_img, i + 1, 0) | |
for i, sp in enumerate(sps): | |
sp_for_current_row = sp.repeat(gl.size(0), 1, 1, 1) | |
mix_row = model(sp_for_current_row, gl, command="decode") | |
mix_row = util.tensor2im(mix_row, tile=False) | |
for j, mix in enumerate(mix_row): | |
put_img(mix, grid_img, i + 1, j + 1) | |
final_grid = Image.fromarray(grid_img) | |
return final_grid | |
def evaluate(self, model, dataset, nsteps): | |
nsteps = self.opt.resume_iter if nsteps is None else str(round(nsteps / 1000)) + "k" | |
savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps)) | |
os.makedirs(savedir, exist_ok=True) | |
webpage_title = "Swap Visualization of %s. iter=%s. phase=%s" % \ | |
(self.opt.name, str(nsteps), self.target_phase) | |
webpage = util.HTML(savedir, webpage_title) | |
num_repeats = int(np.ceil(self.opt.swap_num_images / max(self.opt.swap_num_columns, self.opt.num_gpus))) | |
for i in range(num_repeats): | |
images, should_break = self.gather_images(dataset) | |
if images is None: | |
break | |
mix_grid = self.generate_mix_grid(model, images) | |
webpage.add_images([mix_grid], ["%04d.png" % i]) | |
if should_break: | |
break | |
webpage.save() | |
return {} | |