Spaces:
Runtime error
Runtime error
File size: 1,237 Bytes
810fa8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from einops import rearrange
from torch import nn
from diffusers.models import AutoencoderKL
class HFVAEWrapper(nn.Module):
def __init__(self, hfvae='mse'):
super(HFVAEWrapper, self).__init__()
self.vae = AutoencoderKL.from_pretrained(hfvae, cache_dir='cache_dir')
def encode(self, x): # b c h w
t = 0
if x.ndim == 5:
b, c, t, h, w = x.shape
x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
x = self.vae.encode(x).latent_dist.sample().mul_(0.18215)
if t != 0:
x = rearrange(x, '(b t) c h w -> b c t h w', t=t).contiguous()
return x
def decode(self, x):
t = 0
if x.ndim == 5:
b, c, t, h, w = x.shape
x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
x = self.vae.decode(x / 0.18215).sample
if t != 0:
x = rearrange(x, '(b t) c h w -> b t c h w', t=t).contiguous()
return x
class SDVAEWrapper(nn.Module):
def __init__(self):
super(SDVAEWrapper, self).__init__()
raise NotImplementedError
def encode(self, x): # b c h w
raise NotImplementedError
def decode(self, x):
raise NotImplementedError |