Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
from typing import Dict, Tuple
import torch
from modules.Device import Device
from modules.Utilities import util
class LatentFormat:
"""#### Base class for latent formats.
#### Attributes:
- `scale_factor` (float): The scale factor for the latent format.
#### Returns:
- `LatentFormat`: A latent format object.
"""
scale_factor: float = 1.0
latent_channels: int = 4
def process_in(self, latent: torch.Tensor) -> torch.Tensor:
"""#### Process the latent input, by multiplying it by the scale factor.
#### Args:
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The processed latent tensor.
"""
return latent * self.scale_factor
def process_out(self, latent: torch.Tensor) -> torch.Tensor:
"""#### Process the latent output, by dividing it by the scale factor.
#### Args:
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The processed latent tensor.
"""
return latent / self.scale_factor
class SD15(LatentFormat):
"""#### SD15 latent format.
#### Args:
- `LatentFormat` (LatentFormat): The base latent format class.
"""
latent_channels: int = 4
def __init__(self, scale_factor: float = 0.18215):
"""#### Initialize the SD15 latent format.
#### Args:
- `scale_factor` (float, optional): The scale factor. Defaults to 0.18215.
"""
self.scale_factor = scale_factor
self.latent_rgb_factors = [
# R G B
[0.3512, 0.2297, 0.3227],
[0.3250, 0.4974, 0.2350],
[-0.2829, 0.1762, 0.2721],
[-0.2120, -0.2616, -0.7177],
]
self.taesd_decoder_name = "taesd_decoder"
class SD3(LatentFormat):
latent_channels = 16
def __init__(self):
"""#### Initialize the SD3 latent format."""
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.latent_rgb_factors = [
[-0.0645, 0.0177, 0.1052],
[0.0028, 0.0312, 0.0650],
[0.1848, 0.0762, 0.0360],
[0.0944, 0.0360, 0.0889],
[0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259],
]
self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent: torch.Tensor) -> torch.Tensor:
"""#### Process the latent input, by multiplying it by the scale factor and subtracting the shift factor.
#### Args:
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The processed latent tensor.
"""
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent: torch.Tensor) -> torch.Tensor:
"""#### Process the latent output, by dividing it by the scale factor and adding the shift factor.
#### Args:
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The processed latent tensor.
"""
return (latent / self.scale_factor) + self.shift_factor
class Flux1(SD3):
latent_channels = 16
def __init__(self):
"""#### Initialize the Flux1 latent format."""
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors = [
[-0.0404, 0.0159, 0.0609],
[0.0043, 0.0298, 0.0850],
[0.0328, -0.0749, -0.0503],
[-0.0245, 0.0085, 0.0549],
[0.0966, 0.0894, 0.0530],
[0.0035, 0.0399, 0.0123],
[0.0583, 0.1184, 0.1262],
[-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001],
[0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013],
[0.0500, -0.0008, -0.0088],
[0.0982, 0.0941, 0.0976],
[-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680],
]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent: torch.Tensor) -> torch.Tensor:
"""#### Process the latent input, by multiplying it by the scale factor and subtracting the shift factor.
#### Args:
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The processed latent tensor.
"""
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent: torch.Tensor) -> torch.Tensor:
"""#### Process the latent output, by dividing it by the scale factor and adding the shift factor.
#### Args:
- `latent` (torch.Tensor): The latent tensor.
#### Returns:
- `torch.Tensor`: The processed latent tensor.
"""
return (latent / self.scale_factor) + self.shift_factor
class EmptyLatentImage:
"""#### A class to generate an empty latent image.
#### Args:
- `Device` (Device): The device to use for the latent image.
"""
def __init__(self):
"""#### Initialize the EmptyLatentImage class."""
self.device = Device.intermediate_device()
def generate(
self, width: int, height: int, batch_size: int = 1
) -> Tuple[Dict[str, torch.Tensor]]:
"""#### Generate an empty latent image
#### Args:
- `width` (int): The width of the latent image.
- `height` (int): The height of the latent image.
- `batch_size` (int, optional): The batch size. Defaults to 1.
#### Returns:
- `Tuple[Dict[str, torch.Tensor]]`: The generated latent image.
"""
latent = torch.zeros(
[batch_size, 4, height // 8, width // 8], device=self.device
)
return ({"samples": latent},)
def fix_empty_latent_channels(model, latent_image):
"""#### Fix the empty latent image channels.
#### Args:
- `model` (Model): The model object.
- `latent_image` (torch.Tensor): The latent image.
#### Returns:
- `torch.Tensor`: The fixed latent image.
"""
latent_channels = model.get_model_object(
"latent_format"
).latent_channels # Resize the empty latent image so it has the right number of channels
if (
latent_channels != latent_image.shape[1]
and torch.count_nonzero(latent_image) == 0
):
latent_image = util.repeat_to_batch_size(latent_image, latent_channels, dim=1)
return latent_image