Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from torch import nn | |
| import yaml | |
| import torch | |
| from omegaconf import OmegaConf | |
| from .vqgan import VQModel, GumbelVQ | |
| def load_config(config_path, display=False): | |
| config = OmegaConf.load(config_path) | |
| if display: | |
| print(yaml.dump(OmegaConf.to_container(config))) | |
| return config | |
| def load_vqgan(config, ckpt_path=None, is_gumbel=False): | |
| if is_gumbel: | |
| model = GumbelVQ(**config.model.params) | |
| else: | |
| model = VQModel(**config.model.params) | |
| if ckpt_path is not None: | |
| sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] | |
| missing, unexpected = model.load_state_dict(sd, strict=False) | |
| return model.eval() | |
| class SDVQVAEWrapper(nn.Module): | |
| def __init__(self, name): | |
| super(SDVQVAEWrapper, self).__init__() | |
| raise NotImplementedError | |
| def encode(self, x): # b c h w | |
| raise NotImplementedError | |
| def decode(self, x): | |
| raise NotImplementedError | |
 
			
