Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- modules/Model/ModelBase.py +23 -14
- modules/NeuralNetwork/unet.py +48 -34
- modules/Utilities/util.py +24 -13
- modules/cond/cond.py +87 -55
- modules/sample/CFG.py +56 -20
- modules/sample/ksampler_util.py +73 -23
- modules/sample/samplers.py +44 -221
- modules/sample/sampling.py +313 -370
- modules/user/GUI.py +8 -4
- modules/user/pipeline.py +6 -6
modules/Model/ModelBase.py
CHANGED
@@ -56,7 +56,9 @@ class BaseModel(torch.nn.Module):
|
|
56 |
**unet_config, device=device, operations=operations
|
57 |
)
|
58 |
self.model_type = model_type
|
59 |
-
self.model_sampling = sampling.model_sampling(
|
|
|
|
|
60 |
|
61 |
self.adm_channels = unet_config.get("adm_in_channels", None)
|
62 |
if self.adm_channels is None:
|
@@ -93,26 +95,32 @@ class BaseModel(torch.nn.Module):
|
|
93 |
"""
|
94 |
sigma = t
|
95 |
xc = self.model_sampling.calculate_input(sigma, x)
|
96 |
-
if c_concat is not None:
|
97 |
-
xc = torch.cat([xc] + [c_concat], dim=1)
|
98 |
|
99 |
-
|
100 |
-
|
|
|
101 |
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
|
|
|
105 |
xc = xc.to(dtype)
|
106 |
t = self.model_sampling.timestep(t).float()
|
107 |
-
context =
|
|
|
|
|
108 |
extra_conds = {}
|
109 |
-
for
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
extra_conds[o] = extra
|
115 |
|
|
|
116 |
model_output = self.diffusion_model(
|
117 |
xc,
|
118 |
t,
|
@@ -121,6 +129,7 @@ class BaseModel(torch.nn.Module):
|
|
121 |
transformer_options=transformer_options,
|
122 |
**extra_conds,
|
123 |
).float()
|
|
|
124 |
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
125 |
|
126 |
def get_dtype(self) -> torch.dtype:
|
|
|
56 |
**unet_config, device=device, operations=operations
|
57 |
)
|
58 |
self.model_type = model_type
|
59 |
+
self.model_sampling = sampling.model_sampling(
|
60 |
+
model_config, model_type, flux=flux
|
61 |
+
)
|
62 |
|
63 |
self.adm_channels = unet_config.get("adm_in_channels", None)
|
64 |
if self.adm_channels is None:
|
|
|
95 |
"""
|
96 |
sigma = t
|
97 |
xc = self.model_sampling.calculate_input(sigma, x)
|
|
|
|
|
98 |
|
99 |
+
# Optimize concatenation operation by avoiding unnecessary list creation
|
100 |
+
if c_concat is not None:
|
101 |
+
xc = torch.cat((xc, c_concat), dim=1)
|
102 |
|
103 |
+
# Determine dtype once to avoid repeated calls to get_dtype()
|
104 |
+
dtype = (
|
105 |
+
self.manual_cast_dtype
|
106 |
+
if self.manual_cast_dtype is not None
|
107 |
+
else self.get_dtype()
|
108 |
+
)
|
109 |
|
110 |
+
# Batch operations to reduce overhead
|
111 |
xc = xc.to(dtype)
|
112 |
t = self.model_sampling.timestep(t).float()
|
113 |
+
context = c_crossattn.to(dtype) if c_crossattn is not None else None
|
114 |
+
|
115 |
+
# Process extra conditions more efficiently
|
116 |
extra_conds = {}
|
117 |
+
for name, value in kwargs.items():
|
118 |
+
if hasattr(value, "dtype") and value.dtype not in (torch.int, torch.long):
|
119 |
+
extra_conds[name] = value.to(dtype)
|
120 |
+
else:
|
121 |
+
extra_conds[name] = value
|
|
|
122 |
|
123 |
+
# Run diffusion model and calculate denoised output
|
124 |
model_output = self.diffusion_model(
|
125 |
xc,
|
126 |
t,
|
|
|
129 |
transformer_options=transformer_options,
|
130 |
**extra_conds,
|
131 |
).float()
|
132 |
+
|
133 |
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
134 |
|
135 |
def get_dtype(self) -> torch.dtype:
|
modules/NeuralNetwork/unet.py
CHANGED
@@ -304,7 +304,9 @@ class UNetModel1(nn.Module):
|
|
304 |
if num_heads_upsample == -1:
|
305 |
num_heads_upsample = num_heads
|
306 |
if num_head_channels == -1:
|
307 |
-
assert num_heads != -1,
|
|
|
|
|
308 |
|
309 |
self.in_channels = in_channels
|
310 |
self.model_channels = model_channels
|
@@ -684,36 +686,29 @@ class UNetModel1(nn.Module):
|
|
684 |
transformer_options: Dict[str, Any] = {},
|
685 |
**kwargs: Any,
|
686 |
) -> torch.Tensor:
|
687 |
-
"""#### Forward pass of the UNet model.
|
688 |
-
|
689 |
-
#### Args:
|
690 |
-
- `x` (torch.Tensor): The input tensor.
|
691 |
-
- `timesteps` (Optional[torch.Tensor], optional): The timesteps tensor. Defaults to None.
|
692 |
-
- `context` (Optional[torch.Tensor], optional): The context tensor. Defaults to None.
|
693 |
-
- `y` (Optional[torch.Tensor], optional): The class labels tensor. Defaults to None.
|
694 |
-
- `control` (Optional[torch.Tensor], optional): The control tensor. Defaults to None.
|
695 |
-
- `transformer_options` (Dict[str, Any], optional): Options for the transformer. Defaults to {}.
|
696 |
-
- `**kwargs` (Any): Additional keyword arguments.
|
697 |
-
|
698 |
-
#### Returns:
|
699 |
-
- `torch.Tensor`: The output tensor.
|
700 |
-
"""
|
701 |
transformer_options["original_shape"] = list(x.shape)
|
702 |
transformer_options["transformer_index"] = 0
|
703 |
-
transformer_patches = transformer_options.get("patches", {})
|
704 |
|
|
|
705 |
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
706 |
image_only_indicator = kwargs.get("image_only_indicator", None)
|
707 |
time_context = kwargs.get("time_context", None)
|
708 |
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
).to(
|
|
|
|
|
716 |
emb = self.time_embed(t_emb)
|
|
|
|
|
|
|
717 |
h = x
|
718 |
for id, module in enumerate(self.input_blocks):
|
719 |
transformer_options["block"] = ("input", id)
|
@@ -730,6 +725,7 @@ class UNetModel1(nn.Module):
|
|
730 |
h = apply_control1(h, control, "input")
|
731 |
hs.append(h)
|
732 |
|
|
|
733 |
transformer_options["block"] = ("middle", 0)
|
734 |
if self.middle_block is not None:
|
735 |
h = ResBlock.forward_timestep_embed1(
|
@@ -744,17 +740,19 @@ class UNetModel1(nn.Module):
|
|
744 |
)
|
745 |
h = apply_control1(h, control, "middle")
|
746 |
|
|
|
747 |
for id, module in enumerate(self.output_blocks):
|
748 |
transformer_options["block"] = ("output", id)
|
749 |
hsp = hs.pop()
|
750 |
hsp = apply_control1(hsp, control, "output")
|
751 |
|
|
|
752 |
h = torch.cat([h, hsp], dim=1)
|
753 |
-
del hsp
|
754 |
-
|
755 |
-
|
756 |
-
else
|
757 |
-
|
758 |
h = ResBlock.forward_timestep_embed1(
|
759 |
module,
|
760 |
h,
|
@@ -766,11 +764,15 @@ class UNetModel1(nn.Module):
|
|
766 |
num_video_frames=num_video_frames,
|
767 |
image_only_indicator=image_only_indicator,
|
768 |
)
|
|
|
|
|
769 |
h = h.type(x.dtype)
|
770 |
return self.out(h)
|
771 |
|
772 |
|
773 |
-
def detect_unet_config(
|
|
|
|
|
774 |
"""#### Detect the UNet configuration from a state dictionary.
|
775 |
|
776 |
#### Args:
|
@@ -1017,7 +1019,9 @@ def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) ->
|
|
1017 |
// model_channels
|
1018 |
)
|
1019 |
|
1020 |
-
out = transformer.calculate_transformer_depth(
|
|
|
|
|
1021 |
if out is not None:
|
1022 |
transformer_depth.append(out[0])
|
1023 |
if context_dim is None:
|
@@ -1076,7 +1080,9 @@ def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) ->
|
|
1076 |
return unet_config
|
1077 |
|
1078 |
|
1079 |
-
def model_config_from_unet_config(
|
|
|
|
|
1080 |
"""#### Get the model configuration from a UNet configuration.
|
1081 |
|
1082 |
#### Args:
|
@@ -1096,7 +1102,11 @@ def model_config_from_unet_config(unet_config: Dict[str, Any], state_dict: Optio
|
|
1096 |
return None
|
1097 |
|
1098 |
|
1099 |
-
def model_config_from_unet(
|
|
|
|
|
|
|
|
|
1100 |
"""#### Get the model configuration from a UNet state dictionary.
|
1101 |
|
1102 |
#### Args:
|
@@ -1117,7 +1127,11 @@ def model_config_from_unet(state_dict: Dict[str, torch.Tensor], unet_key_prefix:
|
|
1117 |
def unet_dtype1(
|
1118 |
device: Optional[torch.device] = None,
|
1119 |
model_params: int = 0,
|
1120 |
-
supported_dtypes: List[torch.dtype] = [
|
|
|
|
|
|
|
|
|
1121 |
) -> torch.dtype:
|
1122 |
"""#### Get the dtype for the UNet model.
|
1123 |
|
@@ -1129,4 +1143,4 @@ def unet_dtype1(
|
|
1129 |
#### Returns:
|
1130 |
- `torch.dtype`: The dtype for the UNet model.
|
1131 |
"""
|
1132 |
-
return torch.float16
|
|
|
304 |
if num_heads_upsample == -1:
|
305 |
num_heads_upsample = num_heads
|
306 |
if num_head_channels == -1:
|
307 |
+
assert num_heads != -1, (
|
308 |
+
"Either num_heads or num_head_channels has to be set"
|
309 |
+
)
|
310 |
|
311 |
self.in_channels = in_channels
|
312 |
self.model_channels = model_channels
|
|
|
686 |
transformer_options: Dict[str, Any] = {},
|
687 |
**kwargs: Any,
|
688 |
) -> torch.Tensor:
|
689 |
+
"""#### Forward pass of the UNet model with optimized calculations."""
|
690 |
+
# Setup transformer options (avoid unused variable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
transformer_options["original_shape"] = list(x.shape)
|
692 |
transformer_options["transformer_index"] = 0
|
|
|
693 |
|
694 |
+
# Extract kwargs efficiently
|
695 |
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
696 |
image_only_indicator = kwargs.get("image_only_indicator", None)
|
697 |
time_context = kwargs.get("time_context", None)
|
698 |
|
699 |
+
# Validation
|
700 |
+
assert (y is not None) == (self.num_classes is not None), (
|
701 |
+
"must specify y if and only if the model is class-conditional"
|
702 |
+
)
|
703 |
+
|
704 |
+
# Time embedding - optimize by computing with target dtype directly
|
705 |
+
t_emb = sampling_util.timestep_embedding(timesteps, self.model_channels).to(
|
706 |
+
x.dtype
|
707 |
+
)
|
708 |
emb = self.time_embed(t_emb)
|
709 |
+
|
710 |
+
# Input blocks processing
|
711 |
+
hs = []
|
712 |
h = x
|
713 |
for id, module in enumerate(self.input_blocks):
|
714 |
transformer_options["block"] = ("input", id)
|
|
|
725 |
h = apply_control1(h, control, "input")
|
726 |
hs.append(h)
|
727 |
|
728 |
+
# Middle block processing
|
729 |
transformer_options["block"] = ("middle", 0)
|
730 |
if self.middle_block is not None:
|
731 |
h = ResBlock.forward_timestep_embed1(
|
|
|
740 |
)
|
741 |
h = apply_control1(h, control, "middle")
|
742 |
|
743 |
+
# Output blocks processing - optimize memory usage
|
744 |
for id, module in enumerate(self.output_blocks):
|
745 |
transformer_options["block"] = ("output", id)
|
746 |
hsp = hs.pop()
|
747 |
hsp = apply_control1(hsp, control, "output")
|
748 |
|
749 |
+
# Concatenate tensors
|
750 |
h = torch.cat([h, hsp], dim=1)
|
751 |
+
del hsp # Free memory immediately
|
752 |
+
|
753 |
+
# Only calculate output shape when needed
|
754 |
+
output_shape = hs[-1].shape if hs else None
|
755 |
+
|
756 |
h = ResBlock.forward_timestep_embed1(
|
757 |
module,
|
758 |
h,
|
|
|
764 |
num_video_frames=num_video_frames,
|
765 |
image_only_indicator=image_only_indicator,
|
766 |
)
|
767 |
+
|
768 |
+
# Ensure output has correct dtype
|
769 |
h = h.type(x.dtype)
|
770 |
return self.out(h)
|
771 |
|
772 |
|
773 |
+
def detect_unet_config(
|
774 |
+
state_dict: Dict[str, torch.Tensor], key_prefix: str
|
775 |
+
) -> Dict[str, Any]:
|
776 |
"""#### Detect the UNet configuration from a state dictionary.
|
777 |
|
778 |
#### Args:
|
|
|
1019 |
// model_channels
|
1020 |
)
|
1021 |
|
1022 |
+
out = transformer.calculate_transformer_depth(
|
1023 |
+
prefix, state_dict_keys, state_dict
|
1024 |
+
)
|
1025 |
if out is not None:
|
1026 |
transformer_depth.append(out[0])
|
1027 |
if context_dim is None:
|
|
|
1080 |
return unet_config
|
1081 |
|
1082 |
|
1083 |
+
def model_config_from_unet_config(
|
1084 |
+
unet_config: Dict[str, Any], state_dict: Optional[Dict[str, torch.Tensor]] = None
|
1085 |
+
) -> Any:
|
1086 |
"""#### Get the model configuration from a UNet configuration.
|
1087 |
|
1088 |
#### Args:
|
|
|
1102 |
return None
|
1103 |
|
1104 |
|
1105 |
+
def model_config_from_unet(
|
1106 |
+
state_dict: Dict[str, torch.Tensor],
|
1107 |
+
unet_key_prefix: str,
|
1108 |
+
use_base_if_no_match: bool = False,
|
1109 |
+
) -> Any:
|
1110 |
"""#### Get the model configuration from a UNet state dictionary.
|
1111 |
|
1112 |
#### Args:
|
|
|
1127 |
def unet_dtype1(
|
1128 |
device: Optional[torch.device] = None,
|
1129 |
model_params: int = 0,
|
1130 |
+
supported_dtypes: List[torch.dtype] = [
|
1131 |
+
torch.float16,
|
1132 |
+
torch.bfloat16,
|
1133 |
+
torch.float32,
|
1134 |
+
],
|
1135 |
) -> torch.dtype:
|
1136 |
"""#### Get the dtype for the UNet model.
|
1137 |
|
|
|
1143 |
#### Returns:
|
1144 |
- `torch.dtype`: The dtype for the UNet model.
|
1145 |
"""
|
1146 |
+
return torch.float16
|
modules/Utilities/util.py
CHANGED
@@ -4,7 +4,6 @@ import itertools
|
|
4 |
import logging
|
5 |
import math
|
6 |
import os
|
7 |
-
import pickle
|
8 |
import safetensors.torch
|
9 |
import torch
|
10 |
|
@@ -120,6 +119,18 @@ def state_dict_prefix_replace(
|
|
120 |
return out
|
121 |
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def repeat_to_batch_size(
|
124 |
tensor: torch.Tensor, batch_size: int, dim: int = 0
|
125 |
) -> torch.Tensor:
|
@@ -437,11 +448,11 @@ def tiled_scale_multidim(
|
|
437 |
|
438 |
def get_upscale(dim: int, val: int) -> int:
|
439 |
"""#### Get the upscale value.
|
440 |
-
|
441 |
#### Args:
|
442 |
- `dim` (int): The dimension.
|
443 |
- `val` (int): The value.
|
444 |
-
|
445 |
#### Returns:
|
446 |
- `int`: The upscaled value.
|
447 |
"""
|
@@ -453,11 +464,11 @@ def tiled_scale_multidim(
|
|
453 |
|
454 |
def get_downscale(dim: int, val: int) -> int:
|
455 |
"""#### Get the downscale value.
|
456 |
-
|
457 |
#### Args:
|
458 |
- `dim` (int): The dimension.
|
459 |
- `val` (int): The value.
|
460 |
-
|
461 |
#### Returns:
|
462 |
- `int`: The downscaled value.
|
463 |
"""
|
@@ -469,11 +480,11 @@ def tiled_scale_multidim(
|
|
469 |
|
470 |
def get_upscale_pos(dim: int, val: int) -> int:
|
471 |
"""#### Get the upscaled position.
|
472 |
-
|
473 |
#### Args:
|
474 |
- `dim` (int): The dimension.
|
475 |
- `val` (int): The value.
|
476 |
-
|
477 |
#### Returns:
|
478 |
- `int`: The upscaled position.
|
479 |
"""
|
@@ -485,11 +496,11 @@ def tiled_scale_multidim(
|
|
485 |
|
486 |
def get_downscale_pos(dim: int, val: int) -> int:
|
487 |
"""#### Get the downscaled position.
|
488 |
-
|
489 |
#### Args:
|
490 |
- `dim` (int): The dimension.
|
491 |
- `val` (int): The value.
|
492 |
-
|
493 |
#### Returns:
|
494 |
- `int`: The downscaled position.
|
495 |
"""
|
@@ -508,10 +519,10 @@ def tiled_scale_multidim(
|
|
508 |
|
509 |
def mult_list_upscale(a: list) -> list:
|
510 |
"""#### Multiply a list by the upscale amount.
|
511 |
-
|
512 |
#### Args:
|
513 |
- `a` (list): The list.
|
514 |
-
|
515 |
#### Returns:
|
516 |
- `list`: The multiplied list.
|
517 |
"""
|
@@ -601,7 +612,7 @@ def tiled_scale(
|
|
601 |
pbar: any = None,
|
602 |
):
|
603 |
"""#### Scale an image using a tiled approach.
|
604 |
-
|
605 |
#### Args:
|
606 |
- `samples` (torch.Tensor): The input samples.
|
607 |
- `function` (function): The scaling function.
|
@@ -612,7 +623,7 @@ def tiled_scale(
|
|
612 |
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
|
613 |
- `output_device` (str, optional): The output device. Defaults to "cpu".
|
614 |
- `pbar` (any, optional): The progress bar. Defaults to None.
|
615 |
-
|
616 |
#### Returns:
|
617 |
- The scaled image.
|
618 |
"""
|
|
|
4 |
import logging
|
5 |
import math
|
6 |
import os
|
|
|
7 |
import safetensors.torch
|
8 |
import torch
|
9 |
|
|
|
119 |
return out
|
120 |
|
121 |
|
122 |
+
def lcm_of_list(numbers):
|
123 |
+
"""Calculate LCM of a list of numbers more efficiently."""
|
124 |
+
if not numbers:
|
125 |
+
return 1
|
126 |
+
|
127 |
+
result = numbers[0]
|
128 |
+
for num in numbers[1:]:
|
129 |
+
result = torch.lcm(torch.tensor(result), torch.tensor(num)).item()
|
130 |
+
|
131 |
+
return result
|
132 |
+
|
133 |
+
|
134 |
def repeat_to_batch_size(
|
135 |
tensor: torch.Tensor, batch_size: int, dim: int = 0
|
136 |
) -> torch.Tensor:
|
|
|
448 |
|
449 |
def get_upscale(dim: int, val: int) -> int:
|
450 |
"""#### Get the upscale value.
|
451 |
+
|
452 |
#### Args:
|
453 |
- `dim` (int): The dimension.
|
454 |
- `val` (int): The value.
|
455 |
+
|
456 |
#### Returns:
|
457 |
- `int`: The upscaled value.
|
458 |
"""
|
|
|
464 |
|
465 |
def get_downscale(dim: int, val: int) -> int:
|
466 |
"""#### Get the downscale value.
|
467 |
+
|
468 |
#### Args:
|
469 |
- `dim` (int): The dimension.
|
470 |
- `val` (int): The value.
|
471 |
+
|
472 |
#### Returns:
|
473 |
- `int`: The downscaled value.
|
474 |
"""
|
|
|
480 |
|
481 |
def get_upscale_pos(dim: int, val: int) -> int:
|
482 |
"""#### Get the upscaled position.
|
483 |
+
|
484 |
#### Args:
|
485 |
- `dim` (int): The dimension.
|
486 |
- `val` (int): The value.
|
487 |
+
|
488 |
#### Returns:
|
489 |
- `int`: The upscaled position.
|
490 |
"""
|
|
|
496 |
|
497 |
def get_downscale_pos(dim: int, val: int) -> int:
|
498 |
"""#### Get the downscaled position.
|
499 |
+
|
500 |
#### Args:
|
501 |
- `dim` (int): The dimension.
|
502 |
- `val` (int): The value.
|
503 |
+
|
504 |
#### Returns:
|
505 |
- `int`: The downscaled position.
|
506 |
"""
|
|
|
519 |
|
520 |
def mult_list_upscale(a: list) -> list:
|
521 |
"""#### Multiply a list by the upscale amount.
|
522 |
+
|
523 |
#### Args:
|
524 |
- `a` (list): The list.
|
525 |
+
|
526 |
#### Returns:
|
527 |
- `list`: The multiplied list.
|
528 |
"""
|
|
|
612 |
pbar: any = None,
|
613 |
):
|
614 |
"""#### Scale an image using a tiled approach.
|
615 |
+
|
616 |
#### Args:
|
617 |
- `samples` (torch.Tensor): The input samples.
|
618 |
- `function` (function): The scaling function.
|
|
|
623 |
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
|
624 |
- `output_device` (str, optional): The output device. Defaults to "cpu".
|
625 |
- `pbar` (any, optional): The progress bar. Defaults to None.
|
626 |
+
|
627 |
#### Returns:
|
628 |
- The scaled image.
|
629 |
"""
|
modules/cond/cond.py
CHANGED
@@ -42,13 +42,13 @@ class CONDRegular:
|
|
42 |
return self._copy_with(
|
43 |
util.repeat_to_batch_size(self.cond, batch_size).to(device)
|
44 |
)
|
45 |
-
|
46 |
def can_concat(self, other: "CONDRegular") -> bool:
|
47 |
"""#### Check if conditions can be concatenated.
|
48 |
-
|
49 |
#### Args:
|
50 |
- `other` (CONDRegular): The other condition.
|
51 |
-
|
52 |
#### Returns:
|
53 |
- `bool`: True if conditions can be concatenated, False otherwise.
|
54 |
"""
|
@@ -58,10 +58,10 @@ class CONDRegular:
|
|
58 |
|
59 |
def concat(self, others: list) -> torch.Tensor:
|
60 |
"""#### Concatenate conditions.
|
61 |
-
|
62 |
#### Args:
|
63 |
- `others` (list): The list of other conditions.
|
64 |
-
|
65 |
#### Returns:
|
66 |
- `torch.Tensor`: The concatenated conditions.
|
67 |
"""
|
@@ -76,11 +76,11 @@ class CONDCrossAttn(CONDRegular):
|
|
76 |
|
77 |
def can_concat(self, other: "CONDRegular") -> bool:
|
78 |
"""#### Check if conditions can be concatenated.
|
79 |
-
|
80 |
#### Args:
|
81 |
- `other` (CONDRegular): The other condition.
|
82 |
-
|
83 |
-
#### Returns:
|
84 |
- `bool`: True if conditions can be concatenated, False otherwise.
|
85 |
"""
|
86 |
s1 = self.cond.shape
|
@@ -96,31 +96,34 @@ class CONDCrossAttn(CONDRegular):
|
|
96 |
): # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
97 |
return False
|
98 |
return True
|
99 |
-
|
100 |
-
def concat(self, others: list) -> torch.Tensor:
|
101 |
-
"""#### Concatenate cross-attention conditions.
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
#### Returns:
|
107 |
-
- `torch.Tensor`: The concatenated conditions.
|
108 |
-
"""
|
109 |
conds = [self.cond]
|
110 |
-
|
|
|
|
|
111 |
for x in others:
|
112 |
-
|
113 |
-
|
114 |
-
conds.append(c)
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
|
126 |
def convert_cond(cond: list) -> list:
|
@@ -277,8 +280,10 @@ def calc_cond_batch(
|
|
277 |
out_c += output[o] * mult[o]
|
278 |
out_cts += mult[o]
|
279 |
|
|
|
280 |
for i in range(len(out_conds)):
|
281 |
-
|
|
|
282 |
|
283 |
return out_conds
|
284 |
|
@@ -328,48 +333,75 @@ def encode_model_conds(
|
|
328 |
conds[t] = x
|
329 |
return conds
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
#### Args:
|
335 |
-
- `conditions` (list): The list of conditions.
|
336 |
-
- `dims` (tuple): The dimensions.
|
337 |
-
- `device` (torch.device): The device.
|
338 |
-
"""
|
339 |
-
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
340 |
-
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
341 |
for i in range(len(conditions)):
|
342 |
c = conditions[i]
|
|
|
343 |
if "area" in c:
|
344 |
area = c["area"]
|
345 |
if area[0] == "percentage":
|
346 |
-
|
347 |
a = area[1:]
|
348 |
a_len = len(a) // 2
|
349 |
-
area = ()
|
350 |
-
for d in range(len(dims)):
|
351 |
-
area += (max(1, round(a[d] * dims[d])),)
|
352 |
-
for d in range(len(dims)):
|
353 |
-
area += (round(a[d + a_len] * dims[d]),)
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
if "mask" in c:
|
360 |
-
mask = c["mask"]
|
361 |
-
mask = mask.to(device=device)
|
362 |
modified = c.copy()
|
|
|
|
|
|
|
363 |
if len(mask.shape) == len(dims):
|
364 |
mask = mask.unsqueeze(0)
|
|
|
|
|
365 |
if mask.shape[1:] != dims:
|
366 |
-
mask
|
367 |
-
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
modified["mask"] = mask
|
371 |
conditions[i] = modified
|
372 |
|
|
|
373 |
def process_conds(
|
374 |
model: object,
|
375 |
noise: torch.Tensor,
|
@@ -442,4 +474,4 @@ def process_conds(
|
|
442 |
positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]
|
443 |
)
|
444 |
|
445 |
-
return conds
|
|
|
42 |
return self._copy_with(
|
43 |
util.repeat_to_batch_size(self.cond, batch_size).to(device)
|
44 |
)
|
45 |
+
|
46 |
def can_concat(self, other: "CONDRegular") -> bool:
|
47 |
"""#### Check if conditions can be concatenated.
|
48 |
+
|
49 |
#### Args:
|
50 |
- `other` (CONDRegular): The other condition.
|
51 |
+
|
52 |
#### Returns:
|
53 |
- `bool`: True if conditions can be concatenated, False otherwise.
|
54 |
"""
|
|
|
58 |
|
59 |
def concat(self, others: list) -> torch.Tensor:
|
60 |
"""#### Concatenate conditions.
|
61 |
+
|
62 |
#### Args:
|
63 |
- `others` (list): The list of other conditions.
|
64 |
+
|
65 |
#### Returns:
|
66 |
- `torch.Tensor`: The concatenated conditions.
|
67 |
"""
|
|
|
76 |
|
77 |
def can_concat(self, other: "CONDRegular") -> bool:
|
78 |
"""#### Check if conditions can be concatenated.
|
79 |
+
|
80 |
#### Args:
|
81 |
- `other` (CONDRegular): The other condition.
|
82 |
+
|
83 |
+
#### Returns:
|
84 |
- `bool`: True if conditions can be concatenated, False otherwise.
|
85 |
"""
|
86 |
s1 = self.cond.shape
|
|
|
96 |
): # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
97 |
return False
|
98 |
return True
|
|
|
|
|
|
|
99 |
|
100 |
+
def concat(self, others: list) -> torch.Tensor:
|
101 |
+
"""Optimized version of cross-attention condition concatenation."""
|
|
|
|
|
|
|
|
|
102 |
conds = [self.cond]
|
103 |
+
shapes = [self.cond.shape[1]]
|
104 |
+
|
105 |
+
# Collect all conditions and their shapes
|
106 |
for x in others:
|
107 |
+
conds.append(x.cond)
|
108 |
+
shapes.append(x.cond.shape[1])
|
|
|
109 |
|
110 |
+
# Calculate LCM more efficiently
|
111 |
+
crossattn_max_len = util.lcm_of_list(shapes)
|
112 |
+
|
113 |
+
# Process and concat in one step where possible
|
114 |
+
if all(c.shape[1] == shapes[0] for c in conds):
|
115 |
+
# All same length, simple concatenation
|
116 |
+
return torch.cat(conds)
|
117 |
+
else:
|
118 |
+
# Process conditions that need repeating
|
119 |
+
out = []
|
120 |
+
for c in conds:
|
121 |
+
if c.shape[1] < crossattn_max_len:
|
122 |
+
repeat_factor = crossattn_max_len // c.shape[1]
|
123 |
+
# Use repeat instead of individual operations
|
124 |
+
c = c.repeat(1, repeat_factor, 1)
|
125 |
+
out.append(c)
|
126 |
+
return torch.cat(out)
|
127 |
|
128 |
|
129 |
def convert_cond(cond: list) -> list:
|
|
|
280 |
out_c += output[o] * mult[o]
|
281 |
out_cts += mult[o]
|
282 |
|
283 |
+
# Vectorize the division at the end
|
284 |
for i in range(len(out_conds)):
|
285 |
+
# Inplace division is already efficient
|
286 |
+
out_conds[i].div_(out_counts[i]) # Using .div_ instead of /= for clarity
|
287 |
|
288 |
return out_conds
|
289 |
|
|
|
333 |
conds[t] = x
|
334 |
return conds
|
335 |
|
336 |
+
|
337 |
+
def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
338 |
+
"""Optimized version that processes areas and masks more efficiently"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
for i in range(len(conditions)):
|
340 |
c = conditions[i]
|
341 |
+
# Process area
|
342 |
if "area" in c:
|
343 |
area = c["area"]
|
344 |
if area[0] == "percentage":
|
345 |
+
# Vectorized calculation of area dimensions
|
346 |
a = area[1:]
|
347 |
a_len = len(a) // 2
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
+
# Calculate all dimensions at once using tensor operations
|
350 |
+
dims_tensor = torch.tensor(dims, device="cpu")
|
351 |
+
first_part = torch.tensor(a[:a_len], device="cpu") * dims_tensor
|
352 |
+
second_part = torch.tensor(a[a_len:], device="cpu") * dims_tensor
|
353 |
+
|
354 |
+
# Convert to rounded integers and tuple
|
355 |
+
first_part = torch.max(
|
356 |
+
torch.ones_like(first_part), torch.round(first_part)
|
357 |
+
)
|
358 |
+
second_part = torch.round(second_part)
|
359 |
|
360 |
+
# Create the new area tuple
|
361 |
+
new_area = tuple(first_part.int().tolist()) + tuple(
|
362 |
+
second_part.int().tolist()
|
363 |
+
)
|
364 |
+
|
365 |
+
# Create a modified copy with the new area
|
366 |
+
modified = c.copy()
|
367 |
+
modified["area"] = new_area
|
368 |
+
conditions[i] = modified
|
369 |
+
|
370 |
+
# Process mask
|
371 |
if "mask" in c:
|
|
|
|
|
372 |
modified = c.copy()
|
373 |
+
mask = c["mask"].to(device=device)
|
374 |
+
|
375 |
+
# Combine dimension checks and unsqueeze operation
|
376 |
if len(mask.shape) == len(dims):
|
377 |
mask = mask.unsqueeze(0)
|
378 |
+
|
379 |
+
# Only interpolate if needed
|
380 |
if mask.shape[1:] != dims:
|
381 |
+
# Optimize interpolation by ensuring mask is in the right format for the operation
|
382 |
+
if len(mask.shape) == 3 and mask.shape[0] == 1:
|
383 |
+
# Already in the right format for interpolation
|
384 |
+
mask = torch.nn.functional.interpolate(
|
385 |
+
mask.unsqueeze(1),
|
386 |
+
size=dims,
|
387 |
+
mode="bilinear",
|
388 |
+
align_corners=False,
|
389 |
+
).squeeze(1)
|
390 |
+
else:
|
391 |
+
# Ensure mask is properly formatted for interpolation
|
392 |
+
mask = torch.nn.functional.interpolate(
|
393 |
+
mask
|
394 |
+
if len(mask.shape) > 3 and mask.shape[1] == 1
|
395 |
+
else mask.unsqueeze(1),
|
396 |
+
size=dims,
|
397 |
+
mode="bilinear",
|
398 |
+
align_corners=False,
|
399 |
+
).squeeze(1)
|
400 |
|
401 |
modified["mask"] = mask
|
402 |
conditions[i] = modified
|
403 |
|
404 |
+
|
405 |
def process_conds(
|
406 |
model: object,
|
407 |
noise: torch.Tensor,
|
|
|
474 |
positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]
|
475 |
)
|
476 |
|
477 |
+
return conds
|
modules/sample/CFG.py
CHANGED
@@ -30,10 +30,15 @@ def cfg_function(
|
|
30 |
#### Returns:
|
31 |
- `torch.Tensor`: The CFG result.
|
32 |
"""
|
|
|
33 |
if "sampler_cfg_function" in model_options:
|
|
|
|
|
|
|
|
|
34 |
args = {
|
35 |
-
"cond":
|
36 |
-
"uncond":
|
37 |
"cond_scale": cond_scale,
|
38 |
"timestep": timestep,
|
39 |
"input": x,
|
@@ -45,9 +50,18 @@ def cfg_function(
|
|
45 |
}
|
46 |
cfg_result = x - model_options["sampler_cfg_function"](args)
|
47 |
else:
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
args = {
|
52 |
"denoised": cfg_result,
|
53 |
"cond": cond,
|
@@ -59,7 +73,12 @@ def cfg_function(
|
|
59 |
"model_options": model_options,
|
60 |
"input": x,
|
61 |
}
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
return cfg_result
|
65 |
|
@@ -89,21 +108,29 @@ def sampling_function(
|
|
89 |
#### Returns:
|
90 |
- `torch.Tensor`: The sampled tensor.
|
91 |
"""
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
99 |
|
|
|
100 |
conds = [condo, uncond_]
|
101 |
-
out = cond.calc_cond_batch(model, conds, x, timestep, model_options)
|
102 |
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
args = {
|
105 |
"conds": conds,
|
106 |
-
"conds_out":
|
107 |
"cond_scale": cond_scale,
|
108 |
"timestep": timestep,
|
109 |
"input": x,
|
@@ -111,12 +138,20 @@ def sampling_function(
|
|
111 |
"model": model,
|
112 |
"model_options": model_options,
|
113 |
}
|
114 |
-
out = fn(args)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
return cfg_function(
|
117 |
model,
|
118 |
-
|
119 |
-
|
120 |
cond_scale,
|
121 |
x,
|
122 |
timestep,
|
@@ -128,6 +163,7 @@ def sampling_function(
|
|
128 |
|
129 |
class CFGGuider:
|
130 |
"""#### Class for guiding the sampling process with CFG."""
|
|
|
131 |
def __init__(self, model_patcher, flux=False):
|
132 |
"""#### Initialize the CFGGuider.
|
133 |
|
@@ -315,4 +351,4 @@ class CFGGuider:
|
|
315 |
del self.inner_model
|
316 |
del self.conds
|
317 |
del self.loaded_models
|
318 |
-
return output
|
|
|
30 |
#### Returns:
|
31 |
- `torch.Tensor`: The CFG result.
|
32 |
"""
|
33 |
+
# Check for custom sampler CFG function first
|
34 |
if "sampler_cfg_function" in model_options:
|
35 |
+
# Precompute differences to avoid redundant operations
|
36 |
+
cond_diff = x - cond_pred
|
37 |
+
uncond_diff = x - uncond_pred
|
38 |
+
|
39 |
args = {
|
40 |
+
"cond": cond_diff,
|
41 |
+
"uncond": uncond_diff,
|
42 |
"cond_scale": cond_scale,
|
43 |
"timestep": timestep,
|
44 |
"input": x,
|
|
|
50 |
}
|
51 |
cfg_result = x - model_options["sampler_cfg_function"](args)
|
52 |
else:
|
53 |
+
# Standard CFG calculation - optimized to avoid intermediate tensor allocation
|
54 |
+
# When cond_scale = 1.0, we can just return cond_pred without computation
|
55 |
+
if math.isclose(cond_scale, 1.0):
|
56 |
+
cfg_result = cond_pred
|
57 |
+
else:
|
58 |
+
# Fused operation: uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
59 |
+
# Equivalent to: uncond_pred * (1 - cond_scale) + cond_pred * cond_scale
|
60 |
+
cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale)
|
61 |
+
|
62 |
+
# Apply post-CFG functions if any
|
63 |
+
post_cfg_functions = model_options.get("sampler_post_cfg_function", [])
|
64 |
+
if post_cfg_functions:
|
65 |
args = {
|
66 |
"denoised": cfg_result,
|
67 |
"cond": cond,
|
|
|
73 |
"model_options": model_options,
|
74 |
"input": x,
|
75 |
}
|
76 |
+
|
77 |
+
# Apply each post-CFG function in sequence
|
78 |
+
for fn in post_cfg_functions:
|
79 |
+
cfg_result = fn(args)
|
80 |
+
# Update the denoised result for the next function
|
81 |
+
args["denoised"] = cfg_result
|
82 |
|
83 |
return cfg_result
|
84 |
|
|
|
108 |
#### Returns:
|
109 |
- `torch.Tensor`: The sampled tensor.
|
110 |
"""
|
111 |
+
# Optimize conditional logic for uncond
|
112 |
+
uncond_ = (
|
113 |
+
None
|
114 |
+
if (
|
115 |
+
math.isclose(cond_scale, 1.0)
|
116 |
+
and not model_options.get("disable_cfg1_optimization", False)
|
117 |
+
)
|
118 |
+
else uncond
|
119 |
+
)
|
120 |
|
121 |
+
# Create conditions list once
|
122 |
conds = [condo, uncond_]
|
|
|
123 |
|
124 |
+
# Get model predictions for both conditions
|
125 |
+
cond_outputs = cond.calc_cond_batch(model, conds, x, timestep, model_options)
|
126 |
+
|
127 |
+
# Apply pre-CFG functions if any
|
128 |
+
pre_cfg_functions = model_options.get("sampler_pre_cfg_function", [])
|
129 |
+
if pre_cfg_functions:
|
130 |
+
# Create args dictionary once
|
131 |
args = {
|
132 |
"conds": conds,
|
133 |
+
"conds_out": cond_outputs,
|
134 |
"cond_scale": cond_scale,
|
135 |
"timestep": timestep,
|
136 |
"input": x,
|
|
|
138 |
"model": model,
|
139 |
"model_options": model_options,
|
140 |
}
|
|
|
141 |
|
142 |
+
# Apply each pre-CFG function
|
143 |
+
for fn in pre_cfg_functions:
|
144 |
+
cond_outputs = fn(args)
|
145 |
+
args["conds_out"] = cond_outputs
|
146 |
+
|
147 |
+
# Extract conditional and unconditional outputs explicitly for clarity
|
148 |
+
cond_pred, uncond_pred = cond_outputs[0], cond_outputs[1]
|
149 |
+
|
150 |
+
# Apply the CFG function
|
151 |
return cfg_function(
|
152 |
model,
|
153 |
+
cond_pred,
|
154 |
+
uncond_pred,
|
155 |
cond_scale,
|
156 |
x,
|
157 |
timestep,
|
|
|
163 |
|
164 |
class CFGGuider:
|
165 |
"""#### Class for guiding the sampling process with CFG."""
|
166 |
+
|
167 |
def __init__(self, model_patcher, flux=False):
|
168 |
"""#### Initialize the CFGGuider.
|
169 |
|
|
|
351 |
del self.inner_model
|
352 |
del self.conds
|
353 |
del self.loaded_models
|
354 |
+
return output
|
modules/sample/ksampler_util.py
CHANGED
@@ -46,6 +46,7 @@ def pre_run_control(model: torch.nn.Module, conds: list) -> None:
|
|
46 |
|
47 |
def percent_to_timestep_function(a):
|
48 |
return s.percent_to_sigma(a)
|
|
|
49 |
if "control" in x:
|
50 |
x["control"].pre_run(model, percent_to_timestep_function)
|
51 |
|
@@ -96,9 +97,13 @@ def apply_empty_x_to_equal_area(
|
|
96 |
uncond[temp[1]] = n
|
97 |
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
102 |
"""#### Get the area and multiplier.
|
103 |
|
104 |
#### Args:
|
@@ -109,26 +114,39 @@ def get_area_and_mult(
|
|
109 |
#### Returns:
|
110 |
- `collections.namedtuple`: The area and multiplier.
|
111 |
"""
|
112 |
-
|
113 |
-
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
mult = mask * strength
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
conditioning = {}
|
120 |
model_conds = conds["model_conds"]
|
|
|
|
|
|
|
|
|
121 |
for c in model_conds:
|
122 |
conditioning[c] = model_conds[c].process_cond(
|
123 |
-
batch_size=
|
124 |
)
|
125 |
|
|
|
126 |
control = conds.get("control", None)
|
127 |
patches = None
|
128 |
-
|
129 |
-
|
130 |
-
)
|
131 |
-
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
132 |
|
133 |
|
134 |
def normal_scheduler(
|
@@ -158,6 +176,7 @@ def normal_scheduler(
|
|
158 |
sigs += [0.0]
|
159 |
return torch.FloatTensor(sigs)
|
160 |
|
|
|
161 |
def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor:
|
162 |
"""#### Create a simple scheduler.
|
163 |
|
@@ -176,21 +195,52 @@ def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.Float
|
|
176 |
sigs += [0.0]
|
177 |
return torch.FloatTensor(sigs)
|
178 |
|
|
|
179 |
# Implemented based on: https://arxiv.org/abs/2407.12173
|
180 |
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
-
sigs = []
|
186 |
-
last_t = -1
|
187 |
-
for t in ts:
|
188 |
-
if t != last_t:
|
189 |
-
sigs += [float(model_sampling.sigmas[int(t)])]
|
190 |
-
last_t = t
|
191 |
-
sigs += [0.0]
|
192 |
return torch.FloatTensor(sigs)
|
193 |
|
|
|
194 |
def calculate_sigmas(
|
195 |
model_sampling: torch.nn.Module, scheduler_name: str, steps: int
|
196 |
) -> torch.Tensor:
|
|
|
46 |
|
47 |
def percent_to_timestep_function(a):
|
48 |
return s.percent_to_sigma(a)
|
49 |
+
|
50 |
if "control" in x:
|
51 |
x["control"].pre_run(model, percent_to_timestep_function)
|
52 |
|
|
|
97 |
uncond[temp[1]] = n
|
98 |
|
99 |
|
100 |
+
# Define the namedtuple class once outside the function for reuse
|
101 |
+
CondObj = collections.namedtuple(
|
102 |
+
"cond_obj", ["input_x", "mult", "conditioning", "area", "control", "patches"]
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
def get_area_and_mult(conds: dict, x_in: torch.Tensor, timestep_in: int) -> CondObj:
|
107 |
"""#### Get the area and multiplier.
|
108 |
|
109 |
#### Args:
|
|
|
114 |
#### Returns:
|
115 |
- `collections.namedtuple`: The area and multiplier.
|
116 |
"""
|
117 |
+
# Cache shape information to avoid repeated access
|
118 |
+
x_shape = x_in.shape
|
119 |
|
120 |
+
# Define area dimensions in one operation
|
121 |
+
area = (x_shape[2], x_shape[3], 0, 0)
|
|
|
122 |
|
123 |
+
# Extract input region efficiently
|
124 |
+
# Since area[2] and area[3] are 0, this is essentially taking the full tensor
|
125 |
+
# But we maintain the slice operation for consistency
|
126 |
+
input_x = x_in[:, :, : area[0], : area[1]]
|
127 |
+
|
128 |
+
# Create multiplier tensor directly without intermediate mask creation
|
129 |
+
# This avoids an unnecessary tensor allocation and multiplication
|
130 |
+
mult = torch.ones_like(input_x) # strength is 1.0, so just create ones directly
|
131 |
+
|
132 |
+
# Prepare conditioning dictionary with cached device and batch_size
|
133 |
conditioning = {}
|
134 |
model_conds = conds["model_conds"]
|
135 |
+
batch_size = x_shape[0]
|
136 |
+
device = x_in.device
|
137 |
+
|
138 |
+
# Process conditions with cached parameters
|
139 |
for c in model_conds:
|
140 |
conditioning[c] = model_conds[c].process_cond(
|
141 |
+
batch_size=batch_size, device=device, area=area
|
142 |
)
|
143 |
|
144 |
+
# Get control directly without redundant variable assignment
|
145 |
control = conds.get("control", None)
|
146 |
patches = None
|
147 |
+
|
148 |
+
# Use the pre-defined namedtuple class instead of creating it every call
|
149 |
+
return CondObj(input_x, mult, conditioning, area, control, patches)
|
|
|
150 |
|
151 |
|
152 |
def normal_scheduler(
|
|
|
176 |
sigs += [0.0]
|
177 |
return torch.FloatTensor(sigs)
|
178 |
|
179 |
+
|
180 |
def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor:
|
181 |
"""#### Create a simple scheduler.
|
182 |
|
|
|
195 |
sigs += [0.0]
|
196 |
return torch.FloatTensor(sigs)
|
197 |
|
198 |
+
|
199 |
# Implemented based on: https://arxiv.org/abs/2407.12173
|
200 |
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
201 |
+
"""Creates a beta scheduler for noise levels based on the beta distribution.
|
202 |
+
|
203 |
+
This optimized implementation efficiently computes sigmas using the beta
|
204 |
+
distribution and caches calculations where possible.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
model_sampling: Model sampling module
|
208 |
+
steps: Number of steps
|
209 |
+
alpha: Alpha parameter for beta distribution
|
210 |
+
beta: Beta parameter for beta distribution
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
torch.FloatTensor: Tensor of sigma values for each step
|
214 |
+
"""
|
215 |
+
# Calculate total timesteps once
|
216 |
+
total_timesteps = len(model_sampling.sigmas) - 1
|
217 |
+
|
218 |
+
# Create a cache dictionary for reused values
|
219 |
+
model_sigmas = model_sampling.sigmas
|
220 |
+
|
221 |
+
# Generate evenly spaced values in [0,1) interval
|
222 |
+
ts_normalized = np.linspace(0, 1, steps, endpoint=False)
|
223 |
+
|
224 |
+
# Apply beta inverse CDF to get sampled time points - vectorized operation
|
225 |
+
ts_beta = scipy.stats.beta.ppf(1 - ts_normalized, alpha, beta)
|
226 |
+
|
227 |
+
# Scale to timestep indices and round to integers
|
228 |
+
ts_indices = np.rint(ts_beta * total_timesteps).astype(np.int32)
|
229 |
+
|
230 |
+
# Use numpy's unique function with return_index to efficiently find unique values
|
231 |
+
# while preserving order
|
232 |
+
unique_ts, indices = np.unique(ts_indices, return_index=True)
|
233 |
+
ordered_unique_ts = unique_ts[np.argsort(indices)]
|
234 |
+
|
235 |
+
# Map indices to sigma values efficiently
|
236 |
+
sigs = [float(model_sigmas[idx]) for idx in ordered_unique_ts]
|
237 |
+
|
238 |
+
# Add final sigma value of 0.0
|
239 |
+
sigs.append(0.0)
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
return torch.FloatTensor(sigs)
|
242 |
|
243 |
+
|
244 |
def calculate_sigmas(
|
245 |
model_sampling: torch.nn.Module, scheduler_name: str, steps: int
|
246 |
) -> torch.Tensor:
|
modules/sample/samplers.py
CHANGED
@@ -142,181 +142,15 @@ def sample_euler(
|
|
142 |
return x
|
143 |
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
model,
|
148 |
-
x,
|
149 |
-
sigmas,
|
150 |
-
extra_args=None,
|
151 |
-
callback=None,
|
152 |
-
disable=None,
|
153 |
-
eta=1.0,
|
154 |
-
s_noise=1.0,
|
155 |
-
noise_sampler=None,
|
156 |
-
r=1 / 2,
|
157 |
-
pipeline=False,
|
158 |
-
seed=None,
|
159 |
-
):
|
160 |
-
# Pre-calculate common values
|
161 |
-
device = x.device
|
162 |
-
global disable_gui
|
163 |
-
disable_gui = pipeline
|
164 |
-
|
165 |
-
if not disable_gui:
|
166 |
-
from modules.AutoEncoders import taesd
|
167 |
-
from modules.user import app_instance
|
168 |
-
|
169 |
-
# Early return check
|
170 |
-
if len(sigmas) <= 1:
|
171 |
-
return x
|
172 |
-
|
173 |
-
# Pre-allocate tensors and values
|
174 |
-
s_in = torch.ones((x.shape[0],), device=device)
|
175 |
-
n_steps = len(sigmas) - 1
|
176 |
-
extra_args = {} if extra_args is None else extra_args
|
177 |
-
|
178 |
-
# Define helper functions
|
179 |
-
def sigma_fn(t):
|
180 |
-
return (-t).exp()
|
181 |
-
|
182 |
-
def t_fn(sigma):
|
183 |
-
return -sigma.log()
|
184 |
-
|
185 |
-
# Initialize noise sampler
|
186 |
-
if noise_sampler is None:
|
187 |
-
noise_sampler = sampling_util.BrownianTreeNoiseSampler(
|
188 |
-
x, sigmas[sigmas > 0].min(), sigmas.max(), seed=seed, cpu=True
|
189 |
-
)
|
190 |
-
|
191 |
-
for i in trange(n_steps, disable=disable):
|
192 |
-
if (
|
193 |
-
not pipeline
|
194 |
-
and hasattr(app_instance.app, "interrupt_flag")
|
195 |
-
and app_instance.app.interrupt_flag
|
196 |
-
):
|
197 |
-
return x
|
198 |
-
|
199 |
-
if not pipeline:
|
200 |
-
app_instance.app.progress.set(i / n_steps)
|
201 |
-
|
202 |
-
# Model inference
|
203 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
204 |
-
|
205 |
-
if callback is not None:
|
206 |
-
callback({"x": x, "i": i, "sigma": sigmas[i], "denoised": denoised})
|
207 |
-
|
208 |
-
if sigmas[i + 1] == 0:
|
209 |
-
# Single fused Euler step
|
210 |
-
x = x + util.to_d(x, sigmas[i], denoised) * (sigmas[i + 1] - sigmas[i])
|
211 |
-
else:
|
212 |
-
# Fused DPM-Solver++ steps
|
213 |
-
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
214 |
-
s = t + (t_next - t) * r
|
215 |
-
|
216 |
-
# Step 1 - Combined calculations
|
217 |
-
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
218 |
-
s_ = t_fn(sd)
|
219 |
-
x_2 = (
|
220 |
-
(sigma_fn(s_) / sigma_fn(t)) * x
|
221 |
-
- (t - s_).expm1() * denoised
|
222 |
-
+ noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
223 |
-
)
|
224 |
-
|
225 |
-
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
226 |
-
|
227 |
-
# Step 2 - Combined calculations
|
228 |
-
sd, su = sampling_util.get_ancestral_step(
|
229 |
-
sigma_fn(t), sigma_fn(t_next), eta
|
230 |
-
)
|
231 |
-
t_next_ = t_fn(sd)
|
232 |
-
|
233 |
-
# Final update in single calculation
|
234 |
-
x = (
|
235 |
-
(sigma_fn(t_next_) / sigma_fn(t)) * x
|
236 |
-
- (t - t_next_).expm1()
|
237 |
-
* ((1 - 1 / (2 * r)) * denoised + (1 / (2 * r)) * denoised_2)
|
238 |
-
+ noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
239 |
-
)
|
240 |
-
|
241 |
-
# Preview updates
|
242 |
-
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
|
243 |
-
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
|
244 |
-
|
245 |
-
return x
|
246 |
-
|
247 |
-
|
248 |
-
@torch.no_grad()
|
249 |
-
def sample_dpmpp_2m(
|
250 |
-
model,
|
251 |
-
x,
|
252 |
-
sigmas,
|
253 |
-
extra_args=None,
|
254 |
-
callback=None,
|
255 |
-
disable=None,
|
256 |
-
pipeline=False,
|
257 |
):
|
258 |
-
""
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
if not disable_gui:
|
265 |
-
from modules.AutoEncoders import taesd
|
266 |
-
from modules.user import app_instance
|
267 |
-
|
268 |
-
# Pre-allocate tensors and transform sigmas
|
269 |
-
s_in = torch.ones((x.shape[0],), device=device)
|
270 |
-
t_steps = -torch.log(sigmas) # Fused calculation
|
271 |
-
|
272 |
-
# Pre-calculate all needed values in one go
|
273 |
-
sigma_steps = torch.exp(-t_steps) # Fused calculation
|
274 |
-
ratios = sigma_steps[1:] / sigma_steps[:-1]
|
275 |
-
h_steps = t_steps[1:] - t_steps[:-1]
|
276 |
-
|
277 |
-
old_denoised = None
|
278 |
-
extra_args = {} if extra_args is None else extra_args
|
279 |
-
|
280 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
281 |
-
if (
|
282 |
-
not pipeline
|
283 |
-
and hasattr(app_instance.app, "interrupt_flag")
|
284 |
-
and app_instance.app.interrupt_flag
|
285 |
-
):
|
286 |
-
return x
|
287 |
-
|
288 |
-
if not pipeline:
|
289 |
-
app_instance.app.progress.set(i / (len(sigmas) - 1))
|
290 |
-
|
291 |
-
# Fused model inference and update calculations
|
292 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
293 |
-
|
294 |
-
if callback is not None:
|
295 |
-
callback(
|
296 |
-
{
|
297 |
-
"x": x,
|
298 |
-
"i": i,
|
299 |
-
"sigma": sigmas[i],
|
300 |
-
"sigma_hat": sigmas[i],
|
301 |
-
"denoised": denoised,
|
302 |
-
}
|
303 |
-
)
|
304 |
-
|
305 |
-
# Combined update step
|
306 |
-
x = ratios[i] * x - (-h_steps[i]).expm1() * (
|
307 |
-
denoised
|
308 |
-
if old_denoised is None or sigmas[i + 1] == 0
|
309 |
-
else (1 + h_steps[i - 1] / (2 * h_steps[i])) * denoised
|
310 |
-
- (h_steps[i - 1] / (2 * h_steps[i])) * old_denoised
|
311 |
-
)
|
312 |
-
|
313 |
-
old_denoised = denoised
|
314 |
-
|
315 |
-
# Preview updates
|
316 |
-
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
|
317 |
-
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
|
318 |
-
|
319 |
-
return x
|
320 |
|
321 |
|
322 |
@torch.no_grad()
|
@@ -354,17 +188,26 @@ def sample_dpmpp_2m_cfgpp(
|
|
354 |
ratios = sigma_steps[1:] / sigma_steps[:-1]
|
355 |
h_steps = t_steps[1:] - t_steps[:-1]
|
356 |
|
357 |
-
# CFG
|
358 |
-
|
359 |
-
|
360 |
-
progress = step / n_steps
|
361 |
-
return cfg_scale + (cfg_min - cfg_scale) * progress
|
362 |
|
363 |
old_denoised = None
|
364 |
old_uncond_denoised = None
|
365 |
extra_args = {} if extra_args is None else extra_args
|
366 |
|
367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
if (
|
369 |
not pipeline
|
370 |
and hasattr(app_instance.app, "interrupt_flag")
|
@@ -373,20 +216,10 @@ def sample_dpmpp_2m_cfgpp(
|
|
373 |
return x
|
374 |
|
375 |
if not pipeline:
|
376 |
-
app_instance.app.progress.set(i /
|
377 |
-
|
378 |
-
# Get current CFG scale
|
379 |
-
current_cfg = get_cfg_scale(i)
|
380 |
-
|
381 |
-
def post_cfg_function(args):
|
382 |
-
nonlocal old_uncond_denoised
|
383 |
-
old_uncond_denoised = args["uncond_denoised"]
|
384 |
-
return args["denoised"]
|
385 |
|
386 |
-
|
387 |
-
|
388 |
-
model_options, post_cfg_function, disable_cfg1_optimization=True
|
389 |
-
)
|
390 |
|
391 |
# Fused model inference and update calculations
|
392 |
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
@@ -406,27 +239,29 @@ def sample_dpmpp_2m_cfgpp(
|
|
406 |
}
|
407 |
)
|
408 |
|
409 |
-
# CFG++ update step
|
410 |
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
411 |
-
# First step or last step -
|
412 |
-
cfg_denoised = uncond_denoised
|
413 |
else:
|
414 |
-
#
|
415 |
-
x0_coeff = cfg_x0_scale * current_cfg
|
416 |
-
s_coeff = cfg_s_scale * current_cfg
|
417 |
-
|
418 |
-
# Momentum terms
|
419 |
h_ratio = h_steps[i - 1] / (2 * h_steps[i])
|
420 |
-
|
421 |
-
uncond_momentum = (
|
422 |
-
1 + h_ratio
|
423 |
-
) * uncond_denoised - h_ratio * old_uncond_denoised
|
424 |
|
425 |
-
#
|
426 |
-
|
|
|
|
|
|
|
427 |
|
428 |
-
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
old_denoised = denoised
|
432 |
old_uncond_denoised = uncond_denoised
|
@@ -438,17 +273,6 @@ def sample_dpmpp_2m_cfgpp(
|
|
438 |
return x
|
439 |
|
440 |
|
441 |
-
def set_model_options_post_cfg_function(
|
442 |
-
model_options, post_cfg_function, disable_cfg1_optimization=False
|
443 |
-
):
|
444 |
-
model_options["sampler_post_cfg_function"] = model_options.get(
|
445 |
-
"sampler_post_cfg_function", []
|
446 |
-
) + [post_cfg_function]
|
447 |
-
if disable_cfg1_optimization:
|
448 |
-
model_options["disable_cfg1_optimization"] = True
|
449 |
-
return model_options
|
450 |
-
|
451 |
-
|
452 |
@torch.no_grad()
|
453 |
def sample_dpmpp_sde_cfgpp(
|
454 |
model,
|
@@ -572,7 +396,6 @@ def sample_dpmpp_sde_cfgpp(
|
|
572 |
else:
|
573 |
# CFG++ with momentum
|
574 |
x0_coeff = cfg_x0_scale * current_cfg
|
575 |
-
s_coeff = cfg_s_scale * current_cfg
|
576 |
|
577 |
# Calculate momentum terms
|
578 |
h_ratio = (t - s_) / (2 * (t - t_next))
|
|
|
142 |
return x
|
143 |
|
144 |
|
145 |
+
def set_model_options_post_cfg_function(
|
146 |
+
model_options, post_cfg_function, disable_cfg1_optimization=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
):
|
148 |
+
model_options["sampler_post_cfg_function"] = model_options.get(
|
149 |
+
"sampler_post_cfg_function", []
|
150 |
+
) + [post_cfg_function]
|
151 |
+
if disable_cfg1_optimization:
|
152 |
+
model_options["disable_cfg1_optimization"] = True
|
153 |
+
return model_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
|
156 |
@torch.no_grad()
|
|
|
188 |
ratios = sigma_steps[1:] / sigma_steps[:-1]
|
189 |
h_steps = t_steps[1:] - t_steps[:-1]
|
190 |
|
191 |
+
# Pre-calculate CFG schedule for the entire sampling process
|
192 |
+
steps = torch.arange(n_steps, device=device)
|
193 |
+
cfg_values = cfg_scale + (cfg_min - cfg_scale) * (steps / n_steps)
|
|
|
|
|
194 |
|
195 |
old_denoised = None
|
196 |
old_uncond_denoised = None
|
197 |
extra_args = {} if extra_args is None else extra_args
|
198 |
|
199 |
+
# Define post-CFG function once outside the loop
|
200 |
+
def post_cfg_function(args):
|
201 |
+
nonlocal old_uncond_denoised
|
202 |
+
old_uncond_denoised = args["uncond_denoised"]
|
203 |
+
return args["denoised"]
|
204 |
+
|
205 |
+
model_options = extra_args.get("model_options", {}).copy()
|
206 |
+
extra_args["model_options"] = set_model_options_post_cfg_function(
|
207 |
+
model_options, post_cfg_function, disable_cfg1_optimization=True
|
208 |
+
)
|
209 |
+
|
210 |
+
for i in trange(n_steps, disable=disable):
|
211 |
if (
|
212 |
not pipeline
|
213 |
and hasattr(app_instance.app, "interrupt_flag")
|
|
|
216 |
return x
|
217 |
|
218 |
if not pipeline:
|
219 |
+
app_instance.app.progress.set(i / n_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
# Use pre-calculated CFG scale
|
222 |
+
current_cfg = cfg_values[i]
|
|
|
|
|
223 |
|
224 |
# Fused model inference and update calculations
|
225 |
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
239 |
}
|
240 |
)
|
241 |
|
242 |
+
# CFG++ update step using optimized operations
|
243 |
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
244 |
+
# First step or last step - use torch.lerp for efficient interpolation
|
245 |
+
cfg_denoised = torch.lerp(uncond_denoised, denoised, current_cfg)
|
246 |
else:
|
247 |
+
# Fused momentum calculations
|
|
|
|
|
|
|
|
|
248 |
h_ratio = h_steps[i - 1] / (2 * h_steps[i])
|
249 |
+
h_ratio_plus_1 = 1 + h_ratio
|
|
|
|
|
|
|
250 |
|
251 |
+
# Use fused multiply-add operations for momentum terms
|
252 |
+
momentum = torch.addcmul(denoised * h_ratio_plus_1, old_denoised, -h_ratio)
|
253 |
+
uncond_momentum = torch.addcmul(
|
254 |
+
uncond_denoised * h_ratio_plus_1, old_uncond_denoised, -h_ratio
|
255 |
+
)
|
256 |
|
257 |
+
# Optimized interpolation for CFG++ update
|
258 |
+
cfg_denoised = torch.lerp(
|
259 |
+
uncond_momentum, momentum, current_cfg * cfg_x0_scale
|
260 |
+
)
|
261 |
+
|
262 |
+
# Apply update with pre-calculated expm1
|
263 |
+
h_expm1 = torch.expm1(-h_steps[i])
|
264 |
+
x = ratios[i] * x - h_expm1 * cfg_denoised
|
265 |
|
266 |
old_denoised = denoised
|
267 |
old_uncond_denoised = uncond_denoised
|
|
|
273 |
return x
|
274 |
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
@torch.no_grad()
|
277 |
def sample_dpmpp_sde_cfgpp(
|
278 |
model,
|
|
|
396 |
else:
|
397 |
# CFG++ with momentum
|
398 |
x0_coeff = cfg_x0_scale * current_cfg
|
|
|
399 |
|
400 |
# Calculate momentum terms
|
401 |
h_ratio = (t - s_) / (2 * (t - t_next))
|
modules/sample/sampling.py
CHANGED
@@ -76,6 +76,7 @@ class EPS:
|
|
76 |
if max_denoise:
|
77 |
noise = noise * torch.sqrt(1.0 + sigma**2.0)
|
78 |
else:
|
|
|
79 |
noise = noise * sigma
|
80 |
|
81 |
noise += latent_image
|
@@ -513,153 +514,22 @@ def ksampler(
|
|
513 |
#### Returns:
|
514 |
- `KSAMPLER`: The KSAMPLER object.
|
515 |
"""
|
516 |
-
if sampler_name == "
|
517 |
-
|
518 |
-
def dpmpp_2m_function(
|
519 |
-
model: torch.nn.Module,
|
520 |
-
noise: torch.Tensor,
|
521 |
-
sigmas: torch.Tensor,
|
522 |
-
extra_args: dict,
|
523 |
-
callback: callable,
|
524 |
-
disable: bool,
|
525 |
-
pipeline: bool,
|
526 |
-
**extra_options,
|
527 |
-
) -> torch.Tensor:
|
528 |
-
sigma_min = sigmas[-1]
|
529 |
-
if sigma_min == 0:
|
530 |
-
sigma_min = sigmas[-2]
|
531 |
-
return samplers.sample_dpmpp_2m(
|
532 |
-
model,
|
533 |
-
noise,
|
534 |
-
sigmas,
|
535 |
-
extra_args=extra_args,
|
536 |
-
callback=callback,
|
537 |
-
disable=disable,
|
538 |
-
pipeline=pipeline,
|
539 |
-
**extra_options,
|
540 |
-
)
|
541 |
-
|
542 |
-
sampler_function = dpmpp_2m_function
|
543 |
-
|
544 |
-
elif sampler_name == "dpmpp_2m_cfgpp":
|
545 |
-
|
546 |
-
def dpmpp_2m_dy_function(
|
547 |
-
model: torch.nn.Module,
|
548 |
-
noise: torch.Tensor,
|
549 |
-
sigmas: torch.Tensor,
|
550 |
-
extra_args: dict,
|
551 |
-
callback: callable,
|
552 |
-
disable: bool,
|
553 |
-
pipeline: bool,
|
554 |
-
**extra_options,
|
555 |
-
) -> torch.Tensor:
|
556 |
-
sigma_min = sigmas[-1]
|
557 |
-
if sigma_min == 0:
|
558 |
-
sigma_min = sigmas[-2]
|
559 |
-
return samplers.sample_dpmpp_2m_cfgpp(
|
560 |
-
model,
|
561 |
-
noise,
|
562 |
-
sigmas,
|
563 |
-
extra_args=extra_args,
|
564 |
-
callback=callback,
|
565 |
-
disable=disable,
|
566 |
-
pipeline=pipeline,
|
567 |
-
**extra_options,
|
568 |
-
)
|
569 |
-
|
570 |
-
sampler_function = dpmpp_2m_dy_function
|
571 |
-
|
572 |
-
elif sampler_name == "dpmpp_sde":
|
573 |
-
|
574 |
-
def dpmpp_sde_function(
|
575 |
-
model: torch.nn.Module,
|
576 |
-
noise: torch.Tensor,
|
577 |
-
sigmas: torch.Tensor,
|
578 |
-
extra_args: dict,
|
579 |
-
callback: callable,
|
580 |
-
disable: bool,
|
581 |
-
pipeline: bool,
|
582 |
-
**extra_options,
|
583 |
-
) -> torch.Tensor:
|
584 |
-
return samplers.sample_dpmpp_sde(
|
585 |
-
model,
|
586 |
-
noise,
|
587 |
-
sigmas,
|
588 |
-
extra_args=extra_args,
|
589 |
-
callback=callback,
|
590 |
-
disable=disable,
|
591 |
-
pipeline=pipeline,
|
592 |
-
**extra_options,
|
593 |
-
)
|
594 |
-
|
595 |
-
sampler_function = dpmpp_sde_function
|
596 |
|
597 |
elif sampler_name == "euler_ancestral":
|
598 |
-
|
599 |
-
def euler_ancestral_function(
|
600 |
-
model: torch.nn.Module,
|
601 |
-
noise: torch.Tensor,
|
602 |
-
sigmas: torch.Tensor,
|
603 |
-
extra_args: dict,
|
604 |
-
callback: callable,
|
605 |
-
disable: bool,
|
606 |
-
pipeline: bool,
|
607 |
-
) -> torch.Tensor:
|
608 |
-
return samplers.sample_euler_ancestral(
|
609 |
-
model,
|
610 |
-
noise,
|
611 |
-
sigmas,
|
612 |
-
extra_args=extra_args,
|
613 |
-
callback=callback,
|
614 |
-
disable=disable,
|
615 |
-
pipeline=pipeline,
|
616 |
-
**extra_options,
|
617 |
-
)
|
618 |
-
|
619 |
-
sampler_function = euler_ancestral_function
|
620 |
|
621 |
elif sampler_name == "dpmpp_sde_cfgpp":
|
622 |
-
|
623 |
-
def dpmpp_sde_dy_function(
|
624 |
-
model: torch.nn.Module,
|
625 |
-
noise: torch.Tensor,
|
626 |
-
sigmas: torch.Tensor,
|
627 |
-
extra_args: dict,
|
628 |
-
callback: callable,
|
629 |
-
disable: bool,
|
630 |
-
pipeline: bool,
|
631 |
-
**extra_options,
|
632 |
-
) -> torch.Tensor:
|
633 |
-
return samplers.sample_dpmpp_sde_cfgpp(
|
634 |
-
model,
|
635 |
-
noise,
|
636 |
-
sigmas,
|
637 |
-
extra_args=extra_args,
|
638 |
-
callback=callback,
|
639 |
-
disable=disable,
|
640 |
-
pipeline=pipeline,
|
641 |
-
**extra_options,
|
642 |
-
)
|
643 |
-
|
644 |
-
sampler_function = dpmpp_sde_dy_function
|
645 |
|
646 |
elif sampler_name == "euler":
|
|
|
647 |
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
model,
|
653 |
-
noise,
|
654 |
-
sigmas,
|
655 |
-
extra_args=extra_args,
|
656 |
-
callback=callback,
|
657 |
-
disable=disable,
|
658 |
-
pipeline=pipeline,
|
659 |
-
**extra_options,
|
660 |
-
)
|
661 |
-
|
662 |
-
sampler_function = euler_function
|
663 |
|
664 |
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
665 |
|
@@ -734,49 +604,49 @@ def sampler_object(name: str, pipeline: bool = False) -> KSAMPLER:
|
|
734 |
return sampler
|
735 |
|
736 |
|
737 |
-
class
|
738 |
-
"""
|
739 |
|
740 |
def __init__(
|
741 |
self,
|
742 |
-
model: torch.nn.Module,
|
743 |
-
steps: int,
|
744 |
-
device,
|
745 |
sampler: str = None,
|
746 |
scheduler: str = None,
|
747 |
-
denoise: float =
|
748 |
model_options: dict = {},
|
749 |
pipeline: bool = False,
|
750 |
):
|
751 |
-
"""
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
|
762 |
"""
|
763 |
self.model = model
|
764 |
-
self.device =
|
765 |
self.scheduler = scheduler
|
766 |
-
self.
|
767 |
-
self.set_steps(steps, denoise)
|
768 |
self.denoise = denoise
|
769 |
self.model_options = model_options
|
770 |
self.pipeline = pipeline
|
771 |
|
|
|
|
|
|
|
772 |
def calculate_sigmas(self, steps: int) -> torch.Tensor:
|
773 |
-
"""
|
774 |
|
775 |
-
|
776 |
-
|
777 |
|
778 |
-
|
779 |
-
|
780 |
"""
|
781 |
sigmas = ksampler_util.calculate_sigmas(
|
782 |
self.model.get_model_object("model_sampling"), self.scheduler, steps
|
@@ -784,11 +654,11 @@ class KSampler1:
|
|
784 |
return sigmas
|
785 |
|
786 |
def set_steps(self, steps: int, denoise: float = None):
|
787 |
-
"""
|
788 |
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
"""
|
793 |
self.steps = steps
|
794 |
if denoise is None or denoise > 0.9999:
|
@@ -801,7 +671,29 @@ class KSampler1:
|
|
801 |
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
802 |
self.sigmas = sigmas[-(steps + 1) :]
|
803 |
|
804 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
805 |
self,
|
806 |
noise: torch.Tensor,
|
807 |
positive: torch.Tensor,
|
@@ -816,48 +708,45 @@ class KSampler1:
|
|
816 |
callback: callable = None,
|
817 |
disable_pbar: bool = False,
|
818 |
seed: int = None,
|
819 |
-
pipeline: bool = False,
|
820 |
flux: bool = False,
|
821 |
) -> torch.Tensor:
|
822 |
-
"""
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
"""
|
|
|
|
|
|
|
843 |
if sigmas is None:
|
844 |
sigmas = self.sigmas
|
845 |
|
846 |
-
|
847 |
-
sigmas = sigmas[: last_step + 1]
|
848 |
-
if force_full_denoise:
|
849 |
-
sigmas[-1] = 0
|
850 |
|
851 |
-
|
852 |
-
|
853 |
-
|
|
|
854 |
else:
|
855 |
-
|
856 |
-
return latent_image
|
857 |
-
else:
|
858 |
-
return torch.zeros_like(noise)
|
859 |
|
860 |
-
|
861 |
|
862 |
return sample(
|
863 |
self.model,
|
@@ -866,7 +755,7 @@ class KSampler1:
|
|
866 |
negative,
|
867 |
cfg,
|
868 |
self.device,
|
869 |
-
|
870 |
sigmas,
|
871 |
self.model_options,
|
872 |
latent_image=latent_image,
|
@@ -874,11 +763,117 @@ class KSampler1:
|
|
874 |
callback=callback,
|
875 |
disable_pbar=disable_pbar,
|
876 |
seed=seed,
|
877 |
-
pipeline=pipeline,
|
878 |
flux=flux,
|
879 |
)
|
880 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
881 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
882 |
def sample1(
|
883 |
model: torch.nn.Module,
|
884 |
noise: torch.Tensor,
|
@@ -902,37 +897,37 @@ def sample1(
|
|
902 |
pipeline: bool = False,
|
903 |
flux: bool = False,
|
904 |
) -> torch.Tensor:
|
905 |
-
"""
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
|
|
931 |
"""
|
932 |
-
sampler =
|
933 |
-
model,
|
934 |
steps=steps,
|
935 |
-
device=model.load_device,
|
936 |
sampler=sampler_name,
|
937 |
scheduler=scheduler,
|
938 |
denoise=denoise,
|
@@ -940,7 +935,7 @@ def sample1(
|
|
940 |
pipeline=pipeline,
|
941 |
)
|
942 |
|
943 |
-
samples = sampler.
|
944 |
noise,
|
945 |
positive,
|
946 |
negative,
|
@@ -954,147 +949,12 @@ def sample1(
|
|
954 |
callback=callback,
|
955 |
disable_pbar=disable_pbar,
|
956 |
seed=seed,
|
957 |
-
pipeline=pipeline,
|
958 |
flux=flux,
|
959 |
)
|
960 |
samples = samples.to(Device.intermediate_device())
|
961 |
return samples
|
962 |
|
963 |
|
964 |
-
def common_ksampler(
|
965 |
-
model: torch.nn.Module,
|
966 |
-
seed: int,
|
967 |
-
steps: int,
|
968 |
-
cfg: float,
|
969 |
-
sampler_name: str,
|
970 |
-
scheduler: str,
|
971 |
-
positive: torch.Tensor,
|
972 |
-
negative: torch.Tensor,
|
973 |
-
latent: dict,
|
974 |
-
denoise: float = 1.0,
|
975 |
-
disable_noise: bool = False,
|
976 |
-
start_step: int = None,
|
977 |
-
last_step: int = None,
|
978 |
-
force_full_denoise: bool = False,
|
979 |
-
pipeline: bool = False,
|
980 |
-
flux: bool = False,
|
981 |
-
) -> tuple:
|
982 |
-
"""#### Common ksampler function.
|
983 |
-
|
984 |
-
#### Args:
|
985 |
-
- `model` (torch.nn.Module): The model.
|
986 |
-
- `seed` (int): The seed value.
|
987 |
-
- `steps` (int): The number of steps.
|
988 |
-
- `cfg` (float): The CFG value.
|
989 |
-
- `sampler_name` (str): The sampler name.
|
990 |
-
- `scheduler` (str): The scheduler name.
|
991 |
-
- `positive` (torch.Tensor): The positive tensor.
|
992 |
-
- `negative` (torch.Tensor): The negative tensor.
|
993 |
-
- `latent` (dict): The latent dictionary.
|
994 |
-
- `denoise` (float, optional): The denoise factor. Defaults to 1.0.
|
995 |
-
- `disable_noise` (bool, optional): Whether to disable noise. Defaults to False.
|
996 |
-
- `start_step` (int, optional): The start step. Defaults to None.
|
997 |
-
- `last_step` (int, optional): The last step. Defaults to None.
|
998 |
-
- `force_full_denoise` (bool, optional): Whether to force full denoise. Defaults to False.
|
999 |
-
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
|
1000 |
-
|
1001 |
-
#### Returns:
|
1002 |
-
- `tuple`: The output tuple containing the latent dictionary and samples.
|
1003 |
-
"""
|
1004 |
-
latent_image = latent["samples"]
|
1005 |
-
latent_image = Latent.fix_empty_latent_channels(model, latent_image)
|
1006 |
-
|
1007 |
-
if disable_noise:
|
1008 |
-
noise = torch.zeros(
|
1009 |
-
latent_image.size(),
|
1010 |
-
dtype=latent_image.dtype,
|
1011 |
-
layout=latent_image.layout,
|
1012 |
-
device="cpu",
|
1013 |
-
)
|
1014 |
-
else:
|
1015 |
-
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
1016 |
-
noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
|
1017 |
-
|
1018 |
-
noise_mask = None
|
1019 |
-
if "noise_mask" in latent:
|
1020 |
-
noise_mask = latent["noise_mask"]
|
1021 |
-
samples = sample1(
|
1022 |
-
model,
|
1023 |
-
noise,
|
1024 |
-
steps,
|
1025 |
-
cfg,
|
1026 |
-
sampler_name,
|
1027 |
-
scheduler,
|
1028 |
-
positive,
|
1029 |
-
negative,
|
1030 |
-
latent_image,
|
1031 |
-
denoise=denoise,
|
1032 |
-
disable_noise=disable_noise,
|
1033 |
-
start_step=start_step,
|
1034 |
-
last_step=last_step,
|
1035 |
-
force_full_denoise=force_full_denoise,
|
1036 |
-
noise_mask=noise_mask,
|
1037 |
-
seed=seed,
|
1038 |
-
pipeline=pipeline,
|
1039 |
-
flux=flux,
|
1040 |
-
)
|
1041 |
-
out = latent.copy()
|
1042 |
-
out["samples"] = samples
|
1043 |
-
return (out,)
|
1044 |
-
|
1045 |
-
|
1046 |
-
class KSampler2:
|
1047 |
-
"""#### Class for KSampler2."""
|
1048 |
-
|
1049 |
-
def sample(
|
1050 |
-
self,
|
1051 |
-
model: torch.nn.Module,
|
1052 |
-
seed: int,
|
1053 |
-
steps: int,
|
1054 |
-
cfg: float,
|
1055 |
-
sampler_name: str,
|
1056 |
-
scheduler: str,
|
1057 |
-
positive: torch.Tensor,
|
1058 |
-
negative: torch.Tensor,
|
1059 |
-
latent_image: torch.Tensor,
|
1060 |
-
denoise: float = 1.0,
|
1061 |
-
pipeline: bool = False,
|
1062 |
-
flux: bool = False,
|
1063 |
-
) -> tuple:
|
1064 |
-
"""#### Sample using the KSampler2.
|
1065 |
-
|
1066 |
-
#### Args:
|
1067 |
-
- `model` (torch.nn.Module): The model.
|
1068 |
-
- `seed` (int): The seed value.
|
1069 |
-
- `steps` (int): The number of steps.
|
1070 |
-
- `cfg` (float): The CFG value.
|
1071 |
-
- `sampler_name` (str): The sampler name.
|
1072 |
-
- `scheduler` (str): The scheduler name.
|
1073 |
-
- `positive` (torch.Tensor): The positive tensor.
|
1074 |
-
- `negative` (torch.Tensor): The negative tensor.
|
1075 |
-
- `latent_image` (torch.Tensor): The latent image tensor.
|
1076 |
-
- `denoise` (float, optional): The denoise factor. Defaults to 1.0.
|
1077 |
-
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
|
1078 |
-
|
1079 |
-
#### Returns:
|
1080 |
-
- `tuple`: The output tuple containing the latent dictionary and samples.
|
1081 |
-
"""
|
1082 |
-
return common_ksampler(
|
1083 |
-
model,
|
1084 |
-
seed,
|
1085 |
-
steps,
|
1086 |
-
cfg,
|
1087 |
-
sampler_name,
|
1088 |
-
scheduler,
|
1089 |
-
positive,
|
1090 |
-
negative,
|
1091 |
-
latent_image,
|
1092 |
-
denoise=denoise,
|
1093 |
-
pipeline=pipeline,
|
1094 |
-
flux=flux,
|
1095 |
-
)
|
1096 |
-
|
1097 |
-
|
1098 |
class ModelType(Enum):
|
1099 |
"""#### Enum for Model Types."""
|
1100 |
|
@@ -1187,3 +1047,86 @@ def sample_custom(
|
|
1187 |
)
|
1188 |
samples = samples.to(Device.intermediate_device())
|
1189 |
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
if max_denoise:
|
77 |
noise = noise * torch.sqrt(1.0 + sigma**2.0)
|
78 |
else:
|
79 |
+
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
80 |
noise = noise * sigma
|
81 |
|
82 |
noise += latent_image
|
|
|
514 |
#### Returns:
|
515 |
- `KSAMPLER`: The KSAMPLER object.
|
516 |
"""
|
517 |
+
if sampler_name == "dpmpp_2m_cfgpp":
|
518 |
+
sampler_function = samplers.sample_dpmpp_2m_cfgpp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
|
520 |
elif sampler_name == "euler_ancestral":
|
521 |
+
sampler_function = samplers.sample_euler_ancestral
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
elif sampler_name == "dpmpp_sde_cfgpp":
|
524 |
+
sampler_function = samplers.sample_dpmpp_sde_cfgpp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
|
526 |
elif sampler_name == "euler":
|
527 |
+
sampler_function = samplers.sample_euler
|
528 |
|
529 |
+
else:
|
530 |
+
# Default fallback
|
531 |
+
sampler_function = samplers.sample_euler
|
532 |
+
print(f"Warning: Unknown sampler '{sampler_name}', falling back to euler")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
|
534 |
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
535 |
|
|
|
604 |
return sampler
|
605 |
|
606 |
|
607 |
+
class KSampler:
|
608 |
+
"""A unified sampler class that replaces both KSampler1 and KSampler2."""
|
609 |
|
610 |
def __init__(
|
611 |
self,
|
612 |
+
model: torch.nn.Module = None,
|
613 |
+
steps: int = None,
|
|
|
614 |
sampler: str = None,
|
615 |
scheduler: str = None,
|
616 |
+
denoise: float = 1.0,
|
617 |
model_options: dict = {},
|
618 |
pipeline: bool = False,
|
619 |
):
|
620 |
+
"""Initialize the KSampler class.
|
621 |
+
|
622 |
+
Args:
|
623 |
+
model (torch.nn.Module, optional): The model to use for sampling. Required for direct sampling.
|
624 |
+
steps (int, optional): The number of steps. Required for direct sampling.
|
625 |
+
sampler (str, optional): The sampler name. Defaults to None.
|
626 |
+
scheduler (str, optional): The scheduler name. Defaults to None.
|
627 |
+
denoise (float, optional): The denoise factor. Defaults to 1.0.
|
628 |
+
model_options (dict, optional): The model options. Defaults to {}.
|
629 |
+
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
|
|
|
630 |
"""
|
631 |
self.model = model
|
632 |
+
self.device = model.load_device if model is not None else None
|
633 |
self.scheduler = scheduler
|
634 |
+
self.sampler_name = sampler
|
|
|
635 |
self.denoise = denoise
|
636 |
self.model_options = model_options
|
637 |
self.pipeline = pipeline
|
638 |
|
639 |
+
if model is not None and steps is not None:
|
640 |
+
self.set_steps(steps, denoise)
|
641 |
+
|
642 |
def calculate_sigmas(self, steps: int) -> torch.Tensor:
|
643 |
+
"""Calculate the sigmas for the given steps.
|
644 |
|
645 |
+
Args:
|
646 |
+
steps (int): The number of steps.
|
647 |
|
648 |
+
Returns:
|
649 |
+
torch.Tensor: The calculated sigmas.
|
650 |
"""
|
651 |
sigmas = ksampler_util.calculate_sigmas(
|
652 |
self.model.get_model_object("model_sampling"), self.scheduler, steps
|
|
|
654 |
return sigmas
|
655 |
|
656 |
def set_steps(self, steps: int, denoise: float = None):
|
657 |
+
"""Set the steps and calculate the sigmas.
|
658 |
|
659 |
+
Args:
|
660 |
+
steps (int): The number of steps.
|
661 |
+
denoise (float, optional): The denoise factor. Defaults to None.
|
662 |
"""
|
663 |
self.steps = steps
|
664 |
if denoise is None or denoise > 0.9999:
|
|
|
671 |
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
672 |
self.sigmas = sigmas[-(steps + 1) :]
|
673 |
|
674 |
+
def _process_sigmas(self, sigmas, start_step, last_step, force_full_denoise):
|
675 |
+
"""Process sigmas based on start_step and last_step.
|
676 |
+
|
677 |
+
Args:
|
678 |
+
sigmas (torch.Tensor): The sigmas tensor.
|
679 |
+
start_step (int, optional): The start step. Defaults to None.
|
680 |
+
last_step (int, optional): The last step. Defaults to None.
|
681 |
+
force_full_denoise (bool): Whether to force full denoise.
|
682 |
+
|
683 |
+
Returns:
|
684 |
+
torch.Tensor: The processed sigmas.
|
685 |
+
"""
|
686 |
+
if last_step is not None and last_step < (len(sigmas) - 1):
|
687 |
+
sigmas = sigmas[: last_step + 1]
|
688 |
+
if force_full_denoise:
|
689 |
+
sigmas[-1] = 0
|
690 |
+
|
691 |
+
if start_step is not None and start_step < (len(sigmas) - 1):
|
692 |
+
sigmas = sigmas[start_step:]
|
693 |
+
|
694 |
+
return sigmas
|
695 |
+
|
696 |
+
def direct_sample(
|
697 |
self,
|
698 |
noise: torch.Tensor,
|
699 |
positive: torch.Tensor,
|
|
|
708 |
callback: callable = None,
|
709 |
disable_pbar: bool = False,
|
710 |
seed: int = None,
|
|
|
711 |
flux: bool = False,
|
712 |
) -> torch.Tensor:
|
713 |
+
"""Sample directly with the initialized model and parameters.
|
714 |
+
|
715 |
+
Args:
|
716 |
+
noise (torch.Tensor): The noise tensor.
|
717 |
+
positive (torch.Tensor): The positive tensor.
|
718 |
+
negative (torch.Tensor): The negative tensor.
|
719 |
+
cfg (float): The CFG value.
|
720 |
+
latent_image (torch.Tensor, optional): The latent image tensor. Defaults to None.
|
721 |
+
start_step (int, optional): The start step. Defaults to None.
|
722 |
+
last_step (int, optional): The last step. Defaults to None.
|
723 |
+
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
|
724 |
+
denoise_mask (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
|
725 |
+
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
|
726 |
+
callback (callable, optional): The callback function. Defaults to None.
|
727 |
+
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
|
728 |
+
seed (int, optional): The seed value. Defaults to None.
|
729 |
+
flux (bool, optional): Whether to use flux mode. Defaults to False.
|
730 |
+
|
731 |
+
Returns:
|
732 |
+
torch.Tensor: The sampled tensor.
|
733 |
"""
|
734 |
+
if self.model is None:
|
735 |
+
raise ValueError("Model must be provided for direct sampling")
|
736 |
+
|
737 |
if sigmas is None:
|
738 |
sigmas = self.sigmas
|
739 |
|
740 |
+
sigmas = self._process_sigmas(sigmas, start_step, last_step, force_full_denoise)
|
|
|
|
|
|
|
741 |
|
742 |
+
# Early return if needed
|
743 |
+
if start_step is not None and start_step >= (len(sigmas) - 1):
|
744 |
+
if latent_image is not None:
|
745 |
+
return latent_image
|
746 |
else:
|
747 |
+
return torch.zeros_like(noise)
|
|
|
|
|
|
|
748 |
|
749 |
+
sampler_obj = sampler_object(self.sampler_name, pipeline=self.pipeline)
|
750 |
|
751 |
return sample(
|
752 |
self.model,
|
|
|
755 |
negative,
|
756 |
cfg,
|
757 |
self.device,
|
758 |
+
sampler_obj,
|
759 |
sigmas,
|
760 |
self.model_options,
|
761 |
latent_image=latent_image,
|
|
|
763 |
callback=callback,
|
764 |
disable_pbar=disable_pbar,
|
765 |
seed=seed,
|
766 |
+
pipeline=self.pipeline,
|
767 |
flux=flux,
|
768 |
)
|
769 |
|
770 |
+
def sample(
|
771 |
+
self,
|
772 |
+
model: torch.nn.Module = None,
|
773 |
+
seed: int = None,
|
774 |
+
steps: int = None,
|
775 |
+
cfg: float = None,
|
776 |
+
sampler_name: str = None,
|
777 |
+
scheduler: str = None,
|
778 |
+
positive: torch.Tensor = None,
|
779 |
+
negative: torch.Tensor = None,
|
780 |
+
latent_image: torch.Tensor = None,
|
781 |
+
denoise: float = None,
|
782 |
+
start_step: int = None,
|
783 |
+
last_step: int = None,
|
784 |
+
force_full_denoise: bool = False,
|
785 |
+
noise_mask: torch.Tensor = None,
|
786 |
+
callback: callable = None,
|
787 |
+
disable_pbar: bool = False,
|
788 |
+
disable_noise: bool = False,
|
789 |
+
pipeline: bool = False,
|
790 |
+
flux: bool = False,
|
791 |
+
) -> tuple:
|
792 |
+
"""Unified sampling interface that works both as direct sampling and through the common_ksampler.
|
793 |
+
|
794 |
+
This method can be used in two ways:
|
795 |
+
1. If model is provided, it will create a temporary sampler and use that
|
796 |
+
2. If model is None, it will use the pre-initialized model and parameters
|
797 |
+
|
798 |
+
Args:
|
799 |
+
model (torch.nn.Module, optional): The model to use for sampling. If None, uses pre-initialized model.
|
800 |
+
seed (int, optional): The seed value.
|
801 |
+
steps (int, optional): The number of steps. If None, uses pre-initialized steps.
|
802 |
+
cfg (float, optional): The CFG value.
|
803 |
+
sampler_name (str, optional): The sampler name. If None, uses pre-initialized sampler.
|
804 |
+
scheduler (str, optional): The scheduler name. If None, uses pre-initialized scheduler.
|
805 |
+
positive (torch.Tensor, optional): The positive tensor.
|
806 |
+
negative (torch.Tensor, optional): The negative tensor.
|
807 |
+
latent_image (torch.Tensor, optional): The latent image tensor.
|
808 |
+
denoise (float, optional): The denoise factor. If None, uses pre-initialized denoise.
|
809 |
+
start_step (int, optional): The start step. Defaults to None.
|
810 |
+
last_step (int, optional): The last step. Defaults to None.
|
811 |
+
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
|
812 |
+
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
|
813 |
+
callback (callable, optional): The callback function. Defaults to None.
|
814 |
+
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
|
815 |
+
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
|
816 |
+
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
|
817 |
+
flux (bool, optional): Whether to use flux mode. Defaults to False.
|
818 |
+
|
819 |
+
Returns:
|
820 |
+
tuple: The output tuple containing either (latent_dict,) or the sampled tensor.
|
821 |
+
"""
|
822 |
+
# Case 1: Use pre-initialized model for direct sampling
|
823 |
+
if model is None:
|
824 |
+
if latent_image is None:
|
825 |
+
raise ValueError(
|
826 |
+
"latent_image must be provided when using pre-initialized model"
|
827 |
+
)
|
828 |
+
|
829 |
+
return (
|
830 |
+
self.direct_sample(
|
831 |
+
None, # noise will be generated in common_ksampler
|
832 |
+
positive,
|
833 |
+
negative,
|
834 |
+
cfg,
|
835 |
+
latent_image,
|
836 |
+
start_step,
|
837 |
+
last_step,
|
838 |
+
force_full_denoise,
|
839 |
+
noise_mask,
|
840 |
+
None, # sigmas will use pre-calculated ones
|
841 |
+
callback,
|
842 |
+
disable_pbar,
|
843 |
+
seed,
|
844 |
+
flux,
|
845 |
+
),
|
846 |
+
)
|
847 |
+
|
848 |
+
# Case 2: Use common_ksampler approach with provided model
|
849 |
+
else:
|
850 |
+
# For backwards compatibility with KSampler2 usage pattern
|
851 |
+
if isinstance(latent_image, dict):
|
852 |
+
latent = latent_image
|
853 |
+
else:
|
854 |
+
latent = {"samples": latent_image}
|
855 |
|
856 |
+
return common_ksampler(
|
857 |
+
model,
|
858 |
+
seed,
|
859 |
+
steps,
|
860 |
+
cfg,
|
861 |
+
sampler_name or self.sampler_name,
|
862 |
+
scheduler or self.scheduler,
|
863 |
+
positive,
|
864 |
+
negative,
|
865 |
+
latent,
|
866 |
+
denoise or self.denoise,
|
867 |
+
disable_noise,
|
868 |
+
start_step,
|
869 |
+
last_step,
|
870 |
+
force_full_denoise,
|
871 |
+
pipeline or self.pipeline,
|
872 |
+
flux,
|
873 |
+
)
|
874 |
+
|
875 |
+
|
876 |
+
# Refactor sample1 to use KSampler directly
|
877 |
def sample1(
|
878 |
model: torch.nn.Module,
|
879 |
noise: torch.Tensor,
|
|
|
897 |
pipeline: bool = False,
|
898 |
flux: bool = False,
|
899 |
) -> torch.Tensor:
|
900 |
+
"""Sample using the given parameters with the unified KSampler.
|
901 |
+
|
902 |
+
Args:
|
903 |
+
model (torch.nn.Module): The model.
|
904 |
+
noise (torch.Tensor): The noise tensor.
|
905 |
+
steps (int): The number of steps.
|
906 |
+
cfg (float): The CFG value.
|
907 |
+
sampler_name (str): The sampler name.
|
908 |
+
scheduler (str): The scheduler name.
|
909 |
+
positive (torch.Tensor): The positive tensor.
|
910 |
+
negative (torch.Tensor): The negative tensor.
|
911 |
+
latent_image (torch.Tensor): The latent image tensor.
|
912 |
+
denoise (float, optional): The denoise factor. Defaults to 1.0.
|
913 |
+
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
|
914 |
+
start_step (int, optional): The start step. Defaults to None.
|
915 |
+
last_step (int, optional): The last step. Defaults to None.
|
916 |
+
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
|
917 |
+
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
|
918 |
+
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
|
919 |
+
callback (callable, optional): The callback function. Defaults to None.
|
920 |
+
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
|
921 |
+
seed (int, optional): The seed value. Defaults to None.
|
922 |
+
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
|
923 |
+
flux (bool, optional): Whether to use flux mode. Defaults to False.
|
924 |
+
|
925 |
+
Returns:
|
926 |
+
torch.Tensor: The sampled tensor.
|
927 |
"""
|
928 |
+
sampler = KSampler(
|
929 |
+
model=model,
|
930 |
steps=steps,
|
|
|
931 |
sampler=sampler_name,
|
932 |
scheduler=scheduler,
|
933 |
denoise=denoise,
|
|
|
935 |
pipeline=pipeline,
|
936 |
)
|
937 |
|
938 |
+
samples = sampler.direct_sample(
|
939 |
noise,
|
940 |
positive,
|
941 |
negative,
|
|
|
949 |
callback=callback,
|
950 |
disable_pbar=disable_pbar,
|
951 |
seed=seed,
|
|
|
952 |
flux=flux,
|
953 |
)
|
954 |
samples = samples.to(Device.intermediate_device())
|
955 |
return samples
|
956 |
|
957 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
958 |
class ModelType(Enum):
|
959 |
"""#### Enum for Model Types."""
|
960 |
|
|
|
1047 |
)
|
1048 |
samples = samples.to(Device.intermediate_device())
|
1049 |
return samples
|
1050 |
+
|
1051 |
+
|
1052 |
+
def common_ksampler(
|
1053 |
+
model: torch.nn.Module,
|
1054 |
+
seed: int,
|
1055 |
+
steps: int,
|
1056 |
+
cfg: float,
|
1057 |
+
sampler_name: str,
|
1058 |
+
scheduler: str,
|
1059 |
+
positive: torch.Tensor,
|
1060 |
+
negative: torch.Tensor,
|
1061 |
+
latent: dict,
|
1062 |
+
denoise: float = 1.0,
|
1063 |
+
disable_noise: bool = False,
|
1064 |
+
start_step: int = None,
|
1065 |
+
last_step: int = None,
|
1066 |
+
force_full_denoise: bool = False,
|
1067 |
+
pipeline: bool = False,
|
1068 |
+
flux: bool = False,
|
1069 |
+
) -> tuple:
|
1070 |
+
"""Common ksampler function.
|
1071 |
+
|
1072 |
+
Args:
|
1073 |
+
model (torch.nn.Module): The model.
|
1074 |
+
seed (int): The seed value.
|
1075 |
+
steps (int): The number of steps.
|
1076 |
+
cfg (float): The CFG value.
|
1077 |
+
sampler_name (str): The sampler name.
|
1078 |
+
scheduler (str): The scheduler name.
|
1079 |
+
positive (torch.Tensor): The positive tensor.
|
1080 |
+
negative (torch.Tensor): The negative tensor.
|
1081 |
+
latent (dict): The latent dictionary.
|
1082 |
+
denoise (float, optional): The denoise factor. Defaults to 1.0.
|
1083 |
+
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
|
1084 |
+
start_step (int, optional): The start step. Defaults to None.
|
1085 |
+
last_step (int, optional): The last step. Defaults to None.
|
1086 |
+
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
|
1087 |
+
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
|
1088 |
+
flux (bool, optional): Whether to use flux mode. Defaults to False.
|
1089 |
+
|
1090 |
+
Returns:
|
1091 |
+
tuple: The output tuple containing the latent dictionary and samples.
|
1092 |
+
"""
|
1093 |
+
latent_image = latent["samples"]
|
1094 |
+
latent_image = Latent.fix_empty_latent_channels(model, latent_image)
|
1095 |
+
|
1096 |
+
if disable_noise:
|
1097 |
+
noise = torch.zeros(
|
1098 |
+
latent_image.size(),
|
1099 |
+
dtype=latent_image.dtype,
|
1100 |
+
layout=latent_image.layout,
|
1101 |
+
device="cpu",
|
1102 |
+
)
|
1103 |
+
else:
|
1104 |
+
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
1105 |
+
noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
|
1106 |
+
|
1107 |
+
noise_mask = None
|
1108 |
+
if "noise_mask" in latent:
|
1109 |
+
noise_mask = latent["noise_mask"]
|
1110 |
+
samples = sample1(
|
1111 |
+
model,
|
1112 |
+
noise,
|
1113 |
+
steps,
|
1114 |
+
cfg,
|
1115 |
+
sampler_name,
|
1116 |
+
scheduler,
|
1117 |
+
positive,
|
1118 |
+
negative,
|
1119 |
+
latent_image,
|
1120 |
+
denoise=denoise,
|
1121 |
+
disable_noise=disable_noise,
|
1122 |
+
start_step=start_step,
|
1123 |
+
last_step=last_step,
|
1124 |
+
force_full_denoise=force_full_denoise,
|
1125 |
+
noise_mask=noise_mask,
|
1126 |
+
seed=seed,
|
1127 |
+
pipeline=pipeline,
|
1128 |
+
flux=flux,
|
1129 |
+
)
|
1130 |
+
out = latent.copy()
|
1131 |
+
out["samples"] = samples
|
1132 |
+
return (out,)
|
modules/user/GUI.py
CHANGED
@@ -449,7 +449,9 @@ class App(tk.Tk):
|
|
449 |
img_tensor = img_tensor.unsqueeze(0)
|
450 |
self.interrupt_flag = False
|
451 |
self.sampler = (
|
452 |
-
"
|
|
|
|
|
453 |
)
|
454 |
with torch.inference_mode():
|
455 |
(
|
@@ -612,7 +614,7 @@ class App(tk.Tk):
|
|
612 |
)
|
613 |
self.cliptextencode = Clip.CLIPTextEncode()
|
614 |
self.emptylatentimage = Latent.EmptyLatentImage()
|
615 |
-
self.ksampler_instance = sampling.
|
616 |
self.vaedecode = VariationalAE.VAEDecode()
|
617 |
self.latent_upscale = upscale.LatentUpscale()
|
618 |
self.upscalemodelloader = USDU_upscaler.UpscaleModelLoader()
|
@@ -637,7 +639,9 @@ class App(tk.Tk):
|
|
637 |
self.generation_threads.append(current_thread)
|
638 |
self.interrupt_flag = False
|
639 |
self.sampler = (
|
640 |
-
"
|
|
|
|
|
641 |
)
|
642 |
try:
|
643 |
# Disable generate button during generation
|
@@ -955,7 +959,7 @@ class App(tk.Tk):
|
|
955 |
unetloadergguf = Quantizer.UnetLoaderGGUF()
|
956 |
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
|
957 |
conditioningzeroout = Quantizer.ConditioningZeroOut()
|
958 |
-
ksampler = sampling.
|
959 |
vaedecode = VariationalAE.VAEDecode()
|
960 |
unetloadergguf_10 = unetloadergguf.load_unet(
|
961 |
unet_name="flux1-dev-Q8_0.gguf"
|
|
|
449 |
img_tensor = img_tensor.unsqueeze(0)
|
450 |
self.interrupt_flag = False
|
451 |
self.sampler = (
|
452 |
+
"dpmpp_sde_cfgpp"
|
453 |
+
if not self.prioritize_speed_var.get()
|
454 |
+
else "dpmpp_2m_cfgpp"
|
455 |
)
|
456 |
with torch.inference_mode():
|
457 |
(
|
|
|
614 |
)
|
615 |
self.cliptextencode = Clip.CLIPTextEncode()
|
616 |
self.emptylatentimage = Latent.EmptyLatentImage()
|
617 |
+
self.ksampler_instance = sampling.KSampler()
|
618 |
self.vaedecode = VariationalAE.VAEDecode()
|
619 |
self.latent_upscale = upscale.LatentUpscale()
|
620 |
self.upscalemodelloader = USDU_upscaler.UpscaleModelLoader()
|
|
|
639 |
self.generation_threads.append(current_thread)
|
640 |
self.interrupt_flag = False
|
641 |
self.sampler = (
|
642 |
+
"dpmpp_sde_cfgpp"
|
643 |
+
if not self.prioritize_speed_var.get()
|
644 |
+
else "dpmpp_2m_cfgpp"
|
645 |
)
|
646 |
try:
|
647 |
# Disable generate button during generation
|
|
|
959 |
unetloadergguf = Quantizer.UnetLoaderGGUF()
|
960 |
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
|
961 |
conditioningzeroout = Quantizer.ConditioningZeroOut()
|
962 |
+
ksampler = sampling.KSampler()
|
963 |
vaedecode = VariationalAE.VAEDecode()
|
964 |
unetloadergguf_10 = unetloadergguf.load_unet(
|
965 |
unet_name="flux1-dev-Q8_0.gguf"
|
modules/user/pipeline.py
CHANGED
@@ -92,7 +92,7 @@ def pipeline(
|
|
92 |
hidiffoptimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple()
|
93 |
cliptextencode = Clip.CLIPTextEncode()
|
94 |
emptylatentimage = Latent.EmptyLatentImage()
|
95 |
-
ksampler_instance = sampling.
|
96 |
vaedecode = VariationalAE.VAEDecode()
|
97 |
saveimage = ImageSaver.SaveImage()
|
98 |
latent_upscale = upscale.LatentUpscale()
|
@@ -187,7 +187,7 @@ def pipeline(
|
|
187 |
unetloadergguf = Quantizer.UnetLoaderGGUF()
|
188 |
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
|
189 |
conditioningzeroout = Quantizer.ConditioningZeroOut()
|
190 |
-
ksampler = sampling.
|
191 |
unetloadergguf_10 = unetloadergguf.load_unet(
|
192 |
unet_name="flux1-dev-Q8_0.gguf"
|
193 |
)
|
@@ -283,10 +283,10 @@ def pipeline(
|
|
283 |
)
|
284 |
else:
|
285 |
applystablefast_158 = loraloader_274
|
286 |
-
fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
|
287 |
-
applystablefast_158 = fb_cache.patch(
|
288 |
-
|
289 |
-
)
|
290 |
|
291 |
ksampler_239 = ksampler_instance.sample(
|
292 |
seed=seed,
|
|
|
92 |
hidiffoptimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple()
|
93 |
cliptextencode = Clip.CLIPTextEncode()
|
94 |
emptylatentimage = Latent.EmptyLatentImage()
|
95 |
+
ksampler_instance = sampling.KSampler()
|
96 |
vaedecode = VariationalAE.VAEDecode()
|
97 |
saveimage = ImageSaver.SaveImage()
|
98 |
latent_upscale = upscale.LatentUpscale()
|
|
|
187 |
unetloadergguf = Quantizer.UnetLoaderGGUF()
|
188 |
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
|
189 |
conditioningzeroout = Quantizer.ConditioningZeroOut()
|
190 |
+
ksampler = sampling.KSampler()
|
191 |
unetloadergguf_10 = unetloadergguf.load_unet(
|
192 |
unet_name="flux1-dev-Q8_0.gguf"
|
193 |
)
|
|
|
283 |
)
|
284 |
else:
|
285 |
applystablefast_158 = loraloader_274
|
286 |
+
# fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
|
287 |
+
# applystablefast_158 = fb_cache.patch(
|
288 |
+
# applystablefast_158, "diffusion_model", 0.120
|
289 |
+
# )
|
290 |
|
291 |
ksampler_239 = ksampler_instance.sample(
|
292 |
seed=seed,
|