Spaces:
Running
on
Zero
Running
on
Zero

Refactor app.py to improve UI layout and rename weight download function; update import path for AutoEncoder in vae.py
f7f1ca1
import argparse | |
import torch | |
from models.bsq_vae.flux_vqgan import AutoEncoder | |
def load_cnn(model, state_dict, prefix, expand=False, use_linear=False): | |
delete_keys = [] | |
loaded_keys = [] | |
for key in state_dict: | |
if key.startswith(prefix): | |
_key = key[len(prefix):] | |
if _key in model.state_dict(): | |
# load nn.Conv2d or nn.Linear to nn.Linear | |
if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key): | |
load_weights = state_dict[key].squeeze() | |
elif _key.endswith(".conv.weight") and expand: | |
if model.state_dict()[_key].shape == state_dict[key].shape: | |
# 2D cnn to 2D cnn | |
load_weights = state_dict[key] | |
else: | |
# 2D cnn to 3D cnn | |
_expand_dim = model.state_dict()[_key].shape[2] | |
load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) | |
else: | |
load_weights = state_dict[key] | |
model.state_dict()[_key].copy_(load_weights) | |
delete_keys.append(key) | |
loaded_keys.append(prefix+_key) | |
# load nn.Conv2d to Conv class | |
conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."] | |
if any(k in _key for k in conv_list): | |
if _key.endswith(".weight"): | |
conv_key = _key.replace(".weight", ".conv.weight") | |
if conv_key and conv_key in model.state_dict(): | |
if model.state_dict()[conv_key].shape == state_dict[key].shape: | |
# 2D cnn to 2D cnn | |
load_weights = state_dict[key] | |
else: | |
# 2D cnn to 3D cnn | |
_expand_dim = model.state_dict()[conv_key].shape[2] | |
load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) | |
model.state_dict()[conv_key].copy_(load_weights) | |
delete_keys.append(key) | |
loaded_keys.append(prefix+conv_key) | |
if _key.endswith(".bias"): | |
conv_key = _key.replace(".bias", ".conv.bias") | |
if conv_key and conv_key in model.state_dict(): | |
model.state_dict()[conv_key].copy_(state_dict[key]) | |
delete_keys.append(key) | |
loaded_keys.append(prefix+conv_key) | |
# load nn.GroupNorm to Normalize class | |
if "norm" in _key: | |
if _key.endswith(".weight"): | |
norm_key = _key.replace(".weight", ".norm.weight") | |
if norm_key and norm_key in model.state_dict(): | |
model.state_dict()[norm_key].copy_(state_dict[key]) | |
delete_keys.append(key) | |
loaded_keys.append(prefix+norm_key) | |
if _key.endswith(".bias"): | |
norm_key = _key.replace(".bias", ".norm.bias") | |
if norm_key and norm_key in model.state_dict(): | |
model.state_dict()[norm_key].copy_(state_dict[key]) | |
delete_keys.append(key) | |
loaded_keys.append(prefix+norm_key) | |
for key in delete_keys: | |
del state_dict[key] | |
return model, state_dict, loaded_keys | |
def vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],): | |
args=argparse.Namespace( | |
vqgan_ckpt=vqgan_ckpt, | |
sd_ckpt=None, | |
inference_type='image', | |
save='./imagenet_val_bsq', | |
save_prediction=True, | |
image_recon4video=False, | |
junke_old=False, | |
device='cuda', | |
max_steps=1000000.0, | |
log_every=1, | |
visu_every=1000, | |
ckpt_every=1000, | |
default_root_dir='', | |
compile='no', | |
ema='no', | |
lr=0.0001, | |
beta1=0.9, | |
beta2=0.95, | |
warmup_steps=0, | |
optim_type='Adam', | |
disc_optim_type=None, | |
lr_min=0.0, | |
warmup_lr_init=0.0, | |
max_grad_norm=1.0, | |
max_grad_norm_disc=1.0, | |
disable_sch=False, | |
patch_size=patch_size, | |
temporal_patch_size=4, | |
embedding_dim=256, | |
codebook_dim=codebook_dim, | |
num_quantizers=8, | |
quantizer_type='MultiScaleBSQ', | |
use_vae=False, | |
use_freq_enc=False, | |
use_freq_dec=False, | |
preserve_norm=False, | |
ln_before_quant=False, | |
ln_init_by_sqrt=False, | |
use_pxsf=False, | |
new_quant=True, | |
use_decay_factor=False, | |
mask_out=False, | |
use_stochastic_depth=False, | |
drop_rate=0.0, | |
schedule_mode=schedule_mode, | |
lr_drop=None, | |
lr_drop_rate=0.1, | |
keep_first_quant=False, | |
keep_last_quant=False, | |
remove_residual_detach=False, | |
use_out_phi=False, | |
use_out_phi_res=False, | |
use_lecam_reg=False, | |
lecam_weight=0.05, | |
perceptual_model='vgg16', | |
base_ch_disc=64, | |
random_flip=False, | |
flip_prob=0.5, | |
flip_mode='stochastic', | |
max_flip_lvl=1, | |
not_load_optimizer=False, | |
use_lecam_reg_zero=False, | |
freeze_encoder=False, | |
rm_downsample=False, | |
random_flip_1lvl=False, | |
flip_lvl_idx=0, | |
drop_when_test=False, | |
drop_lvl_idx=0, | |
drop_lvl_num=1, | |
disc_version='v1', | |
magvit_disc=False, | |
sigmoid_in_disc=False, | |
activation_in_disc='leaky_relu', | |
apply_blur=False, | |
apply_noise=False, | |
dis_warmup_steps=0, | |
dis_lr_multiplier=1.0, | |
dis_minlr_multiplier=False, | |
disc_channels=64, | |
disc_layers=3, | |
discriminator_iter_start=0, | |
disc_pretrain_iter=0, | |
disc_optim_steps=1, | |
disc_warmup=0, | |
disc_pool='no', | |
disc_pool_size=1000, | |
advanced_disc=False, | |
recon_loss_type='l1', | |
video_perceptual_weight=0.0, | |
image_gan_weight=1.0, | |
video_gan_weight=1.0, | |
image_disc_weight=0.0, | |
video_disc_weight=0.0, | |
l1_weight=4.0, | |
gan_feat_weight=0.0, | |
perceptual_weight=0.0, | |
kl_weight=0.0, | |
lfq_weight=0.0, | |
entropy_loss_weight=0.1, | |
commitment_loss_weight=0.25, | |
diversity_gamma=1, | |
norm_type='group', | |
disc_loss_type='hinge', | |
use_checkpoint=False, | |
precision='fp32', | |
encoder_dtype='fp32', | |
upcast_attention='', | |
upcast_tf32=False, | |
tokenizer='flux', | |
pretrained=None, | |
pretrained_mode='full', | |
inflation_pe=False, | |
init_vgen='no', | |
no_init_idis=False, | |
init_idis='keep', | |
init_vdis='no', | |
enable_nan_detector=False, | |
turn_on_profiler=False, | |
profiler_scheduler_wait_steps=10, | |
debug=True, | |
video_logger=False, | |
bytenas='', | |
username='', | |
seed=1234, | |
vq_to_vae=False, | |
load_not_strict=False, | |
zero=0, | |
bucket_cap_mb=40, | |
manual_gc_interval=1000, | |
data_path=[''], | |
data_type=[''], | |
dataset_list=['imagenet'], | |
fps=-1, | |
dataaug='resizecrop', | |
multi_resolution=False, | |
random_bucket_ratio=0.0, | |
sequence_length=16, | |
resolution=[256, 256], | |
batch_size=[1], | |
num_workers=0, | |
image_channels=3, | |
codebook_size=codebook_size, | |
codebook_l2_norm=True, | |
codebook_show_usage=True, | |
commit_loss_beta=0.25, | |
entropy_loss_ratio=0.0, | |
base_ch=128, | |
num_res_blocks=2, | |
encoder_ch_mult=encoder_ch_mult, | |
decoder_ch_mult=decoder_ch_mult, | |
dropout_p=0.0, | |
cnn_type='2d', | |
cnn_version='v1', | |
conv_in_out_2d='no', | |
conv_inner_2d='no', | |
res_conv_2d='no', | |
cnn_attention='no', | |
cnn_norm_axis='spatial', | |
flux_weight=0, | |
cycle_weight=0, | |
cycle_feat_weight=0, | |
cycle_gan_weight=0, | |
cycle_loop=0, | |
z_drop=0.0) | |
vae = AutoEncoder(args) | |
use_vae = vae.use_vae | |
if not use_vae: | |
num_codes = args.codebook_size | |
if isinstance(vqgan_ckpt, str): | |
state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True) | |
else: | |
state_dict = args.vqgan_ckpt | |
if state_dict: | |
if args.ema == "yes": | |
vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["ema"], prefix="", expand=False) | |
else: | |
vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["vae"], prefix="", expand=False) | |
if test_mode: | |
vae.eval() | |
[p.requires_grad_(False) for p in vae.parameters()] | |
return vae |