Spaces:
Runtime error
Runtime error
from einops import rearrange | |
from torch import nn | |
from diffusers.models import AutoencoderKL | |
class HFVAEWrapper(nn.Module): | |
def __init__(self, hfvae='mse'): | |
super(HFVAEWrapper, self).__init__() | |
self.vae = AutoencoderKL.from_pretrained(hfvae, cache_dir='cache_dir') | |
def encode(self, x): # b c h w | |
t = 0 | |
if x.ndim == 5: | |
b, c, t, h, w = x.shape | |
x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous() | |
x = self.vae.encode(x).latent_dist.sample().mul_(0.18215) | |
if t != 0: | |
x = rearrange(x, '(b t) c h w -> b c t h w', t=t).contiguous() | |
return x | |
def decode(self, x): | |
t = 0 | |
if x.ndim == 5: | |
b, c, t, h, w = x.shape | |
x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous() | |
x = self.vae.decode(x / 0.18215).sample | |
if t != 0: | |
x = rearrange(x, '(b t) c h w -> b t c h w', t=t).contiguous() | |
return x | |
class SDVAEWrapper(nn.Module): | |
def __init__(self): | |
super(SDVAEWrapper, self).__init__() | |
raise NotImplementedError | |
def encode(self, x): # b c h w | |
raise NotImplementedError | |
def decode(self, x): | |
raise NotImplementedError |