disentangled-image-editing-final-project
/
ContraCLIP
/models
/genforce
/configs
/stylegan_ffhq256_encoder_y.py
# python3.7 | |
"""Configuration for training StyleGAN Encoder on FF-HQ (256) dataset. | |
All settings are particularly used for one replica (GPU), such as `batch_size` | |
and `num_workers`. | |
""" | |
gan_model_path = 'checkpoints/stylegan_ffhq256.pth' | |
perceptual_model_path = 'checkpoints/vgg16.pth' | |
runner_type = 'EncoderRunner' | |
gan_type = 'stylegan' | |
resolution = 256 | |
batch_size = 12 | |
val_batch_size = 25 | |
total_img = 14000_000 | |
space_of_latent = 'y' | |
# Training dataset is repeated at the beginning to avoid loading dataset | |
# repeatedly at the end of each epoch. This can save some I/O time. | |
data = dict( | |
num_workers=4, | |
repeat=500, | |
# train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5), | |
# val=dict(root_dir='data/ffhq', resolution=resolution), | |
train=dict(root_dir='data/', data_format='list', | |
image_list_path='data/ffhq/ffhq_train_list.txt', | |
resolution=resolution, mirror=0.5), | |
val=dict(root_dir='data/', data_format='list', | |
image_list_path='./data/ffhq/ffhq_val_list.txt', | |
resolution=resolution), | |
) | |
controllers = dict( | |
RunningLogger=dict(every_n_iters=50), | |
Snapshoter=dict(every_n_iters=10000, first_iter=True, num=200), | |
Checkpointer=dict(every_n_iters=10000, first_iter=False), | |
) | |
modules = dict( | |
discriminator=dict( | |
model=dict(gan_type=gan_type, resolution=resolution), | |
lr=dict(lr_type='ExpSTEP', decay_factor=0.8, decay_step=36458 // 2), | |
opt=dict(opt_type='Adam', base_lr=1e-4, betas=(0.9, 0.99)), | |
kwargs_train=dict(), | |
kwargs_val=dict(), | |
), | |
generator=dict( | |
model=dict(gan_type=gan_type, resolution=resolution, repeat_w=True), | |
kwargs_val=dict(randomize_noise=False), | |
), | |
encoder=dict( | |
model=dict(gan_type=gan_type, resolution=resolution, network_depth=18, | |
latent_dim = [1024] * 8 + [512, 512, 256, 256, 128, 128], | |
num_latents_per_head=[4, 4, 6], | |
use_fpn=True, | |
fpn_channels=512, | |
use_sam=True, | |
sam_channels=512), | |
lr=dict(lr_type='ExpSTEP', decay_factor=0.8, decay_step=36458 // 2), | |
opt=dict(opt_type='Adam', base_lr=1e-4, betas=(0.9, 0.99)), | |
kwargs_train=dict(), | |
kwargs_val=dict(), | |
), | |
) | |
loss = dict( | |
type='EncoderLoss', | |
d_loss_kwargs=dict(r1_gamma=10.0), | |
e_loss_kwargs=dict(adv_lw=0.08, perceptual_lw=5e-5), | |
perceptual_kwargs=dict(output_layer_idx=23, | |
pretrained_weight_path=perceptual_model_path), | |
) | |