Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from typing import Dict, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from modules.Model import ModelPatcher | |
import torch.nn as nn | |
from modules.Attention import Attention | |
from modules.AutoEncoders import ResBlock | |
from modules.Device import Device | |
from modules.Utilities import util | |
from modules.cond import cast | |
class DiagonalGaussianDistribution(object): | |
"""#### Represents a diagonal Gaussian distribution parameterized by mean and log-variance. | |
#### Attributes: | |
- `parameters` (torch.Tensor): The concatenated mean and log-variance of the distribution. | |
- `mean` (torch.Tensor): The mean of the distribution. | |
- `logvar` (torch.Tensor): The log-variance of the distribution, clamped between -30.0 and 20.0. | |
- `std` (torch.Tensor): The standard deviation of the distribution, computed as exp(0.5 * logvar). | |
- `var` (torch.Tensor): The variance of the distribution, computed as exp(logvar). | |
- `deterministic` (bool): If True, the distribution is deterministic. | |
#### Methods: | |
- `sample() -> torch.Tensor`: | |
Samples from the distribution using the reparameterization trick. | |
- `kl(other: DiagonalGaussianDistribution = None) -> torch.Tensor`: | |
Computes the Kullback-Leibler divergence between this distribution and a standard normal distribution. | |
If `other` is provided, computes the KL divergence between this distribution and `other`. | |
""" | |
def __init__(self, parameters: torch.Tensor, deterministic: bool = False): | |
self.parameters = parameters | |
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
self.deterministic = deterministic | |
self.std = torch.exp(0.5 * self.logvar) | |
self.var = torch.exp(self.logvar) | |
def sample(self) -> torch.Tensor: | |
"""#### Samples from the distribution using the reparameterization trick. | |
#### Returns: | |
- `torch.Tensor`: A sample from the distribution. | |
""" | |
x = self.mean + self.std * torch.randn(self.mean.shape).to( | |
device=self.parameters.device | |
) | |
return x | |
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: | |
"""#### Computes the Kullback-Leibler divergence between this distribution and a standard normal distribution. | |
If `other` is provided, computes the KL divergence between this distribution and `other`. | |
#### Args: | |
- `other` (DiagonalGaussianDistribution, optional): Another distribution to compute the KL divergence with. | |
#### Returns: | |
- `torch.Tensor`: The KL divergence. | |
""" | |
return 0.5 * torch.sum( | |
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, | |
dim=[1, 2, 3], | |
) | |
class DiagonalGaussianRegularizer(torch.nn.Module): | |
"""#### Regularizer for diagonal Gaussian distributions.""" | |
def __init__(self, sample: bool = True): | |
"""#### Initialize the regularizer. | |
#### Args: | |
- `sample` (bool, optional): Whether to sample from the distribution. Defaults to True. | |
""" | |
super().__init__() | |
self.sample = sample | |
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: | |
"""#### Forward pass for the regularizer. | |
#### Args: | |
- `z` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `Tuple[torch.Tensor, dict]`: The regularized tensor and a log dictionary. | |
""" | |
log = dict() | |
posterior = DiagonalGaussianDistribution(z) | |
if self.sample: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
kl_loss = posterior.kl() | |
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] | |
log["kl_loss"] = kl_loss | |
return z, log | |
class AutoencodingEngine(nn.Module): | |
"""#### Class representing an autoencoding engine.""" | |
def __init__(self, encoder: nn.Module, decoder: nn.Module, regularizer: nn.Module, flux: bool = False): | |
"""#### Initialize the autoencoding engine. | |
#### Args: | |
- `encoder` (nn.Module): The encoder module. | |
- `decoder` (nn.Module): The decoder module. | |
- `regularizer` (nn.Module): The regularizer module. | |
""" | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.regularization = regularizer | |
if not flux: | |
self.post_quant_conv = cast.disable_weight_init.Conv2d(4, 4, 1) | |
self.quant_conv = cast.disable_weight_init.Conv2d(8, 8, 1) | |
def get_last_layer(self): | |
"""#### Get the last layer of the decoder. | |
Returns: | |
- `nn.Module`: The last layer of the decoder. | |
""" | |
return self.decoder.get_last_layer() | |
def decode(self, z: torch.Tensor, flux:bool = False, **kwargs) -> torch.Tensor: | |
"""#### Decode the latent tensor. | |
#### Args: | |
- `z` (torch.Tensor): The latent tensor. | |
- `decoder_kwargs` (dict): Additional arguments for the decoder. | |
#### Returns: | |
- `torch.Tensor`: The decoded tensor. | |
""" | |
if flux: | |
x = self.decoder(z, **kwargs) | |
return x | |
dec = self.post_quant_conv(z) | |
dec = self.decoder(dec, **kwargs) | |
return dec | |
def encode( | |
self, | |
x: torch.Tensor, | |
return_reg_log: bool = False, | |
unregularized: bool = False, | |
flux: bool = False, | |
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: | |
"""#### Encode the input tensor. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `return_reg_log` (bool, optional): Whether to return the regularization log. Defaults to False. | |
#### Returns: | |
- `Union[torch.Tensor, Tuple[torch.Tensor, dict]]`: The encoded tensor and optionally the regularization log. | |
""" | |
z = self.encoder(x) | |
if not flux: | |
z = self.quant_conv(z) | |
if unregularized: | |
return z, dict() | |
z, reg_log = self.regularization(z) | |
if return_reg_log: | |
return z, reg_log | |
return z | |
ops = cast.disable_weight_init | |
if Device.xformers_enabled_vae(): | |
pass | |
def nonlinearity(x: torch.Tensor) -> torch.Tensor: | |
"""#### Apply the swish nonlinearity. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
return x * torch.sigmoid(x) | |
class Upsample(nn.Module): | |
"""#### Class representing an upsample layer.""" | |
def __init__(self, in_channels: int, with_conv: bool): | |
"""#### Initialize the upsample layer. | |
#### Args: | |
- `in_channels` (int): The number of input channels. | |
- `with_conv` (bool): Whether to use convolution. | |
""" | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
self.conv = ops.Conv2d( | |
in_channels, in_channels, kernel_size=3, stride=1, padding=1 | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the upsample layer. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
if self.with_conv: | |
x = self.conv(x) | |
return x | |
class Downsample(nn.Module): | |
"""#### Class representing a downsample layer.""" | |
def __init__(self, in_channels: int, with_conv: bool): | |
"""#### Initialize the downsample layer. | |
#### Args: | |
- `in_channels` (int): The number of input channels. | |
- `with_conv` (bool): Whether to use convolution. | |
""" | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
# no asymmetric padding in torch conv, must do it ourselves | |
self.conv = ops.Conv2d( | |
in_channels, in_channels, kernel_size=3, stride=2, padding=0 | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the downsample layer. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
pad = (0, 1, 0, 1) | |
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
x = self.conv(x) | |
return x | |
class Encoder(nn.Module): | |
"""#### Class representing an encoder.""" | |
def __init__( | |
self, | |
*, | |
ch: int, | |
out_ch: int, | |
ch_mult: Tuple[int, ...] = (1, 2, 4, 8), | |
num_res_blocks: int, | |
attn_resolutions: Tuple[int, ...], | |
dropout: float = 0.0, | |
resamp_with_conv: bool = True, | |
in_channels: int, | |
resolution: int, | |
z_channels: int, | |
double_z: bool = True, | |
use_linear_attn: bool = False, | |
attn_type: str = "vanilla", | |
**ignore_kwargs, | |
): | |
"""#### Initialize the encoder. | |
#### Args: | |
- `ch` (int): The base number of channels. | |
- `out_ch` (int): The number of output channels. | |
- `ch_mult` (Tuple[int, ...], optional): Channel multiplier at each resolution. Defaults to (1, 2, 4, 8). | |
- `num_res_blocks` (int): The number of residual blocks. | |
- `attn_resolutions` (Tuple[int, ...]): The resolutions at which to apply attention. | |
- `dropout` (float, optional): The dropout rate. Defaults to 0.0. | |
- `resamp_with_conv` (bool, optional): Whether to use convolution for resampling. Defaults to True. | |
- `in_channels` (int): The number of input channels. | |
- `resolution` (int): The resolution of the input. | |
- `z_channels` (int): The number of latent channels. | |
- `double_z` (bool, optional): Whether to double the latent channels. Defaults to True. | |
- `use_linear_attn` (bool, optional): Whether to use linear attention. Defaults to False. | |
- `attn_type` (str, optional): The type of attention. Defaults to "vanilla". | |
""" | |
super().__init__() | |
if use_linear_attn: | |
attn_type = "linear" | |
self.ch = ch | |
self.temb_ch = 0 | |
self.num_resolutions = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.resolution = resolution | |
self.in_channels = in_channels | |
# downsampling | |
self.conv_in = ops.Conv2d( | |
in_channels, self.ch, kernel_size=3, stride=1, padding=1 | |
) | |
curr_res = resolution | |
in_ch_mult = (1,) + tuple(ch_mult) | |
self.in_ch_mult = in_ch_mult | |
self.down = nn.ModuleList() | |
for i_level in range(self.num_resolutions): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_in = ch * in_ch_mult[i_level] | |
block_out = ch * ch_mult[i_level] | |
for i_block in range(self.num_res_blocks): | |
block.append( | |
ResBlock.ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_out, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
) | |
block_in = block_out | |
down = nn.Module() | |
down.block = block | |
down.attn = attn | |
if i_level != self.num_resolutions - 1: | |
down.downsample = Downsample(block_in, resamp_with_conv) | |
curr_res = curr_res // 2 | |
self.down.append(down) | |
# middle | |
self.mid = nn.Module() | |
self.mid.block_1 = ResBlock.ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
self.mid.attn_1 = Attention.make_attn(block_in, attn_type=attn_type) | |
self.mid.block_2 = ResBlock.ResnetBlock( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
# end | |
self.norm_out = Attention.Normalize(block_in) | |
self.conv_out = ops.Conv2d( | |
block_in, | |
2 * z_channels if double_z else z_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
self._device = torch.device("cpu") | |
self._dtype = torch.float32 | |
def to(self, device=None, dtype=None): | |
"""#### Move the encoder to a device and data type. | |
#### Args: | |
- `device` (torch.device, optional): The device to move to. Defaults to None. | |
- `dtype` (torch.dtype, optional): The data type to move to. Defaults to None. | |
#### Returns: | |
- `nn.Module`: The encoder. | |
""" | |
if device is not None: | |
self._device = device | |
if dtype is not None: | |
self._dtype = dtype | |
return super().to(device=device, dtype=dtype) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the encoder. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The encoded tensor. | |
""" | |
if x.device != self._device or x.dtype != self._dtype: | |
self.to(device=x.device, dtype=x.dtype) | |
# timestep embedding | |
temb = None | |
# downsampling | |
h = self.conv_in(x) | |
for i_level in range(self.num_resolutions): | |
for i_block in range(self.num_res_blocks): | |
h = self.down[i_level].block[i_block](h, temb) | |
if len(self.down[i_level].attn) > 0: | |
h = self.down[i_level].attn[i_block](h) | |
if i_level != self.num_resolutions - 1: | |
h = self.down[i_level].downsample(h) | |
# middle | |
h = self.mid.block_1(h, temb) | |
h = self.mid.attn_1(h) | |
h = self.mid.block_2(h, temb) | |
# end | |
h = self.norm_out(h) | |
h = nonlinearity(h) | |
h = self.conv_out(h) | |
return h | |
class Decoder(nn.Module): | |
"""#### Class representing a decoder.""" | |
def __init__( | |
self, | |
*, | |
ch: int, | |
out_ch: int, | |
ch_mult: Tuple[int, ...] = (1, 2, 4, 8), | |
num_res_blocks: int, | |
attn_resolutions: Tuple[int, ...], | |
dropout: float = 0.0, | |
resamp_with_conv: bool = True, | |
in_channels: int, | |
resolution: int, | |
z_channels: int, | |
give_pre_end: bool = False, | |
tanh_out: bool = False, | |
use_linear_attn: bool = False, | |
conv_out_op: nn.Module = ops.Conv2d, | |
resnet_op: nn.Module = ResBlock.ResnetBlock, | |
attn_op: nn.Module = Attention.AttnBlock, | |
**ignorekwargs, | |
): | |
"""#### Initialize the decoder. | |
#### Args: | |
- `ch` (int): The base number of channels. | |
- `out_ch` (int): The number of output channels. | |
- `ch_mult` (Tuple[int, ...], optional): Channel multiplier at each resolution. Defaults to (1, 2, 4, 8). | |
- `num_res_blocks` (int): The number of residual blocks. | |
- `attn_resolutions` (Tuple[int, ...]): The resolutions at which to apply attention. | |
- `dropout` (float, optional): The dropout rate. Defaults to 0.0. | |
- `resamp_with_conv` (bool, optional): Whether to use convolution for resampling. Defaults to True. | |
- `in_channels` (int): The number of input channels. | |
- `resolution` (int): The resolution of the input. | |
- `z_channels` (int): The number of latent channels. | |
- `give_pre_end` (bool, optional): Whether to give pre-end. Defaults to False. | |
- `tanh_out` (bool, optional): Whether to use tanh activation at the output. Defaults to False. | |
- `use_linear_attn` (bool, optional): Whether to use linear attention. Defaults to False. | |
- `conv_out_op` (nn.Module, optional): The convolution output operation. Defaults to ops.Conv2d. | |
- `resnet_op` (nn.Module, optional): The residual block operation. Defaults to ResBlock.ResnetBlock. | |
- `attn_op` (nn.Module, optional): The attention block operation. Defaults to Attention.AttnBlock. | |
""" | |
super().__init__() | |
self.ch = ch | |
self.temb_ch = 0 | |
self.num_resolutions = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.resolution = resolution | |
self.in_channels = in_channels | |
self.give_pre_end = give_pre_end | |
self.tanh_out = tanh_out | |
# compute in_ch_mult, block_in and curr_res at lowest res | |
(1,) + tuple(ch_mult) | |
block_in = ch * ch_mult[self.num_resolutions - 1] | |
curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
self.z_shape = (1, z_channels, curr_res, curr_res) | |
logging.debug( | |
"Working with z of shape {} = {} dimensions.".format( | |
self.z_shape, np.prod(self.z_shape) | |
) | |
) | |
# z to block_in | |
self.conv_in = ops.Conv2d( | |
z_channels, block_in, kernel_size=3, stride=1, padding=1 | |
) | |
# middle | |
self.mid = nn.Module() | |
self.mid.block_1 = resnet_op( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
self.mid.attn_1 = attn_op(block_in) | |
self.mid.block_2 = resnet_op( | |
in_channels=block_in, | |
out_channels=block_in, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
# upsampling | |
self.up = nn.ModuleList() | |
for i_level in reversed(range(self.num_resolutions)): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_out = ch * ch_mult[i_level] | |
for i_block in range(self.num_res_blocks + 1): | |
block.append( | |
resnet_op( | |
in_channels=block_in, | |
out_channels=block_out, | |
temb_channels=self.temb_ch, | |
dropout=dropout, | |
) | |
) | |
block_in = block_out | |
up = nn.Module() | |
up.block = block | |
up.attn = attn | |
if i_level != 0: | |
up.upsample = Upsample(block_in, resamp_with_conv) | |
curr_res = curr_res * 2 | |
self.up.insert(0, up) # prepend to get consistent order | |
# end | |
self.norm_out = Attention.Normalize(block_in) | |
self.conv_out = conv_out_op( | |
block_in, out_ch, kernel_size=3, stride=1, padding=1 | |
) | |
def forward(self, z: torch.Tensor, **kwargs) -> torch.Tensor: | |
"""#### Forward pass for the decoder. | |
#### Args: | |
- `z` (torch.Tensor): The input tensor. | |
- `**kwargs`: Additional arguments. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
# assert z.shape[1:] == self.z_shape[1:] | |
self.last_z_shape = z.shape | |
# timestep embedding | |
temb = None | |
# z to block_in | |
h = self.conv_in(z) | |
# middle | |
h = self.mid.block_1(h, temb, **kwargs) | |
h = self.mid.attn_1(h, **kwargs) | |
h = self.mid.block_2(h, temb, **kwargs) | |
# upsampling | |
for i_level in reversed(range(self.num_resolutions)): | |
for i_block in range(self.num_res_blocks + 1): | |
h = self.up[i_level].block[i_block](h, temb, **kwargs) | |
if i_level != 0: | |
h = self.up[i_level].upsample(h) | |
h = self.norm_out(h) | |
h = nonlinearity(h) | |
h = self.conv_out(h, **kwargs) | |
return h | |
class VAE: | |
"""#### Class representing a Variational Autoencoder (VAE).""" | |
def __init__( | |
self, | |
sd: Optional[dict] = None, | |
device: Optional[torch.device] = None, | |
config: Optional[dict] = None, | |
dtype: Optional[torch.dtype] = None, | |
flux: Optional[bool] = False, | |
): | |
"""#### Initialize the VAE. | |
#### Args: | |
- `sd` (dict, optional): The state dictionary. Defaults to None. | |
- `device` (torch.device, optional): The device to use. Defaults to None. | |
- `config` (dict, optional): The configuration dictionary. Defaults to None. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
""" | |
self.memory_used_encode = lambda shape, dtype: ( | |
1767 * shape[2] * shape[3] | |
) * Device.dtype_size( | |
dtype | |
) # These are for AutoencoderKL and need tweaking (should be lower) | |
self.memory_used_decode = lambda shape, dtype: ( | |
2178 * shape[2] * shape[3] * 64 | |
) * Device.dtype_size(dtype) | |
self.downscale_ratio = 8 | |
self.upscale_ratio = 8 | |
self.latent_channels = 4 | |
self.output_channels = 3 | |
self.process_input = lambda image: image * 2.0 - 1.0 | |
self.process_output = lambda image: torch.clamp( | |
(image + 1.0) / 2.0, min=0.0, max=1.0 | |
) | |
self.working_dtypes = [torch.bfloat16, torch.float32] | |
if config is None: | |
if "decoder.conv_in.weight" in sd: | |
# default SD1.x/SD2.x VAE parameters | |
ddconfig = { | |
"double_z": True, | |
"z_channels": 4, | |
"resolution": 256, | |
"in_channels": 3, | |
"out_ch": 3, | |
"ch": 128, | |
"ch_mult": [1, 2, 4, 4], | |
"num_res_blocks": 2, | |
"attn_resolutions": [], | |
"dropout": 0.0, | |
} | |
if ( | |
"encoder.down.2.downsample.conv.weight" not in sd | |
and "decoder.up.3.upsample.conv.weight" not in sd | |
): # Stable diffusion x4 upscaler VAE | |
ddconfig["ch_mult"] = [1, 2, 4] | |
self.downscale_ratio = 4 | |
self.upscale_ratio = 4 | |
self.latent_channels = ddconfig["z_channels"] = sd[ | |
"decoder.conv_in.weight" | |
].shape[1] | |
# Initialize model | |
self.first_stage_model = AutoencodingEngine( | |
Encoder(**ddconfig), | |
Decoder(**ddconfig), | |
DiagonalGaussianRegularizer(), | |
flux=flux | |
) | |
else: | |
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") | |
self.first_stage_model = None | |
return | |
self.first_stage_model = self.first_stage_model.eval() | |
m, u = self.first_stage_model.load_state_dict(sd, strict=False) | |
if len(m) > 0: | |
logging.warning("Missing VAE keys {}".format(m)) | |
if len(u) > 0: | |
logging.debug("Leftover VAE keys {}".format(u)) | |
if device is None: | |
device = Device.vae_device() | |
self.device = device | |
offload_device = Device.vae_offload_device() | |
if dtype is None: | |
dtype = Device.vae_dtype() | |
self.vae_dtype = dtype | |
self.first_stage_model.to(self.vae_dtype) | |
self.output_device = Device.intermediate_device() | |
self.patcher = ModelPatcher.ModelPatcher( | |
self.first_stage_model, | |
load_device=self.device, | |
offload_device=offload_device, | |
) | |
logging.debug( | |
"VAE load device: {}, offload device: {}, dtype: {}".format( | |
self.device, offload_device, self.vae_dtype | |
) | |
) | |
def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor: | |
"""#### Crop the input pixels to be compatible with the VAE. | |
#### Args: | |
- `pixels` (torch.Tensor): The input pixel tensor. | |
#### Returns: | |
- `torch.Tensor`: The cropped pixel tensor. | |
""" | |
(pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio | |
(pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio | |
return pixels | |
def decode(self, samples_in: torch.Tensor, flux:bool = False) -> torch.Tensor: | |
"""#### Decode the latent samples to pixel samples. | |
#### Args: | |
- `samples_in` (torch.Tensor): The input latent samples. | |
#### Returns: | |
- `torch.Tensor`: The decoded pixel samples. | |
""" | |
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) | |
Device.load_models_gpu([self.patcher], memory_required=memory_used) | |
free_memory = Device.get_free_memory(self.device) | |
batch_number = int(free_memory / memory_used) | |
batch_number = max(1, batch_number) | |
pixel_samples = torch.empty( | |
( | |
samples_in.shape[0], | |
3, | |
round(samples_in.shape[2] * self.upscale_ratio), | |
round(samples_in.shape[3] * self.upscale_ratio), | |
), | |
device=self.output_device, | |
) | |
for x in range(0, samples_in.shape[0], batch_number): | |
samples = ( | |
samples_in[x : x + batch_number].to(self.vae_dtype).to(self.device) | |
) | |
pixel_samples[x : x + batch_number] = self.process_output( | |
self.first_stage_model.decode(samples, flux=flux).to(self.output_device).float() | |
) | |
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1) | |
return pixel_samples | |
def encode(self, pixel_samples: torch.Tensor, flux:bool = False) -> torch.Tensor: | |
"""#### Encode the pixel samples to latent samples. | |
#### Args: | |
- `pixel_samples` (torch.Tensor): The input pixel samples. | |
#### Returns: | |
- `torch.Tensor`: The encoded latent samples. | |
""" | |
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) | |
pixel_samples = pixel_samples.movedim(-1, 1) | |
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) | |
Device.load_models_gpu([self.patcher], memory_required=memory_used) | |
free_memory = Device.get_free_memory(self.device) | |
batch_number = int(free_memory / memory_used) | |
batch_number = max(1, batch_number) | |
samples = torch.empty( | |
( | |
pixel_samples.shape[0], | |
self.latent_channels, | |
round(pixel_samples.shape[2] // self.downscale_ratio), | |
round(pixel_samples.shape[3] // self.downscale_ratio), | |
), | |
device=self.output_device, | |
) | |
for x in range(0, pixel_samples.shape[0], batch_number): | |
pixels_in = ( | |
self.process_input(pixel_samples[x : x + batch_number]) | |
.to(self.vae_dtype) | |
.to(self.device) | |
) | |
samples[x : x + batch_number] = ( | |
self.first_stage_model.encode(pixels_in, flux=flux).to(self.output_device).float() | |
) | |
return samples | |
def get_sd(self): | |
"""#### Get the state dictionary. | |
#### Returns: | |
- `dict`: The state dictionary. | |
""" | |
return self.first_stage_model.state_dict() | |
class VAEDecode: | |
"""#### Class for decoding VAE samples.""" | |
def decode(self, vae: VAE, samples: dict, flux:bool = False) -> Tuple[torch.Tensor]: | |
"""#### Decode the VAE samples. | |
#### Args: | |
- `vae` (VAE): The VAE instance. | |
- `samples` (dict): The samples dictionary. | |
#### Returns: | |
- `Tuple[torch.Tensor]`: The decoded samples. | |
""" | |
return (vae.decode(samples["samples"], flux=flux),) | |
class VAEEncode: | |
"""#### Class for encoding VAE samples.""" | |
def encode(self, vae: VAE, pixels: torch.Tensor, flux:bool = False) -> Tuple[dict]: | |
"""#### Encode the VAE samples. | |
#### Args: | |
- `vae` (VAE): The VAE instance. | |
- `pixels` (torch.Tensor): The input pixel tensor. | |
#### Returns: | |
- `Tuple[dict]`: The encoded samples dictionary. | |
""" | |
t = vae.encode(pixels[:, :, :, :3], flux=flux) | |
return ({"samples": t},) | |
class VAELoader: | |
"""#### Class for loading VAEs.""" | |
# TODO: scale factor? | |
def load_vae(self, vae_name): | |
"""#### Load the VAE. | |
#### Args: | |
- `vae_name`: The name of the VAE. | |
#### Returns: | |
- `Tuple[VAE]`: The VAE instance. | |
""" | |
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: | |
sd = self.load_taesd(vae_name) | |
else: | |
vae_path = "./_internal/vae/" + vae_name | |
sd = util.load_torch_file(vae_path) | |
vae = VAE(sd=sd) | |
return (vae,) | |