|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1' |
|
|
|
import numpy as np |
|
import os |
|
from pathlib import Path |
|
import re |
|
import sys |
|
import datetime |
|
import argparse |
|
import torch |
|
import json |
|
from types import SimpleNamespace |
|
import scipy |
|
from scipy.cluster.vq import kmeans |
|
from tqdm import trange |
|
from netdissect.nethook import InstrumentedModel |
|
from config import Config |
|
from estimators import get_estimator |
|
from models import get_instrumented_model |
|
|
|
SEED_SAMPLING = 1 |
|
SEED_RANDOM_DIRS = 2 |
|
SEED_LINREG = 3 |
|
SEED_VISUALIZATION = 5 |
|
|
|
B = 20 |
|
n_clusters = 500 |
|
|
|
def get_random_dirs(components, dimensions): |
|
gen = np.random.RandomState(seed=SEED_RANDOM_DIRS) |
|
dirs = gen.normal(size=(components, dimensions)) |
|
dirs /= np.sqrt(np.sum(dirs**2, axis=1, keepdims=True)) |
|
return dirs.astype(np.float32) |
|
|
|
|
|
def get_max_batch_size(inst, device, layer_name=None): |
|
inst.remove_edits() |
|
|
|
|
|
torch.cuda.reset_max_memory_cached(device) |
|
torch.cuda.reset_max_memory_allocated(device) |
|
total_mem = torch.cuda.get_device_properties(device).total_memory |
|
|
|
B_max = 20 |
|
|
|
|
|
for i in range(2, B_max, 2): |
|
z = inst.model.sample_latent(n_samples=i) |
|
if layer_name: |
|
inst.model.partial_forward(z, layer_name) |
|
else: |
|
inst.model.forward(z) |
|
|
|
maxmem = torch.cuda.max_memory_allocated(device) |
|
del z |
|
|
|
if maxmem > 0.5*total_mem: |
|
print('Batch size {:d}: memory usage {:.0f}MB'.format(i, maxmem / 1e6)) |
|
return i |
|
|
|
return B_max |
|
|
|
|
|
def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config): |
|
print('Performing least squares regression', flush=True) |
|
|
|
torch.manual_seed(SEED_LINREG) |
|
np.random.seed(SEED_LINREG) |
|
|
|
comp = torch.from_numpy(comp_np).float().to(inst.model.device) |
|
mean = torch.from_numpy(mean_np).float().to(inst.model.device) |
|
stdev = torch.from_numpy(stdev_np).float().to(inst.model.device) |
|
|
|
n_samp = max(10_000, config.n) // B * B |
|
n_comp = comp.shape[0] |
|
latent_dims = inst.model.get_latent_dims() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
A = np.zeros((n_samp, n_comp), dtype=np.float32) |
|
Z = np.zeros((n_samp, latent_dims), dtype=np.float32) |
|
|
|
|
|
def project(X, comp): |
|
N = X.shape[0] |
|
K = comp.shape[0] |
|
coords = torch.bmm(comp.expand([N]+[-1]*comp.ndim), X.view(N, -1, 1)) |
|
return coords.reshape(N, K) |
|
|
|
for i in trange(n_samp // B, desc='Collecting samples', ascii=True): |
|
z = inst.model.sample_latent(B) |
|
inst.model.partial_forward(z, config.layer) |
|
act = inst.retained_features()[config.layer].reshape(B, -1) |
|
|
|
|
|
act = act - mean |
|
coords = project(act, comp) |
|
coords_scaled = coords / stdev |
|
|
|
A[i*B:(i+1)*B] = coords_scaled.detach().cpu().numpy() |
|
Z[i*B:(i+1)*B] = z.detach().cpu().numpy().reshape(B, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
M_t = scipy.linalg.lstsq(A, Z, lapack_driver='gelsd')[0] |
|
|
|
|
|
Z_comp = M_t[:n_comp, :] |
|
Z_mean = np.mean(Z, axis=0, keepdims=True) |
|
|
|
return Z_comp, Z_mean |
|
|
|
def regression(comp, mean, stdev, inst, config): |
|
|
|
M = np.dot(comp, comp.T) |
|
if not np.allclose(M, np.identity(M.shape[0])): |
|
det = np.linalg.det(M) |
|
print(f'WARNING: Computed basis is not orthonormal (determinant={det})') |
|
|
|
return linreg_lstsq(comp, mean, stdev, inst, config) |
|
|
|
def compute(config, dump_name, instrumented_model): |
|
global B |
|
|
|
timestamp = lambda : datetime.datetime.now().strftime("%d.%m %H:%M") |
|
print(f'[{timestamp()}] Computing', dump_name.name) |
|
|
|
|
|
torch.manual_seed(0) |
|
np.random.seed(0) |
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
has_gpu = torch.cuda.is_available() |
|
device = torch.device('cuda' if has_gpu else 'cpu') |
|
layer_key = config.layer |
|
|
|
if instrumented_model is None: |
|
inst = get_instrumented_model(config.model, config.output_class, layer_key, device) |
|
model = inst.model |
|
else: |
|
print('Reusing InstrumentedModel instance') |
|
inst = instrumented_model |
|
model = inst.model |
|
inst.remove_edits() |
|
model.set_output_class(config.output_class) |
|
|
|
|
|
if config.use_w: |
|
print('Using W latent space') |
|
model.use_w() |
|
|
|
inst.retain_layer(layer_key) |
|
model.partial_forward(model.sample_latent(1), layer_key) |
|
sample_shape = inst.retained_features()[layer_key].shape |
|
sample_dims = np.prod(sample_shape) |
|
print('Feature shape:', sample_shape) |
|
|
|
input_shape = inst.model.get_latent_shape() |
|
input_dims = inst.model.get_latent_dims() |
|
|
|
config.components = min(config.components, sample_dims) |
|
transformer = get_estimator(config.estimator, config.components, config.sparsity) |
|
|
|
X = None |
|
X_global_mean = None |
|
|
|
|
|
B = config.batch_size or get_max_batch_size(inst, device, layer_key) |
|
|
|
|
|
N = config.n // B * B |
|
|
|
|
|
target_bytes = 20 * 1_000_000_000 |
|
feat_size_bytes = sample_dims * np.dtype('float64').itemsize |
|
N_limit_RAM = np.floor_divide(target_bytes, feat_size_bytes) |
|
if not transformer.batch_support and N > N_limit_RAM: |
|
print('WARNING: estimator does not support batching, ' \ |
|
'given config will use {:.1f} GB memory.'.format(feat_size_bytes / 1_000_000_000 * N)) |
|
|
|
|
|
if config.estimator == 'ica': |
|
lapack_max_N = np.floor_divide(np.iinfo(np.int32).max // 4, sample_dims) |
|
if N > lapack_max_N: |
|
raise RuntimeError(f'Matrices too large for ICA, please use N <= {lapack_max_N}') |
|
|
|
print('B={}, N={}, dims={}, N/dims={:.1f}'.format(B, N, sample_dims, N/sample_dims), flush=True) |
|
|
|
|
|
NB = max(B, max(2_000, 3*config.components)) |
|
|
|
samples = None |
|
if not transformer.batch_support: |
|
samples = np.zeros((N + NB, sample_dims), dtype=np.float32) |
|
|
|
torch.manual_seed(config.seed or SEED_SAMPLING) |
|
np.random.seed(config.seed or SEED_SAMPLING) |
|
|
|
|
|
|
|
|
|
n_lat = ((N + NB - 1) // B + 1) * B |
|
latents = np.zeros((n_lat, *input_shape[1:]), dtype=np.float32) |
|
with torch.no_grad(): |
|
for i in trange(n_lat // B, desc='Sampling latents'): |
|
latents[i*B:(i+1)*B] = model.sample_latent(n_samples=B).cpu().numpy() |
|
|
|
|
|
samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W' |
|
|
|
canceled = False |
|
try: |
|
X = np.ones((NB, sample_dims), dtype=np.float32) |
|
action = 'Fitting' if transformer.batch_support else 'Collecting' |
|
for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True): |
|
for mb in range(0, NB, B): |
|
z = torch.from_numpy(latents[gi+mb:gi+mb+B]).to(device) |
|
|
|
if samples_are_latents: |
|
|
|
batch = z.reshape((B, -1)) |
|
else: |
|
|
|
with torch.no_grad(): |
|
model.partial_forward(z, layer_key) |
|
|
|
|
|
batch = inst.retained_features()[layer_key].reshape((B, -1)) |
|
|
|
space_left = min(B, NB - mb) |
|
X[mb:mb+space_left] = batch.cpu().numpy()[:space_left] |
|
|
|
if transformer.batch_support: |
|
if not transformer.fit_partial(X.reshape(-1, sample_dims)): |
|
break |
|
else: |
|
samples[gi:gi+NB, :] = X.copy() |
|
except KeyboardInterrupt: |
|
if not transformer.batch_support: |
|
sys.exit(1) |
|
|
|
dump_name = dump_name.parent / dump_name.name.replace(f'n{N}', f'n{gi}') |
|
print(f'Saving current state to "{dump_name.name}" before exiting') |
|
canceled = True |
|
|
|
if not transformer.batch_support: |
|
X = samples |
|
X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32) |
|
X -= X_global_mean |
|
|
|
print(f'[{timestamp()}] Fitting whole batch') |
|
t_start_fit = datetime.datetime.now() |
|
|
|
transformer.fit(X) |
|
|
|
print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}') |
|
assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero' |
|
else: |
|
X_global_mean = transformer.transformer.mean_.reshape((1, sample_dims)) |
|
X = X.reshape(-1, sample_dims) |
|
X -= X_global_mean |
|
|
|
X_comp, X_stdev, X_var_ratio = transformer.get_components() |
|
|
|
assert X_comp.shape[1] == sample_dims \ |
|
and X_comp.shape[0] == config.components \ |
|
and X_global_mean.shape[1] == sample_dims \ |
|
and X_stdev.shape[0] == config.components, 'Invalid shape' |
|
|
|
|
|
if samples_are_latents: |
|
Z_comp = X_comp |
|
Z_global_mean = X_global_mean |
|
else: |
|
Z_comp, Z_global_mean = regression(X_comp, X_global_mean, X_stdev, inst, config) |
|
|
|
|
|
Z_comp /= np.linalg.norm(Z_comp, axis=-1, keepdims=True) |
|
|
|
|
|
|
|
random_dirs = get_random_dirs(config.components, np.prod(sample_shape)) |
|
n_rand_samples = min(5000, X.shape[0]) |
|
X_view = X[:n_rand_samples, :].T |
|
assert np.shares_memory(X_view, X), "Error: slice produced copy" |
|
X_stdev_random = np.dot(random_dirs, X_view).std(axis=1) |
|
|
|
|
|
X_comp = X_comp.reshape(-1, *sample_shape) |
|
X_global_mean = X_global_mean.reshape(sample_shape) |
|
Z_comp = Z_comp.reshape(-1, *input_shape) |
|
Z_global_mean = Z_global_mean.reshape(input_shape) |
|
|
|
|
|
lat_stdev = np.ones_like(X_stdev) |
|
if config.use_w: |
|
samples = model.sample_latent(5000).reshape(5000, input_dims).detach().cpu().numpy() |
|
coords = np.dot(Z_comp.reshape(-1, input_dims), samples.T) |
|
lat_stdev = coords.std(axis=1) |
|
|
|
os.makedirs(dump_name.parent, exist_ok=True) |
|
np.savez_compressed(dump_name, **{ |
|
'act_comp': X_comp.astype(np.float32), |
|
'act_mean': X_global_mean.astype(np.float32), |
|
'act_stdev': X_stdev.astype(np.float32), |
|
'lat_comp': Z_comp.astype(np.float32), |
|
'lat_mean': Z_global_mean.astype(np.float32), |
|
'lat_stdev': lat_stdev.astype(np.float32), |
|
'var_ratio': X_var_ratio.astype(np.float32), |
|
'random_stdevs': X_stdev_random.astype(np.float32), |
|
}) |
|
|
|
if canceled: |
|
sys.exit(1) |
|
|
|
|
|
if instrumented_model is None: |
|
inst.close() |
|
del inst |
|
del model |
|
|
|
del X |
|
del X_comp |
|
del random_dirs |
|
del batch |
|
del samples |
|
del latents |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def get_or_compute(config, model=None, submit_config=None, force_recompute=False): |
|
if submit_config is None: |
|
wrkdir = str(Path(__file__).parent.resolve()) |
|
submit_config = SimpleNamespace(run_dir_root = wrkdir, run_dir = wrkdir) |
|
|
|
|
|
return _compute(submit_config, config, model, force_recompute) |
|
|
|
def _compute(submit_config, config, model=None, force_recompute=False): |
|
basedir = Path(submit_config.run_dir) |
|
outdir = basedir / 'out' |
|
|
|
if config.n is None: |
|
raise RuntimeError('Must specify number of samples with -n=XXX') |
|
|
|
if model and not isinstance(model, InstrumentedModel): |
|
raise RuntimeError('Passed model has to be wrapped in "InstrumentedModel"') |
|
|
|
if config.use_w and not 'StyleGAN' in config.model: |
|
raise RuntimeError(f'Cannot change latent space of non-StyleGAN model {config.model}') |
|
|
|
transformer = get_estimator(config.estimator, config.components, config.sparsity) |
|
dump_name = "{}-{}_{}_{}_n{}{}{}.npz".format( |
|
config.model.lower(), |
|
config.output_class.replace(' ', '_'), |
|
config.layer.lower(), |
|
transformer.get_param_str(), |
|
config.n, |
|
'_w' if config.use_w else '', |
|
f'_seed{config.seed}' if config.seed else '' |
|
) |
|
|
|
dump_path = basedir / 'cache' / 'components' / dump_name |
|
|
|
if not dump_path.is_file() or force_recompute: |
|
print('Not cached') |
|
t_start = datetime.datetime.now() |
|
compute(config, dump_path, model) |
|
print('Total time:', datetime.datetime.now() - t_start) |
|
|
|
return dump_path |