|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Miscellaneous utility functions.""" |
|
|
|
import os |
|
import glob |
|
import pickle |
|
import re |
|
import numpy as np |
|
from collections import defaultdict |
|
import PIL.Image |
|
import dnnlib |
|
|
|
import config |
|
from training import dataset |
|
|
|
|
|
|
|
|
|
|
|
def open_file_or_url(file_or_url): |
|
if dnnlib.util.is_url(file_or_url): |
|
return dnnlib.util.open_url(file_or_url, cache_dir=config.cache_dir) |
|
return open(file_or_url, 'rb') |
|
|
|
def load_pkl(file_or_url): |
|
with open_file_or_url(file_or_url) as file: |
|
return pickle.load(file, encoding='latin1') |
|
|
|
def save_pkl(obj, filename): |
|
with open(filename, 'wb') as file: |
|
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
|
|
|
|
def adjust_dynamic_range(data, drange_in, drange_out): |
|
if drange_in != drange_out: |
|
scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) |
|
bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) |
|
data = data * scale + bias |
|
return data |
|
|
|
def create_image_grid(images, grid_size=None): |
|
assert images.ndim == 3 or images.ndim == 4 |
|
num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] |
|
|
|
if grid_size is not None: |
|
grid_w, grid_h = tuple(grid_size) |
|
else: |
|
grid_w = max(int(np.ceil(np.sqrt(num))), 1) |
|
grid_h = max((num - 1) // grid_w + 1, 1) |
|
|
|
grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) |
|
for idx in range(num): |
|
x = (idx % grid_w) * img_w |
|
y = (idx // grid_w) * img_h |
|
grid[..., y : y + img_h, x : x + img_w] = images[idx] |
|
return grid |
|
|
|
def convert_to_pil_image(image, drange=[0,1]): |
|
assert image.ndim == 2 or image.ndim == 3 |
|
if image.ndim == 3: |
|
if image.shape[0] == 1: |
|
image = image[0] |
|
else: |
|
image = image.transpose(1, 2, 0) |
|
|
|
image = adjust_dynamic_range(image, drange, [0,255]) |
|
image = np.rint(image).clip(0, 255).astype(np.uint8) |
|
fmt = 'RGB' if image.ndim == 3 else 'L' |
|
return PIL.Image.fromarray(image, fmt) |
|
|
|
def save_image(image, filename, drange=[0,1], quality=95): |
|
img = convert_to_pil_image(image, drange) |
|
if '.jpg' in filename: |
|
img.save(filename,"JPEG", quality=quality, optimize=True) |
|
else: |
|
img.save(filename) |
|
|
|
def save_image_grid(images, filename, drange=[0,1], grid_size=None): |
|
convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) |
|
|
|
|
|
|
|
|
|
def locate_run_dir(run_id_or_run_dir): |
|
if isinstance(run_id_or_run_dir, str): |
|
if os.path.isdir(run_id_or_run_dir): |
|
return run_id_or_run_dir |
|
converted = dnnlib.submission.submit.convert_path(run_id_or_run_dir) |
|
if os.path.isdir(converted): |
|
return converted |
|
|
|
run_dir_pattern = re.compile('^0*%s-' % str(run_id_or_run_dir)) |
|
for search_dir in ['']: |
|
full_search_dir = config.result_dir if search_dir == '' else os.path.normpath(os.path.join(config.result_dir, search_dir)) |
|
run_dir = os.path.join(full_search_dir, str(run_id_or_run_dir)) |
|
if os.path.isdir(run_dir): |
|
return run_dir |
|
run_dirs = sorted(glob.glob(os.path.join(full_search_dir, '*'))) |
|
run_dirs = [run_dir for run_dir in run_dirs if run_dir_pattern.match(os.path.basename(run_dir))] |
|
run_dirs = [run_dir for run_dir in run_dirs if os.path.isdir(run_dir)] |
|
if len(run_dirs) == 1: |
|
return run_dirs[0] |
|
raise IOError('Cannot locate result subdir for run', run_id_or_run_dir) |
|
|
|
def list_network_pkls(run_id_or_run_dir, include_final=True): |
|
run_dir = locate_run_dir(run_id_or_run_dir) |
|
pkls = sorted(glob.glob(os.path.join(run_dir, 'network-*.pkl'))) |
|
if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl': |
|
if include_final: |
|
pkls.append(pkls[0]) |
|
del pkls[0] |
|
return pkls |
|
|
|
def locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): |
|
for candidate in [snapshot_or_network_pkl, run_id_or_run_dir_or_network_pkl]: |
|
if isinstance(candidate, str): |
|
if os.path.isfile(candidate): |
|
return candidate |
|
converted = dnnlib.submission.submit.convert_path(candidate) |
|
if os.path.isfile(converted): |
|
return converted |
|
|
|
pkls = list_network_pkls(run_id_or_run_dir_or_network_pkl) |
|
if len(pkls) >= 1 and snapshot_or_network_pkl is None: |
|
return pkls[-1] |
|
|
|
for pkl in pkls: |
|
try: |
|
name = os.path.splitext(os.path.basename(pkl))[0] |
|
number = int(name.split('-')[-1]) |
|
if number == snapshot_or_network_pkl: |
|
return pkl |
|
except ValueError: pass |
|
except IndexError: pass |
|
raise IOError('Cannot locate network pkl for snapshot', snapshot_or_network_pkl) |
|
|
|
def get_id_string_for_network_pkl(network_pkl): |
|
p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/') |
|
return '-'.join(p[max(len(p) - 2, 0):]) |
|
|
|
|
|
|
|
|
|
def load_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): |
|
return load_pkl(locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl)) |
|
|
|
def parse_config_for_previous_run(run_id): |
|
run_dir = locate_run_dir(run_id) |
|
|
|
|
|
cfg = defaultdict(dict) |
|
with open(os.path.join(run_dir, 'config.txt'), 'rt') as f: |
|
for line in f: |
|
line = re.sub(r"^{?\s*'(\w+)':\s*{(.*)(},|}})$", r"\1 = {\2}", line.strip()) |
|
if line.startswith('dataset =') or line.startswith('train ='): |
|
exec(line, cfg, cfg) |
|
|
|
|
|
if 'file_pattern' in cfg['dataset']: |
|
cfg['dataset']['tfrecord_dir'] = cfg['dataset'].pop('file_pattern').replace('-r??.tfrecords', '') |
|
if 'mirror_augment' in cfg['dataset']: |
|
cfg['train']['mirror_augment'] = cfg['dataset'].pop('mirror_augment') |
|
if 'max_labels' in cfg['dataset']: |
|
v = cfg['dataset'].pop('max_labels') |
|
if v is None: v = 0 |
|
if v == 'all': v = 'full' |
|
cfg['dataset']['max_label_size'] = v |
|
if 'max_images' in cfg['dataset']: |
|
cfg['dataset'].pop('max_images') |
|
return cfg |
|
|
|
def load_dataset_for_previous_run(run_id, **kwargs): |
|
cfg = parse_config_for_previous_run(run_id) |
|
cfg['dataset'].update(kwargs) |
|
dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **cfg['dataset']) |
|
mirror_augment = cfg['train'].get('mirror_augment', False) |
|
return dataset_obj, mirror_augment |
|
|
|
def apply_mirror_augment(minibatch): |
|
mask = np.random.rand(minibatch.shape[0]) < 0.5 |
|
minibatch = np.array(minibatch) |
|
minibatch[mask] = minibatch[mask, :, :, ::-1] |
|
return minibatch |
|
|
|
|
|
|
|
|
|
|
|
def setup_snapshot_image_grid(G, training_set, |
|
size = '1080p', |
|
layout = 'random'): |
|
|
|
|
|
gw = 1; gh = 1 |
|
if size == '1080p': |
|
gw = np.clip(1920 // G.output_shape[3], 3, 32) |
|
gh = np.clip(1080 // G.output_shape[2], 2, 32) |
|
if size == '4k': |
|
gw = np.clip(3840 // G.output_shape[3], 7, 32) |
|
gh = np.clip(2160 // G.output_shape[2], 4, 32) |
|
|
|
|
|
reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) |
|
labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) |
|
latents = np.random.randn(gw * gh, *G.input_shape[1:]) |
|
|
|
|
|
if layout == 'random': |
|
reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) |
|
|
|
|
|
class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) |
|
if layout in class_layouts: |
|
bw, bh = class_layouts[layout] |
|
nw = (gw - 1) // bw + 1 |
|
nh = (gh - 1) // bh + 1 |
|
blocks = [[] for _i in range(nw * nh)] |
|
for _iter in range(1000000): |
|
real, label = training_set.get_minibatch_np(1) |
|
idx = np.argmax(label[0]) |
|
while idx < len(blocks) and len(blocks[idx]) >= bw * bh: |
|
idx += training_set.label_size |
|
if idx < len(blocks): |
|
blocks[idx].append((real, label)) |
|
if all(len(block) >= bw * bh for block in blocks): |
|
break |
|
for i, block in enumerate(blocks): |
|
for j, (real, label) in enumerate(block): |
|
x = (i % nw) * bw + j % bw |
|
y = (i // nw) * bh + j // bw |
|
if x < gw and y < gh: |
|
reals[x + y * gw] = real[0] |
|
labels[x + y * gw] = label[0] |
|
|
|
return (gw, gh), reals, labels, latents |
|
|
|
|
|
|