fffiloni's picture
Upload 244 files
b3f324b verified
raw
history blame contribute delete
945 Bytes
from torch import nn
import yaml
import torch
from omegaconf import OmegaConf
from .vqgan import VQModel, GumbelVQ
def load_config(config_path, display=False):
config = OmegaConf.load(config_path)
if display:
print(yaml.dump(OmegaConf.to_container(config)))
return config
def load_vqgan(config, ckpt_path=None, is_gumbel=False):
if is_gumbel:
model = GumbelVQ(**config.model.params)
else:
model = VQModel(**config.model.params)
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
missing, unexpected = model.load_state_dict(sd, strict=False)
return model.eval()
class SDVQVAEWrapper(nn.Module):
def __init__(self, name):
super(SDVQVAEWrapper, self).__init__()
raise NotImplementedError
def encode(self, x): # b c h w
raise NotImplementedError
def decode(self, x):
raise NotImplementedError