Spaces:
Runtime error
Runtime error
from torch import nn | |
import yaml | |
import torch | |
from omegaconf import OmegaConf | |
from .vqgan import VQModel, GumbelVQ | |
def load_config(config_path, display=False): | |
config = OmegaConf.load(config_path) | |
if display: | |
print(yaml.dump(OmegaConf.to_container(config))) | |
return config | |
def load_vqgan(config, ckpt_path=None, is_gumbel=False): | |
if is_gumbel: | |
model = GumbelVQ(**config.model.params) | |
else: | |
model = VQModel(**config.model.params) | |
if ckpt_path is not None: | |
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] | |
missing, unexpected = model.load_state_dict(sd, strict=False) | |
return model.eval() | |
class SDVQVAEWrapper(nn.Module): | |
def __init__(self, name): | |
super(SDVQVAEWrapper, self).__init__() | |
raise NotImplementedError | |
def encode(self, x): # b c h w | |
raise NotImplementedError | |
def decode(self, x): | |
raise NotImplementedError | |