Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import logging | |
import gguf | |
import torch | |
from modules.Device import Device | |
from modules.Model import ModelPatcher | |
from modules.Utilities import util | |
from modules.clip import Clip | |
from modules.cond import cast | |
# Constants for torch-compatible quantization types | |
TORCH_COMPATIBLE_QTYPES = { | |
None, | |
gguf.GGMLQuantizationType.F32, | |
gguf.GGMLQuantizationType.F16, | |
} | |
def is_torch_compatible(tensor: torch.Tensor) -> bool: | |
"""#### Check if a tensor is compatible with PyTorch operations. | |
#### Args: | |
- `tensor` (torch.Tensor): The tensor to check. | |
#### Returns: | |
- `bool`: Whether the tensor is torch-compatible. | |
""" | |
return ( | |
tensor is None | |
or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES | |
) | |
def is_quantized(tensor: torch.Tensor) -> bool: | |
"""#### Check if a tensor is quantized. | |
#### Args: | |
- `tensor` (torch.Tensor): The tensor to check. | |
#### Returns: | |
- `bool`: Whether the tensor is quantized. | |
""" | |
return not is_torch_compatible(tensor) | |
def dequantize( | |
data: torch.Tensor, | |
qtype: gguf.GGMLQuantizationType, | |
oshape: tuple, | |
dtype: torch.dtype = None, | |
) -> torch.Tensor: | |
"""#### Dequantize tensor back to usable shape/dtype. | |
#### Args: | |
- `data` (torch.Tensor): The quantized data. | |
- `qtype` (gguf.GGMLQuantizationType): The quantization type. | |
- `oshape` (tuple): The output shape. | |
- `dtype` (torch.dtype, optional): The output dtype. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The dequantized tensor. | |
""" | |
# Get block size and type size for quantization format | |
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] | |
dequantize_blocks = dequantize_functions[qtype] | |
# Reshape data into blocks | |
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) | |
n_blocks = rows.numel() // type_size | |
blocks = rows.reshape((n_blocks, type_size)) | |
# Dequantize blocks and reshape to target shape | |
blocks = dequantize_blocks(blocks, block_size, type_size, dtype) | |
return blocks.reshape(oshape) | |
def split_block_dims(blocks: torch.Tensor, *args) -> list: | |
"""#### Split blocks into dimensions. | |
#### Args: | |
- `blocks` (torch.Tensor): The blocks to split. | |
- `*args`: The dimensions to split into. | |
#### Returns: | |
- `list`: The split blocks. | |
""" | |
n_max = blocks.shape[1] | |
dims = list(args) + [n_max - sum(args)] | |
return torch.split(blocks, dims, dim=1) | |
# Legacy Quantization Functions | |
def dequantize_blocks_Q8_0( | |
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype = None | |
) -> torch.Tensor: | |
"""#### Dequantize Q8_0 quantized blocks. | |
#### Args: | |
- `blocks` (torch.Tensor): The quantized blocks. | |
- `block_size` (int): The block size. | |
- `type_size` (int): The type size. | |
- `dtype` (torch.dtype, optional): The output dtype. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The dequantized blocks. | |
""" | |
# Split blocks into scale and quantized values | |
d, x = split_block_dims(blocks, 2) | |
d = d.view(torch.float16).to(dtype) | |
x = x.view(torch.int8) | |
return d * x | |
# K Quants # | |
QK_K = 256 | |
K_SCALE_SIZE = 12 | |
# Mapping of quantization types to dequantization functions | |
dequantize_functions = { | |
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, | |
} | |
def dequantize_tensor( | |
tensor: torch.Tensor, dtype: torch.dtype = None, dequant_dtype: torch.dtype = None | |
) -> torch.Tensor: | |
"""#### Dequantize a potentially quantized tensor. | |
#### Args: | |
- `tensor` (torch.Tensor): The tensor to dequantize. | |
- `dtype` (torch.dtype, optional): Target dtype. Defaults to None. | |
- `dequant_dtype` (torch.dtype, optional): Intermediate dequantization dtype. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The dequantized tensor. | |
""" | |
qtype = getattr(tensor, "tensor_type", None) | |
oshape = getattr(tensor, "tensor_shape", tensor.shape) | |
if qtype in TORCH_COMPATIBLE_QTYPES: | |
return tensor.to(dtype) | |
elif qtype in dequantize_functions: | |
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype | |
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) | |
class GGMLLayer(torch.nn.Module): | |
"""#### Base class for GGML quantized layers. | |
Handles dynamic dequantization of weights during forward pass. | |
""" | |
comfy_cast_weights: bool = True | |
dequant_dtype: torch.dtype = None | |
patch_dtype: torch.dtype = None | |
torch_compatible_tensor_types: set = { | |
None, | |
gguf.GGMLQuantizationType.F32, | |
gguf.GGMLQuantizationType.F16, | |
} | |
def is_ggml_quantized( | |
self, *, weight: torch.Tensor = None, bias: torch.Tensor = None | |
) -> bool: | |
"""#### Check if layer weights are GGML quantized. | |
#### Args: | |
- `weight` (torch.Tensor, optional): Weight tensor to check. Defaults to self.weight. | |
- `bias` (torch.Tensor, optional): Bias tensor to check. Defaults to self.bias. | |
#### Returns: | |
- `bool`: Whether weights are quantized. | |
""" | |
if weight is None: | |
weight = self.weight | |
if bias is None: | |
bias = self.bias | |
return is_quantized(weight) or is_quantized(bias) | |
def _load_from_state_dict( | |
self, state_dict: dict, prefix: str, *args, **kwargs | |
) -> None: | |
"""#### Load quantized weights from state dict. | |
#### Args: | |
- `state_dict` (dict): State dictionary. | |
- `prefix` (str): Key prefix. | |
- `*args`: Additional arguments. | |
- `**kwargs`: Additional keyword arguments. | |
""" | |
weight = state_dict.get(f"{prefix}weight") | |
bias = state_dict.get(f"{prefix}bias") | |
# Use modified loader for quantized or linear layers | |
if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance( | |
self, torch.nn.Linear | |
): | |
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs) | |
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) | |
def ggml_load_from_state_dict( | |
self, | |
state_dict: dict, | |
prefix: str, | |
local_metadata: dict, | |
strict: bool, | |
missing_keys: list, | |
unexpected_keys: list, | |
error_msgs: list, | |
) -> None: | |
"""#### Load GGML quantized weights from state dict. | |
#### Args: | |
- `state_dict` (dict): State dictionary. | |
- `prefix` (str): Key prefix. | |
- `local_metadata` (dict): Local metadata. | |
- `strict` (bool): Strict loading mode. | |
- `missing_keys` (list): Keys missing from state dict. | |
- `unexpected_keys` (list): Unexpected keys found. | |
- `error_msgs` (list): Error messages. | |
""" | |
prefix_len = len(prefix) | |
for k, v in state_dict.items(): | |
if k[prefix_len:] == "weight": | |
self.weight = torch.nn.Parameter(v, requires_grad=False) | |
elif k[prefix_len:] == "bias" and v is not None: | |
self.bias = torch.nn.Parameter(v, requires_grad=False) | |
else: | |
missing_keys.append(k) | |
def _save_to_state_dict(self, *args, **kwargs) -> None: | |
"""#### Save layer state to state dict. | |
#### Args: | |
- `*args`: Additional arguments. | |
- `**kwargs`: Additional keyword arguments. | |
""" | |
if self.is_ggml_quantized(): | |
return self.ggml_save_to_state_dict(*args, **kwargs) | |
return super()._save_to_state_dict(*args, **kwargs) | |
def ggml_save_to_state_dict( | |
self, destination: dict, prefix: str, keep_vars: bool | |
) -> None: | |
"""#### Save GGML layer state to state dict. | |
#### Args: | |
- `destination` (dict): Destination dictionary. | |
- `prefix` (str): Key prefix. | |
- `keep_vars` (bool): Whether to keep variables. | |
""" | |
# Create fake tensors for VRAM estimation | |
weight = torch.zeros_like(self.weight, device=torch.device("meta")) | |
destination[prefix + "weight"] = weight | |
if self.bias is not None: | |
bias = torch.zeros_like(self.bias, device=torch.device("meta")) | |
destination[prefix + "bias"] = bias | |
return | |
def get_weight(self, tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: | |
"""#### Get dequantized weight tensor. | |
#### Args: | |
- `tensor` (torch.Tensor): Input tensor. | |
- `dtype` (torch.dtype): Target dtype. | |
#### Returns: | |
- `torch.Tensor`: Dequantized tensor. | |
""" | |
if tensor is None: | |
return | |
# Consolidate and load patches to GPU asynchronously | |
patch_list = [] | |
device = tensor.device | |
for function, patches, key in getattr(tensor, "patches", []): | |
patch_list += move_patch_to_device(patches, device) | |
# Dequantize tensor while patches load | |
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype) | |
# Apply patches | |
if patch_list: | |
if self.patch_dtype is None: | |
weight = function(patch_list, weight, key) | |
else: | |
# For testing, may degrade image quality | |
patch_dtype = ( | |
dtype if self.patch_dtype == "target" else self.patch_dtype | |
) | |
weight = function(patch_list, weight, key, patch_dtype) | |
return weight | |
def cast_bias_weight( | |
self, | |
input: torch.Tensor = None, | |
dtype: torch.dtype = None, | |
device: torch.device = None, | |
bias_dtype: torch.dtype = None, | |
) -> tuple: | |
"""#### Cast layer weights and bias to target dtype/device. | |
#### Args: | |
- `input` (torch.Tensor, optional): Input tensor for type/device inference. | |
- `dtype` (torch.dtype, optional): Target dtype. | |
- `device` (torch.device, optional): Target device. | |
- `bias_dtype` (torch.dtype, optional): Target bias dtype. | |
#### Returns: | |
- `tuple`: (cast_weight, cast_bias) | |
""" | |
if input is not None: | |
if dtype is None: | |
dtype = getattr(input, "dtype", torch.float32) | |
if bias_dtype is None: | |
bias_dtype = dtype | |
if device is None: | |
device = input.device | |
bias = None | |
non_blocking = Device.device_supports_non_blocking(device) | |
if self.bias is not None: | |
bias = self.get_weight(self.bias.to(device), dtype) | |
bias = cast.cast_to( | |
bias, bias_dtype, device, non_blocking=non_blocking, copy=False | |
) | |
weight = self.get_weight(self.weight.to(device), dtype) | |
weight = cast.cast_to( | |
weight, dtype, device, non_blocking=non_blocking, copy=False | |
) | |
return weight, bias | |
def forward_comfy_cast_weights( | |
self, input: torch.Tensor, *args, **kwargs | |
) -> torch.Tensor: | |
"""#### Forward pass with weight casting. | |
#### Args: | |
- `input` (torch.Tensor): Input tensor. | |
- `*args`: Additional arguments. | |
- `**kwargs`: Additional keyword arguments. | |
#### Returns: | |
- `torch.Tensor`: Output tensor. | |
""" | |
if self.is_ggml_quantized(): | |
return self.forward_ggml_cast_weights(input, *args, **kwargs) | |
return super().forward_comfy_cast_weights(input, *args, **kwargs) | |
class GGMLOps(cast.manual_cast): | |
""" | |
Dequantize weights on the fly before doing the compute | |
""" | |
class Linear(GGMLLayer, cast.manual_cast.Linear): | |
def __init__( | |
self, in_features, out_features, bias=True, device=None, dtype=None | |
): | |
""" | |
Initialize the Linear layer. | |
Args: | |
in_features (int): Number of input features. | |
out_features (int): Number of output features. | |
bias (bool, optional): If set to False, the layer will not learn an additive bias. Defaults to True. | |
device (torch.device, optional): The device to store the layer's parameters. Defaults to None. | |
dtype (torch.dtype, optional): The data type of the layer's parameters. Defaults to None. | |
""" | |
torch.nn.Module.__init__(self) | |
# TODO: better workaround for reserved memory spike on windows | |
# Issue is with `torch.empty` still reserving the full memory for the layer | |
# Windows doesn't over-commit memory so without this 24GB+ of pagefile is used | |
self.in_features = in_features | |
self.out_features = out_features | |
self.weight = None | |
self.bias = None | |
def forward_ggml_cast_weights(self, input: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass with GGML cast weights. | |
Args: | |
input (torch.Tensor): The input tensor. | |
Returns: | |
torch.Tensor: The output tensor. | |
""" | |
weight, bias = self.cast_bias_weight(input) | |
return torch.nn.functional.linear(input, weight, bias) | |
class Embedding(GGMLLayer, cast.manual_cast.Embedding): | |
def forward_ggml_cast_weights( | |
self, input: torch.Tensor, out_dtype: torch.dtype = None | |
) -> torch.Tensor: | |
""" | |
Forward pass with GGML cast weights for embedding. | |
Args: | |
input (torch.Tensor): The input tensor. | |
out_dtype (torch.dtype, optional): The output data type. Defaults to None. | |
Returns: | |
torch.Tensor: The output tensor. | |
""" | |
output_dtype = out_dtype | |
if ( | |
self.weight.dtype == torch.float16 | |
or self.weight.dtype == torch.bfloat16 | |
): | |
out_dtype = None | |
weight, _bias = self.cast_bias_weight( | |
self, device=input.device, dtype=out_dtype | |
) | |
return torch.nn.functional.embedding( | |
input, | |
weight, | |
self.padding_idx, | |
self.max_norm, | |
self.norm_type, | |
self.scale_grad_by_freq, | |
self.sparse, | |
).to(dtype=output_dtype) | |
def gguf_sd_loader_get_orig_shape( | |
reader: gguf.GGUFReader, tensor_name: str | |
) -> torch.Size: | |
"""#### Get the original shape of a tensor from a GGUF reader. | |
#### Args: | |
- `reader` (gguf.GGUFReader): The GGUF reader. | |
- `tensor_name` (str): The name of the tensor. | |
#### Returns: | |
- `torch.Size`: The original shape of the tensor. | |
""" | |
field_key = f"comfy.gguf.orig_shape.{tensor_name}" | |
field = reader.get_field(field_key) | |
if field is None: | |
return None | |
# Has original shape metadata, so we try to decode it. | |
if ( | |
len(field.types) != 2 | |
or field.types[0] != gguf.GGUFValueType.ARRAY | |
or field.types[1] != gguf.GGUFValueType.INT32 | |
): | |
raise TypeError( | |
f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}" | |
) | |
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) | |
class GGMLTensor(torch.Tensor): | |
""" | |
Main tensor-like class for storing quantized weights | |
""" | |
def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs): | |
""" | |
Initialize the GGMLTensor. | |
Args: | |
*args: Variable length argument list. | |
tensor_type: The type of the tensor. | |
tensor_shape: The shape of the tensor. | |
patches (list, optional): List of patches. Defaults to []. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
super().__init__() | |
self.tensor_type = tensor_type | |
self.tensor_shape = tensor_shape | |
self.patches = patches | |
def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs): | |
""" | |
Create a new instance of GGMLTensor. | |
Args: | |
*args: Variable length argument list. | |
tensor_type: The type of the tensor. | |
tensor_shape: The shape of the tensor. | |
patches (list, optional): List of patches. Defaults to []. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
GGMLTensor: A new instance of GGMLTensor. | |
""" | |
return super().__new__(cls, *args, **kwargs) | |
def to(self, *args, **kwargs): | |
""" | |
Convert the tensor to a specified device and/or dtype. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
GGMLTensor: The converted tensor. | |
""" | |
new = super().to(*args, **kwargs) | |
new.tensor_type = getattr(self, "tensor_type", None) | |
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) | |
new.patches = getattr(self, "patches", []).copy() | |
return new | |
def clone(self, *args, **kwargs): | |
""" | |
Clone the tensor. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
GGMLTensor: The cloned tensor. | |
""" | |
return self | |
def detach(self, *args, **kwargs): | |
""" | |
Detach the tensor from the computation graph. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
GGMLTensor: The detached tensor. | |
""" | |
return self | |
def copy_(self, *args, **kwargs): | |
""" | |
Copy the values from another tensor into this tensor. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
GGMLTensor: The tensor with copied values. | |
""" | |
try: | |
return super().copy_(*args, **kwargs) | |
except Exception as e: | |
print(f"ignoring 'copy_' on tensor: {e}") | |
def __deepcopy__(self, *args, **kwargs): | |
""" | |
Create a deep copy of the tensor. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
GGMLTensor: The deep copied tensor. | |
""" | |
new = super().__deepcopy__(*args, **kwargs) | |
new.tensor_type = getattr(self, "tensor_type", None) | |
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) | |
new.patches = getattr(self, "patches", []).copy() | |
return new | |
def shape(self): | |
""" | |
Get the shape of the tensor. | |
Returns: | |
torch.Size: The shape of the tensor. | |
""" | |
if not hasattr(self, "tensor_shape"): | |
self.tensor_shape = self.size() | |
return self.tensor_shape | |
def gguf_sd_loader(path: str, handle_prefix: str = "model.diffusion_model."): | |
"""#### Load a GGUF file into a state dict. | |
#### Args: | |
- `path` (str): The path to the GGUF file. | |
- `handle_prefix` (str, optional): The prefix to handle. Defaults to "model.diffusion_model.". | |
#### Returns: | |
- `dict`: The loaded state dict. | |
""" | |
reader = gguf.GGUFReader(path) | |
# filter and strip prefix | |
has_prefix = False | |
if handle_prefix is not None: | |
prefix_len = len(handle_prefix) | |
tensor_names = set(tensor.name for tensor in reader.tensors) | |
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) | |
tensors = [] | |
for tensor in reader.tensors: | |
sd_key = tensor_name = tensor.name | |
if has_prefix: | |
if not tensor_name.startswith(handle_prefix): | |
continue | |
sd_key = tensor_name[prefix_len:] | |
tensors.append((sd_key, tensor)) | |
# detect and verify architecture | |
compat = None | |
arch_str = None | |
arch_field = reader.get_field("general.architecture") | |
if arch_field is not None: | |
if ( | |
len(arch_field.types) != 1 | |
or arch_field.types[0] != gguf.GGUFValueType.STRING | |
): | |
raise TypeError( | |
f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}" | |
) | |
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8") | |
if arch_str not in {"flux", "sd1", "sdxl", "t5", "t5encoder"}: | |
raise ValueError( | |
f"Unexpected architecture type in GGUF file, expected one of flux, sd1, sdxl, t5encoder but got {arch_str!r}" | |
) | |
# main loading loop | |
state_dict = {} | |
qtype_dict = {} | |
for sd_key, tensor in tensors: | |
tensor_name = tensor.name | |
tensor_type_str = str(tensor.tensor_type) | |
torch_tensor = torch.from_numpy(tensor.data) # mmap | |
shape = gguf_sd_loader_get_orig_shape(reader, tensor_name) | |
if shape is None: | |
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) | |
# Workaround for stable-diffusion.cpp SDXL detection. | |
if compat == "sd.cpp" and arch_str == "sdxl": | |
if any( | |
[ | |
tensor_name.endswith(x) | |
for x in (".proj_in.weight", ".proj_out.weight") | |
] | |
): | |
while len(shape) > 2 and shape[-1] == 1: | |
shape = shape[:-1] | |
# add to state dict | |
if tensor.tensor_type in { | |
gguf.GGMLQuantizationType.F32, | |
gguf.GGMLQuantizationType.F16, | |
}: | |
torch_tensor = torch_tensor.view(*shape) | |
state_dict[sd_key] = GGMLTensor( | |
torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape | |
) | |
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 | |
# sanity check debug print | |
print("\nggml_sd_loader:") | |
for k, v in qtype_dict.items(): | |
print(f" {k:30}{v:3}") | |
return state_dict | |
class GGUFModelPatcher(ModelPatcher.ModelPatcher): | |
patch_on_device = False | |
def unpatch_model(self, device_to=None, unpatch_weights=True): | |
""" | |
Unpatch the model. | |
Args: | |
device_to (torch.device, optional): The device to move the model to. Defaults to None. | |
unpatch_weights (bool, optional): Whether to unpatch the weights. Defaults to True. | |
Returns: | |
GGUFModelPatcher: The unpatched model. | |
""" | |
if unpatch_weights: | |
for p in self.model.parameters(): | |
if is_torch_compatible(p): | |
continue | |
patches = getattr(p, "patches", []) | |
if len(patches) > 0: | |
p.patches = [] | |
self.object_patches = {} | |
# TODO: Find another way to not unload after patches | |
return super().unpatch_model( | |
device_to=device_to, unpatch_weights=unpatch_weights | |
) | |
mmap_released = False | |
def load(self, *args, force_patch_weights=False, **kwargs): | |
""" | |
Load the model. | |
Args: | |
*args: Variable length argument list. | |
force_patch_weights (bool, optional): Whether to force patch weights. Defaults to False. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
super().load(*args, force_patch_weights=True, **kwargs) | |
# make sure nothing stays linked to mmap after first load | |
if not self.mmap_released: | |
linked = [] | |
if kwargs.get("lowvram_model_memory", 0) > 0: | |
for n, m in self.model.named_modules(): | |
if hasattr(m, "weight"): | |
device = getattr(m.weight, "device", None) | |
if device == self.offload_device: | |
linked.append((n, m)) | |
continue | |
if hasattr(m, "bias"): | |
device = getattr(m.bias, "device", None) | |
if device == self.offload_device: | |
linked.append((n, m)) | |
continue | |
if linked: | |
print(f"Attempting to release mmap ({len(linked)})") | |
for n, m in linked: | |
# TODO: possible to OOM, find better way to detach | |
m.to(self.load_device).to(self.offload_device) | |
self.mmap_released = True | |
def add_object_patch(self, name, obj): | |
self.object_patches[name] = obj | |
def clone(self, *args, **kwargs): | |
""" | |
Clone the model patcher. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
GGUFModelPatcher: The cloned model patcher. | |
""" | |
n = GGUFModelPatcher( | |
self.model, | |
self.load_device, | |
self.offload_device, | |
self.size, | |
weight_inplace_update=self.weight_inplace_update, | |
) | |
n.patches = {} | |
for k in self.patches: | |
n.patches[k] = self.patches[k][:] | |
n.patches_uuid = self.patches_uuid | |
n.object_patches = self.object_patches.copy() | |
n.model_options = copy.deepcopy(self.model_options) | |
n.backup = self.backup | |
n.object_patches_backup = self.object_patches_backup | |
n.patch_on_device = getattr(self, "patch_on_device", False) | |
return n | |
class UnetLoaderGGUF: | |
def load_unet( | |
self, | |
unet_name: str, | |
dequant_dtype: str = None, | |
patch_dtype: str = None, | |
patch_on_device: bool = None, | |
) -> tuple: | |
""" | |
Load the UNet model. | |
Args: | |
unet_name (str): The name of the UNet model. | |
dequant_dtype (str, optional): The dequantization data type. Defaults to None. | |
patch_dtype (str, optional): The patch data type. Defaults to None. | |
patch_on_device (bool, optional): Whether to patch on device. Defaults to None. | |
Returns: | |
tuple: The loaded model. | |
""" | |
ops = GGMLOps() | |
if dequant_dtype in ("default", None): | |
ops.Linear.dequant_dtype = None | |
elif dequant_dtype in ["target"]: | |
ops.Linear.dequant_dtype = dequant_dtype | |
else: | |
ops.Linear.dequant_dtype = getattr(torch, dequant_dtype) | |
if patch_dtype in ("default", None): | |
ops.Linear.patch_dtype = None | |
elif patch_dtype in ["target"]: | |
ops.Linear.patch_dtype = patch_dtype | |
else: | |
ops.Linear.patch_dtype = getattr(torch, patch_dtype) | |
unet_path = "./_internal/unet/" + unet_name | |
sd = gguf_sd_loader(unet_path) | |
model = ModelPatcher.load_diffusion_model_state_dict( | |
sd, model_options={"custom_operations": ops} | |
) | |
if model is None: | |
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) | |
raise RuntimeError( | |
"ERROR: Could not detect model type of: {}".format(unet_path) | |
) | |
model = GGUFModelPatcher.clone(model) | |
model.patch_on_device = patch_on_device | |
return (model,) | |
clip_sd_map = { | |
"enc.": "encoder.", | |
".blk.": ".block.", | |
"token_embd": "shared", | |
"output_norm": "final_layer_norm", | |
"attn_q": "layer.0.SelfAttention.q", | |
"attn_k": "layer.0.SelfAttention.k", | |
"attn_v": "layer.0.SelfAttention.v", | |
"attn_o": "layer.0.SelfAttention.o", | |
"attn_norm": "layer.0.layer_norm", | |
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", | |
"ffn_up": "layer.1.DenseReluDense.wi_1", | |
"ffn_down": "layer.1.DenseReluDense.wo", | |
"ffn_gate": "layer.1.DenseReluDense.wi_0", | |
"ffn_norm": "layer.1.layer_norm", | |
} | |
clip_name_dict = { | |
"stable_diffusion": Clip.CLIPType.STABLE_DIFFUSION, | |
"sdxl": Clip.CLIPType.STABLE_DIFFUSION, | |
"sd3": Clip.CLIPType.SD3, | |
"flux": Clip.CLIPType.FLUX, | |
} | |
def gguf_clip_loader(path: str) -> dict: | |
"""#### Load a CLIP model from a GGUF file. | |
#### Args: | |
- `path` (str): The path to the GGUF file. | |
#### Returns: | |
- `dict`: The loaded CLIP model. | |
""" | |
raw_sd = gguf_sd_loader(path) | |
assert "enc.blk.23.ffn_up.weight" in raw_sd, "Invalid Text Encoder!" | |
sd = {} | |
for k, v in raw_sd.items(): | |
for s, d in clip_sd_map.items(): | |
k = k.replace(s, d) | |
sd[k] = v | |
return sd | |
class CLIPLoaderGGUF: | |
def load_data(self, ckpt_paths: list) -> list: | |
""" | |
Load data from checkpoint paths. | |
Args: | |
ckpt_paths (list): List of checkpoint paths. | |
Returns: | |
list: List of loaded data. | |
""" | |
clip_data = [] | |
for p in ckpt_paths: | |
if p.endswith(".gguf"): | |
clip_data.append(gguf_clip_loader(p)) | |
else: | |
sd = util.load_torch_file(p, safe_load=True) | |
clip_data.append( | |
{ | |
k: GGMLTensor( | |
v, | |
tensor_type=gguf.GGMLQuantizationType.F16, | |
tensor_shape=v.shape, | |
) | |
for k, v in sd.items() | |
} | |
) | |
return clip_data | |
def load_patcher(self, clip_paths: list, clip_type: str, clip_data: list) -> Clip: | |
""" | |
Load the model patcher. | |
Args: | |
clip_paths (list): List of clip paths. | |
clip_type (str): The type of the clip. | |
clip_data (list): List of clip data. | |
Returns: | |
Clip: The loaded clip. | |
""" | |
clip = Clip.load_text_encoder_state_dicts( | |
clip_type=clip_type, | |
state_dicts=clip_data, | |
model_options={ | |
"custom_operations": GGMLOps, | |
"initial_device": Device.text_encoder_offload_device(), | |
}, | |
embedding_directory="models/embeddings", | |
) | |
clip.patcher = GGUFModelPatcher.clone(clip.patcher) | |
# for some reason this is just missing in some SAI checkpoints | |
if getattr(clip.cond_stage_model, "clip_l", None) is not None: | |
if ( | |
getattr( | |
clip.cond_stage_model.clip_l.transformer.text_projection.weight, | |
"tensor_shape", | |
None, | |
) | |
is None | |
): | |
clip.cond_stage_model.clip_l.transformer.text_projection = ( | |
cast.manual_cast.Linear(768, 768) | |
) | |
if getattr(clip.cond_stage_model, "clip_g", None) is not None: | |
if ( | |
getattr( | |
clip.cond_stage_model.clip_g.transformer.text_projection.weight, | |
"tensor_shape", | |
None, | |
) | |
is None | |
): | |
clip.cond_stage_model.clip_g.transformer.text_projection = ( | |
cast.manual_cast.Linear(1280, 1280) | |
) | |
return clip | |
class DualCLIPLoaderGGUF(CLIPLoaderGGUF): | |
def load_clip(self, clip_name1: str, clip_name2: str, type: str) -> tuple: | |
""" | |
Load dual clips. | |
Args: | |
clip_name1 (str): The name of the first clip. | |
clip_name2 (str): The name of the second clip. | |
type (str): The type of the clip. | |
Returns: | |
tuple: The loaded clips. | |
""" | |
clip_path1 = "./_internal/clip/" + clip_name1 | |
clip_path2 = "./_internal/clip/" + clip_name2 | |
clip_paths = (clip_path1, clip_path2) | |
clip_type = clip_name_dict.get(type, Clip.CLIPType.STABLE_DIFFUSION) | |
return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) | |
class CLIPTextEncodeFlux: | |
def encode( | |
self, | |
clip: Clip, | |
clip_l: str, | |
t5xxl: str, | |
guidance: str, | |
flux_enabled: bool = False, | |
) -> tuple: | |
""" | |
Encode text using CLIP and T5XXL. | |
Args: | |
clip (Clip): The clip object. | |
clip_l (str): The clip text. | |
t5xxl (str): The T5XXL text. | |
guidance (str): The guidance text. | |
flux_enabled (bool, optional): Whether flux is enabled. Defaults to False. | |
Returns: | |
tuple: The encoded text. | |
""" | |
tokens = clip.tokenize(clip_l) | |
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] | |
output = clip.encode_from_tokens( | |
tokens, return_pooled=True, return_dict=True, flux_enabled=flux_enabled | |
) | |
cond = output.pop("cond") | |
output["guidance"] = guidance | |
return ([[cond, output]],) | |
class ConditioningZeroOut: | |
def zero_out(self, conditioning: list) -> list: | |
""" | |
Zero out the conditioning. | |
Args: | |
conditioning (list): The conditioning list. | |
Returns: | |
list: The zeroed out conditioning. | |
""" | |
c = [] | |
for t in conditioning: | |
d = t[1].copy() | |
pooled_output = d.get("pooled_output", None) | |
if pooled_output is not None: | |
d["pooled_output"] = torch.zeros_like(pooled_output) | |
n = [torch.zeros_like(t[0]), d] | |
c.append(n) | |
return (c,) | |