|
from diffusers import AutoencoderKL |
|
from typing import Optional, Union |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput |
|
from diffusers.models.autoencoders.vae import DecoderOutput |
|
|
|
|
|
class PixelMixer(nn.Module): |
|
def __init__(self, in_channels, downscale_factor): |
|
super(PixelMixer, self).__init__() |
|
self.downscale_factor = downscale_factor |
|
self.in_channels = in_channels |
|
|
|
def forward(self, x): |
|
latent = self.encode(x) |
|
out = self.decode(latent) |
|
return out |
|
|
|
def encode(self, x): |
|
return torch.nn.PixelUnshuffle(self.downscale_factor)(x) |
|
|
|
def decode(self, x): |
|
return torch.nn.PixelShuffle(self.downscale_factor)(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
in_channels = 3 |
|
out_channels = 3 |
|
down_block_types = ('1', '1', |
|
'1', '1') |
|
up_block_types = ('1', '1', |
|
'1', '1') |
|
block_out_channels = (1, 1, 1, 1) |
|
latent_channels = 192 |
|
norm_num_groups = 32 |
|
sample_size = 512 |
|
|
|
|
|
scaling_factor = 1.8 |
|
shift_factor = -0.123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(cls, x): |
|
return getattr(cls, x) |
|
|
|
|
|
class AutoencoderPixelMixer(nn.Module): |
|
|
|
def __init__(self, in_channels=3, downscale_factor=8): |
|
super().__init__() |
|
self.mixer = PixelMixer(in_channels, downscale_factor) |
|
self._dtype = torch.float32 |
|
self._device = torch.device( |
|
"cuda" if torch.cuda.is_available() else "cpu") |
|
self.config = Config() |
|
|
|
if downscale_factor == 8: |
|
|
|
self.config.block_out_channels = (1, 1, 1, 1) |
|
self.config.latent_channels = 192 |
|
|
|
elif downscale_factor == 16: |
|
|
|
self.config.block_out_channels = (1, 1, 1, 1, 1) |
|
self.config.latent_channels = 768 |
|
else: |
|
raise ValueError( |
|
f"downscale_factor {downscale_factor} not supported") |
|
|
|
@property |
|
def dtype(self): |
|
return self._dtype |
|
|
|
@dtype.setter |
|
def dtype(self, value): |
|
self._dtype = value |
|
|
|
@property |
|
def device(self): |
|
return self._device |
|
|
|
@device.setter |
|
def device(self, value): |
|
self._device = value |
|
|
|
|
|
def to(self, *args, **kwargs): |
|
|
|
if 'dtype' in kwargs: |
|
self._dtype = kwargs['dtype'] |
|
if 'device' in kwargs: |
|
self._device = kwargs['device'] |
|
return super().to(*args, **kwargs) |
|
|
|
def enable_xformers_memory_efficient_attention(self): |
|
pass |
|
|
|
|
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: |
|
|
|
h = self.mixer.encode(x) |
|
|
|
|
|
|
|
|
|
if not return_dict: |
|
return (h,) |
|
|
|
class FakeDist: |
|
def __init__(self, x): |
|
self._sample = x |
|
|
|
def sample(self): |
|
return self._sample |
|
|
|
return AutoencoderKLOutput(latent_dist=FakeDist(h)) |
|
|
|
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: |
|
dec = self.mixer.decode(z) |
|
|
|
if not return_dict: |
|
return (dec,) |
|
|
|
return DecoderOutput(sample=dec) |
|
|
|
|
|
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: |
|
decoded = self._decode(z).sample |
|
|
|
if not return_dict: |
|
return (decoded,) |
|
|
|
return DecoderOutput(sample=decoded) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
pass |
|
|
|
def enable_tiling(self, use_tiling: bool = True): |
|
pass |
|
|
|
def disable_tiling(self): |
|
pass |
|
|
|
def enable_slicing(self): |
|
pass |
|
|
|
def disable_slicing(self): |
|
pass |
|
|
|
def set_use_memory_efficient_attention_xformers(self, value: bool = True): |
|
pass |
|
|
|
def forward( |
|
self, |
|
sample: torch.FloatTensor, |
|
sample_posterior: bool = False, |
|
return_dict: bool = True, |
|
generator: Optional[torch.Generator] = None, |
|
) -> Union[DecoderOutput, torch.FloatTensor]: |
|
|
|
x = sample |
|
posterior = self.encode(x).latent_dist |
|
if sample_posterior: |
|
z = posterior.sample(generator=generator) |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z).sample |
|
|
|
if not return_dict: |
|
return (dec,) |
|
|
|
return DecoderOutput(sample=dec) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
import os |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
user_path = os.path.expanduser('~') |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
dtype = torch.float32 |
|
|
|
input_path = os.path.join(user_path, "Pictures/test/test.jpg") |
|
output_path = os.path.join(user_path, "Pictures/test/test.jpg") |
|
img = Image.open(input_path) |
|
img_tensor = transforms.ToTensor()(img) |
|
img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) |
|
print("input_shape: ", list(img_tensor.shape)) |
|
vae = PixelMixer(in_channels=3, downscale_factor=8) |
|
latent = vae.encode(img_tensor) |
|
print("latent_shape: ", list(latent.shape)) |
|
out_tensor = vae.decode(latent) |
|
print("out_shape: ", list(out_tensor.shape)) |
|
|
|
mse_loss = nn.MSELoss() |
|
mse = mse_loss(img_tensor, out_tensor) |
|
print("roundtrip_loss: ", mse.item()) |
|
out_img = transforms.ToPILImage()(out_tensor.squeeze(0)) |
|
out_img.save(output_path) |
|
|