|
import os, torch, numpy |
|
from torch.utils.data import TensorDataset |
|
|
|
def z_dataset_for_model(model, size=100, seed=1): |
|
return TensorDataset(z_sample_for_model(model, size, seed)) |
|
|
|
def z_sample_for_model(model, size=100, seed=1): |
|
|
|
if hasattr(model, 'input_shape'): |
|
sample = standard_z_sample(size, model.input_shape[1], seed=seed).view( |
|
(size,) + model.input_shape[1:]) |
|
return sample |
|
|
|
first_layer = [c for c in model.modules() |
|
if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, |
|
torch.nn.Linear))][0] |
|
|
|
if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): |
|
sample = standard_z_sample( |
|
size, first_layer.in_channels, seed=seed)[:,:,None,None] |
|
else: |
|
sample = standard_z_sample( |
|
size, first_layer.in_features, seed=seed) |
|
return sample |
|
|
|
def standard_z_sample(size, depth, seed=1, device=None): |
|
''' |
|
Generate a standard set of random Z as a (size, z_dimension) tensor. |
|
With the same random seed, it always returns the same z (e.g., |
|
the first one is always the same regardless of the size.) |
|
''' |
|
|
|
|
|
rng = numpy.random.RandomState(seed) |
|
result = torch.from_numpy( |
|
rng.standard_normal(size * depth) |
|
.reshape(size, depth)).float() |
|
if device is not None: |
|
result = result.to(device) |
|
return result |
|
|
|
|