Spaces:
Running
on
Zero
Running
on
Zero
from typing import Dict, Tuple | |
import torch | |
from modules.Device import Device | |
from modules.Utilities import util | |
class LatentFormat: | |
"""#### Base class for latent formats. | |
#### Attributes: | |
- `scale_factor` (float): The scale factor for the latent format. | |
#### Returns: | |
- `LatentFormat`: A latent format object. | |
""" | |
scale_factor: float = 1.0 | |
latent_channels: int = 4 | |
def process_in(self, latent: torch.Tensor) -> torch.Tensor: | |
"""#### Process the latent input, by multiplying it by the scale factor. | |
#### Args: | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The processed latent tensor. | |
""" | |
return latent * self.scale_factor | |
def process_out(self, latent: torch.Tensor) -> torch.Tensor: | |
"""#### Process the latent output, by dividing it by the scale factor. | |
#### Args: | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The processed latent tensor. | |
""" | |
return latent / self.scale_factor | |
class SD15(LatentFormat): | |
"""#### SD15 latent format. | |
#### Args: | |
- `LatentFormat` (LatentFormat): The base latent format class. | |
""" | |
latent_channels: int = 4 | |
def __init__(self, scale_factor: float = 0.18215): | |
"""#### Initialize the SD15 latent format. | |
#### Args: | |
- `scale_factor` (float, optional): The scale factor. Defaults to 0.18215. | |
""" | |
self.scale_factor = scale_factor | |
self.latent_rgb_factors = [ | |
# R G B | |
[0.3512, 0.2297, 0.3227], | |
[0.3250, 0.4974, 0.2350], | |
[-0.2829, 0.1762, 0.2721], | |
[-0.2120, -0.2616, -0.7177], | |
] | |
self.taesd_decoder_name = "taesd_decoder" | |
class SD3(LatentFormat): | |
latent_channels = 16 | |
def __init__(self): | |
"""#### Initialize the SD3 latent format.""" | |
self.scale_factor = 1.5305 | |
self.shift_factor = 0.0609 | |
self.latent_rgb_factors = [ | |
[-0.0645, 0.0177, 0.1052], | |
[0.0028, 0.0312, 0.0650], | |
[0.1848, 0.0762, 0.0360], | |
[0.0944, 0.0360, 0.0889], | |
[0.0897, 0.0506, -0.0364], | |
[-0.0020, 0.1203, 0.0284], | |
[0.0855, 0.0118, 0.0283], | |
[-0.0539, 0.0658, 0.1047], | |
[-0.0057, 0.0116, 0.0700], | |
[-0.0412, 0.0281, -0.0039], | |
[0.1106, 0.1171, 0.1220], | |
[-0.0248, 0.0682, -0.0481], | |
[0.0815, 0.0846, 0.1207], | |
[-0.0120, -0.0055, -0.0867], | |
[-0.0749, -0.0634, -0.0456], | |
[-0.1418, -0.1457, -0.1259], | |
] | |
self.taesd_decoder_name = "taesd3_decoder" | |
def process_in(self, latent: torch.Tensor) -> torch.Tensor: | |
"""#### Process the latent input, by multiplying it by the scale factor and subtracting the shift factor. | |
#### Args: | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The processed latent tensor. | |
""" | |
return (latent - self.shift_factor) * self.scale_factor | |
def process_out(self, latent: torch.Tensor) -> torch.Tensor: | |
"""#### Process the latent output, by dividing it by the scale factor and adding the shift factor. | |
#### Args: | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The processed latent tensor. | |
""" | |
return (latent / self.scale_factor) + self.shift_factor | |
class Flux1(SD3): | |
latent_channels = 16 | |
def __init__(self): | |
"""#### Initialize the Flux1 latent format.""" | |
self.scale_factor = 0.3611 | |
self.shift_factor = 0.1159 | |
self.latent_rgb_factors = [ | |
[-0.0404, 0.0159, 0.0609], | |
[0.0043, 0.0298, 0.0850], | |
[0.0328, -0.0749, -0.0503], | |
[-0.0245, 0.0085, 0.0549], | |
[0.0966, 0.0894, 0.0530], | |
[0.0035, 0.0399, 0.0123], | |
[0.0583, 0.1184, 0.1262], | |
[-0.0191, -0.0206, -0.0306], | |
[-0.0324, 0.0055, 0.1001], | |
[0.0955, 0.0659, -0.0545], | |
[-0.0504, 0.0231, -0.0013], | |
[0.0500, -0.0008, -0.0088], | |
[0.0982, 0.0941, 0.0976], | |
[-0.1233, -0.0280, -0.0897], | |
[-0.0005, -0.0530, -0.0020], | |
[-0.1273, -0.0932, -0.0680], | |
] | |
self.taesd_decoder_name = "taef1_decoder" | |
def process_in(self, latent: torch.Tensor) -> torch.Tensor: | |
"""#### Process the latent input, by multiplying it by the scale factor and subtracting the shift factor. | |
#### Args: | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The processed latent tensor. | |
""" | |
return (latent - self.shift_factor) * self.scale_factor | |
def process_out(self, latent: torch.Tensor) -> torch.Tensor: | |
"""#### Process the latent output, by dividing it by the scale factor and adding the shift factor. | |
#### Args: | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The processed latent tensor. | |
""" | |
return (latent / self.scale_factor) + self.shift_factor | |
class EmptyLatentImage: | |
"""#### A class to generate an empty latent image. | |
#### Args: | |
- `Device` (Device): The device to use for the latent image. | |
""" | |
def __init__(self): | |
"""#### Initialize the EmptyLatentImage class.""" | |
self.device = Device.intermediate_device() | |
def generate( | |
self, width: int, height: int, batch_size: int = 1 | |
) -> Tuple[Dict[str, torch.Tensor]]: | |
"""#### Generate an empty latent image | |
#### Args: | |
- `width` (int): The width of the latent image. | |
- `height` (int): The height of the latent image. | |
- `batch_size` (int, optional): The batch size. Defaults to 1. | |
#### Returns: | |
- `Tuple[Dict[str, torch.Tensor]]`: The generated latent image. | |
""" | |
latent = torch.zeros( | |
[batch_size, 4, height // 8, width // 8], device=self.device | |
) | |
return ({"samples": latent},) | |
def fix_empty_latent_channels(model, latent_image): | |
"""#### Fix the empty latent image channels. | |
#### Args: | |
- `model` (Model): The model object. | |
- `latent_image` (torch.Tensor): The latent image. | |
#### Returns: | |
- `torch.Tensor`: The fixed latent image. | |
""" | |
latent_channels = model.get_model_object( | |
"latent_format" | |
).latent_channels # Resize the empty latent image so it has the right number of channels | |
if ( | |
latent_channels != latent_image.shape[1] | |
and torch.count_nonzero(latent_image) == 0 | |
): | |
latent_image = util.repeat_to_batch_size(latent_image, latent_channels, dim=1) | |
return latent_image |