Aatricks commited on
Commit
cfe609e
·
verified ·
1 Parent(s): b173bc1

Upload folder using huggingface_hub

Browse files
modules/Model/ModelBase.py CHANGED
@@ -56,7 +56,9 @@ class BaseModel(torch.nn.Module):
56
  **unet_config, device=device, operations=operations
57
  )
58
  self.model_type = model_type
59
- self.model_sampling = sampling.model_sampling(model_config, model_type, flux=flux)
 
 
60
 
61
  self.adm_channels = unet_config.get("adm_in_channels", None)
62
  if self.adm_channels is None:
@@ -93,26 +95,32 @@ class BaseModel(torch.nn.Module):
93
  """
94
  sigma = t
95
  xc = self.model_sampling.calculate_input(sigma, x)
96
- if c_concat is not None:
97
- xc = torch.cat([xc] + [c_concat], dim=1)
98
 
99
- context = c_crossattn
100
- dtype = self.get_dtype()
 
101
 
102
- if self.manual_cast_dtype is not None:
103
- dtype = self.manual_cast_dtype
 
 
 
 
104
 
 
105
  xc = xc.to(dtype)
106
  t = self.model_sampling.timestep(t).float()
107
- context = context.to(dtype)
 
 
108
  extra_conds = {}
109
- for o in kwargs:
110
- extra = kwargs[o]
111
- if hasattr(extra, "dtype"):
112
- if extra.dtype != torch.int and extra.dtype != torch.long:
113
- extra = extra.to(dtype)
114
- extra_conds[o] = extra
115
 
 
116
  model_output = self.diffusion_model(
117
  xc,
118
  t,
@@ -121,6 +129,7 @@ class BaseModel(torch.nn.Module):
121
  transformer_options=transformer_options,
122
  **extra_conds,
123
  ).float()
 
124
  return self.model_sampling.calculate_denoised(sigma, model_output, x)
125
 
126
  def get_dtype(self) -> torch.dtype:
 
56
  **unet_config, device=device, operations=operations
57
  )
58
  self.model_type = model_type
59
+ self.model_sampling = sampling.model_sampling(
60
+ model_config, model_type, flux=flux
61
+ )
62
 
63
  self.adm_channels = unet_config.get("adm_in_channels", None)
64
  if self.adm_channels is None:
 
95
  """
96
  sigma = t
97
  xc = self.model_sampling.calculate_input(sigma, x)
 
 
98
 
99
+ # Optimize concatenation operation by avoiding unnecessary list creation
100
+ if c_concat is not None:
101
+ xc = torch.cat((xc, c_concat), dim=1)
102
 
103
+ # Determine dtype once to avoid repeated calls to get_dtype()
104
+ dtype = (
105
+ self.manual_cast_dtype
106
+ if self.manual_cast_dtype is not None
107
+ else self.get_dtype()
108
+ )
109
 
110
+ # Batch operations to reduce overhead
111
  xc = xc.to(dtype)
112
  t = self.model_sampling.timestep(t).float()
113
+ context = c_crossattn.to(dtype) if c_crossattn is not None else None
114
+
115
+ # Process extra conditions more efficiently
116
  extra_conds = {}
117
+ for name, value in kwargs.items():
118
+ if hasattr(value, "dtype") and value.dtype not in (torch.int, torch.long):
119
+ extra_conds[name] = value.to(dtype)
120
+ else:
121
+ extra_conds[name] = value
 
122
 
123
+ # Run diffusion model and calculate denoised output
124
  model_output = self.diffusion_model(
125
  xc,
126
  t,
 
129
  transformer_options=transformer_options,
130
  **extra_conds,
131
  ).float()
132
+
133
  return self.model_sampling.calculate_denoised(sigma, model_output, x)
134
 
135
  def get_dtype(self) -> torch.dtype:
modules/NeuralNetwork/unet.py CHANGED
@@ -304,7 +304,9 @@ class UNetModel1(nn.Module):
304
  if num_heads_upsample == -1:
305
  num_heads_upsample = num_heads
306
  if num_head_channels == -1:
307
- assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
 
 
308
 
309
  self.in_channels = in_channels
310
  self.model_channels = model_channels
@@ -684,36 +686,29 @@ class UNetModel1(nn.Module):
684
  transformer_options: Dict[str, Any] = {},
685
  **kwargs: Any,
686
  ) -> torch.Tensor:
687
- """#### Forward pass of the UNet model.
688
-
689
- #### Args:
690
- - `x` (torch.Tensor): The input tensor.
691
- - `timesteps` (Optional[torch.Tensor], optional): The timesteps tensor. Defaults to None.
692
- - `context` (Optional[torch.Tensor], optional): The context tensor. Defaults to None.
693
- - `y` (Optional[torch.Tensor], optional): The class labels tensor. Defaults to None.
694
- - `control` (Optional[torch.Tensor], optional): The control tensor. Defaults to None.
695
- - `transformer_options` (Dict[str, Any], optional): Options for the transformer. Defaults to {}.
696
- - `**kwargs` (Any): Additional keyword arguments.
697
-
698
- #### Returns:
699
- - `torch.Tensor`: The output tensor.
700
- """
701
  transformer_options["original_shape"] = list(x.shape)
702
  transformer_options["transformer_index"] = 0
703
- transformer_patches = transformer_options.get("patches", {})
704
 
 
705
  num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
706
  image_only_indicator = kwargs.get("image_only_indicator", None)
707
  time_context = kwargs.get("time_context", None)
708
 
709
- assert (y is not None) == (
710
- self.num_classes is not None
711
- ), "must specify y if and only if the model is class-conditional"
712
- hs = []
713
- t_emb = sampling_util.timestep_embedding(
714
- timesteps, self.model_channels
715
- ).to(x.dtype)
 
 
716
  emb = self.time_embed(t_emb)
 
 
 
717
  h = x
718
  for id, module in enumerate(self.input_blocks):
719
  transformer_options["block"] = ("input", id)
@@ -730,6 +725,7 @@ class UNetModel1(nn.Module):
730
  h = apply_control1(h, control, "input")
731
  hs.append(h)
732
 
 
733
  transformer_options["block"] = ("middle", 0)
734
  if self.middle_block is not None:
735
  h = ResBlock.forward_timestep_embed1(
@@ -744,17 +740,19 @@ class UNetModel1(nn.Module):
744
  )
745
  h = apply_control1(h, control, "middle")
746
 
 
747
  for id, module in enumerate(self.output_blocks):
748
  transformer_options["block"] = ("output", id)
749
  hsp = hs.pop()
750
  hsp = apply_control1(hsp, control, "output")
751
 
 
752
  h = torch.cat([h, hsp], dim=1)
753
- del hsp
754
- if len(hs) > 0:
755
- output_shape = hs[-1].shape
756
- else:
757
- output_shape = None
758
  h = ResBlock.forward_timestep_embed1(
759
  module,
760
  h,
@@ -766,11 +764,15 @@ class UNetModel1(nn.Module):
766
  num_video_frames=num_video_frames,
767
  image_only_indicator=image_only_indicator,
768
  )
 
 
769
  h = h.type(x.dtype)
770
  return self.out(h)
771
 
772
 
773
- def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) -> Dict[str, Any]:
 
 
774
  """#### Detect the UNet configuration from a state dictionary.
775
 
776
  #### Args:
@@ -1017,7 +1019,9 @@ def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) ->
1017
  // model_channels
1018
  )
1019
 
1020
- out = transformer.calculate_transformer_depth(prefix, state_dict_keys, state_dict)
 
 
1021
  if out is not None:
1022
  transformer_depth.append(out[0])
1023
  if context_dim is None:
@@ -1076,7 +1080,9 @@ def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) ->
1076
  return unet_config
1077
 
1078
 
1079
- def model_config_from_unet_config(unet_config: Dict[str, Any], state_dict: Optional[Dict[str, torch.Tensor]] = None) -> Any:
 
 
1080
  """#### Get the model configuration from a UNet configuration.
1081
 
1082
  #### Args:
@@ -1096,7 +1102,11 @@ def model_config_from_unet_config(unet_config: Dict[str, Any], state_dict: Optio
1096
  return None
1097
 
1098
 
1099
- def model_config_from_unet(state_dict: Dict[str, torch.Tensor], unet_key_prefix: str, use_base_if_no_match: bool = False) -> Any:
 
 
 
 
1100
  """#### Get the model configuration from a UNet state dictionary.
1101
 
1102
  #### Args:
@@ -1117,7 +1127,11 @@ def model_config_from_unet(state_dict: Dict[str, torch.Tensor], unet_key_prefix:
1117
  def unet_dtype1(
1118
  device: Optional[torch.device] = None,
1119
  model_params: int = 0,
1120
- supported_dtypes: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32],
 
 
 
 
1121
  ) -> torch.dtype:
1122
  """#### Get the dtype for the UNet model.
1123
 
@@ -1129,4 +1143,4 @@ def unet_dtype1(
1129
  #### Returns:
1130
  - `torch.dtype`: The dtype for the UNet model.
1131
  """
1132
- return torch.float16
 
304
  if num_heads_upsample == -1:
305
  num_heads_upsample = num_heads
306
  if num_head_channels == -1:
307
+ assert num_heads != -1, (
308
+ "Either num_heads or num_head_channels has to be set"
309
+ )
310
 
311
  self.in_channels = in_channels
312
  self.model_channels = model_channels
 
686
  transformer_options: Dict[str, Any] = {},
687
  **kwargs: Any,
688
  ) -> torch.Tensor:
689
+ """#### Forward pass of the UNet model with optimized calculations."""
690
+ # Setup transformer options (avoid unused variable)
 
 
 
 
 
 
 
 
 
 
 
 
691
  transformer_options["original_shape"] = list(x.shape)
692
  transformer_options["transformer_index"] = 0
 
693
 
694
+ # Extract kwargs efficiently
695
  num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
696
  image_only_indicator = kwargs.get("image_only_indicator", None)
697
  time_context = kwargs.get("time_context", None)
698
 
699
+ # Validation
700
+ assert (y is not None) == (self.num_classes is not None), (
701
+ "must specify y if and only if the model is class-conditional"
702
+ )
703
+
704
+ # Time embedding - optimize by computing with target dtype directly
705
+ t_emb = sampling_util.timestep_embedding(timesteps, self.model_channels).to(
706
+ x.dtype
707
+ )
708
  emb = self.time_embed(t_emb)
709
+
710
+ # Input blocks processing
711
+ hs = []
712
  h = x
713
  for id, module in enumerate(self.input_blocks):
714
  transformer_options["block"] = ("input", id)
 
725
  h = apply_control1(h, control, "input")
726
  hs.append(h)
727
 
728
+ # Middle block processing
729
  transformer_options["block"] = ("middle", 0)
730
  if self.middle_block is not None:
731
  h = ResBlock.forward_timestep_embed1(
 
740
  )
741
  h = apply_control1(h, control, "middle")
742
 
743
+ # Output blocks processing - optimize memory usage
744
  for id, module in enumerate(self.output_blocks):
745
  transformer_options["block"] = ("output", id)
746
  hsp = hs.pop()
747
  hsp = apply_control1(hsp, control, "output")
748
 
749
+ # Concatenate tensors
750
  h = torch.cat([h, hsp], dim=1)
751
+ del hsp # Free memory immediately
752
+
753
+ # Only calculate output shape when needed
754
+ output_shape = hs[-1].shape if hs else None
755
+
756
  h = ResBlock.forward_timestep_embed1(
757
  module,
758
  h,
 
764
  num_video_frames=num_video_frames,
765
  image_only_indicator=image_only_indicator,
766
  )
767
+
768
+ # Ensure output has correct dtype
769
  h = h.type(x.dtype)
770
  return self.out(h)
771
 
772
 
773
+ def detect_unet_config(
774
+ state_dict: Dict[str, torch.Tensor], key_prefix: str
775
+ ) -> Dict[str, Any]:
776
  """#### Detect the UNet configuration from a state dictionary.
777
 
778
  #### Args:
 
1019
  // model_channels
1020
  )
1021
 
1022
+ out = transformer.calculate_transformer_depth(
1023
+ prefix, state_dict_keys, state_dict
1024
+ )
1025
  if out is not None:
1026
  transformer_depth.append(out[0])
1027
  if context_dim is None:
 
1080
  return unet_config
1081
 
1082
 
1083
+ def model_config_from_unet_config(
1084
+ unet_config: Dict[str, Any], state_dict: Optional[Dict[str, torch.Tensor]] = None
1085
+ ) -> Any:
1086
  """#### Get the model configuration from a UNet configuration.
1087
 
1088
  #### Args:
 
1102
  return None
1103
 
1104
 
1105
+ def model_config_from_unet(
1106
+ state_dict: Dict[str, torch.Tensor],
1107
+ unet_key_prefix: str,
1108
+ use_base_if_no_match: bool = False,
1109
+ ) -> Any:
1110
  """#### Get the model configuration from a UNet state dictionary.
1111
 
1112
  #### Args:
 
1127
  def unet_dtype1(
1128
  device: Optional[torch.device] = None,
1129
  model_params: int = 0,
1130
+ supported_dtypes: List[torch.dtype] = [
1131
+ torch.float16,
1132
+ torch.bfloat16,
1133
+ torch.float32,
1134
+ ],
1135
  ) -> torch.dtype:
1136
  """#### Get the dtype for the UNet model.
1137
 
 
1143
  #### Returns:
1144
  - `torch.dtype`: The dtype for the UNet model.
1145
  """
1146
+ return torch.float16
modules/Utilities/util.py CHANGED
@@ -4,7 +4,6 @@ import itertools
4
  import logging
5
  import math
6
  import os
7
- import pickle
8
  import safetensors.torch
9
  import torch
10
 
@@ -120,6 +119,18 @@ def state_dict_prefix_replace(
120
  return out
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def repeat_to_batch_size(
124
  tensor: torch.Tensor, batch_size: int, dim: int = 0
125
  ) -> torch.Tensor:
@@ -437,11 +448,11 @@ def tiled_scale_multidim(
437
 
438
  def get_upscale(dim: int, val: int) -> int:
439
  """#### Get the upscale value.
440
-
441
  #### Args:
442
  - `dim` (int): The dimension.
443
  - `val` (int): The value.
444
-
445
  #### Returns:
446
  - `int`: The upscaled value.
447
  """
@@ -453,11 +464,11 @@ def tiled_scale_multidim(
453
 
454
  def get_downscale(dim: int, val: int) -> int:
455
  """#### Get the downscale value.
456
-
457
  #### Args:
458
  - `dim` (int): The dimension.
459
  - `val` (int): The value.
460
-
461
  #### Returns:
462
  - `int`: The downscaled value.
463
  """
@@ -469,11 +480,11 @@ def tiled_scale_multidim(
469
 
470
  def get_upscale_pos(dim: int, val: int) -> int:
471
  """#### Get the upscaled position.
472
-
473
  #### Args:
474
  - `dim` (int): The dimension.
475
  - `val` (int): The value.
476
-
477
  #### Returns:
478
  - `int`: The upscaled position.
479
  """
@@ -485,11 +496,11 @@ def tiled_scale_multidim(
485
 
486
  def get_downscale_pos(dim: int, val: int) -> int:
487
  """#### Get the downscaled position.
488
-
489
  #### Args:
490
  - `dim` (int): The dimension.
491
  - `val` (int): The value.
492
-
493
  #### Returns:
494
  - `int`: The downscaled position.
495
  """
@@ -508,10 +519,10 @@ def tiled_scale_multidim(
508
 
509
  def mult_list_upscale(a: list) -> list:
510
  """#### Multiply a list by the upscale amount.
511
-
512
  #### Args:
513
  - `a` (list): The list.
514
-
515
  #### Returns:
516
  - `list`: The multiplied list.
517
  """
@@ -601,7 +612,7 @@ def tiled_scale(
601
  pbar: any = None,
602
  ):
603
  """#### Scale an image using a tiled approach.
604
-
605
  #### Args:
606
  - `samples` (torch.Tensor): The input samples.
607
  - `function` (function): The scaling function.
@@ -612,7 +623,7 @@ def tiled_scale(
612
  - `out_channels` (int, optional): The number of output channels. Defaults to 3.
613
  - `output_device` (str, optional): The output device. Defaults to "cpu".
614
  - `pbar` (any, optional): The progress bar. Defaults to None.
615
-
616
  #### Returns:
617
  - The scaled image.
618
  """
 
4
  import logging
5
  import math
6
  import os
 
7
  import safetensors.torch
8
  import torch
9
 
 
119
  return out
120
 
121
 
122
+ def lcm_of_list(numbers):
123
+ """Calculate LCM of a list of numbers more efficiently."""
124
+ if not numbers:
125
+ return 1
126
+
127
+ result = numbers[0]
128
+ for num in numbers[1:]:
129
+ result = torch.lcm(torch.tensor(result), torch.tensor(num)).item()
130
+
131
+ return result
132
+
133
+
134
  def repeat_to_batch_size(
135
  tensor: torch.Tensor, batch_size: int, dim: int = 0
136
  ) -> torch.Tensor:
 
448
 
449
  def get_upscale(dim: int, val: int) -> int:
450
  """#### Get the upscale value.
451
+
452
  #### Args:
453
  - `dim` (int): The dimension.
454
  - `val` (int): The value.
455
+
456
  #### Returns:
457
  - `int`: The upscaled value.
458
  """
 
464
 
465
  def get_downscale(dim: int, val: int) -> int:
466
  """#### Get the downscale value.
467
+
468
  #### Args:
469
  - `dim` (int): The dimension.
470
  - `val` (int): The value.
471
+
472
  #### Returns:
473
  - `int`: The downscaled value.
474
  """
 
480
 
481
  def get_upscale_pos(dim: int, val: int) -> int:
482
  """#### Get the upscaled position.
483
+
484
  #### Args:
485
  - `dim` (int): The dimension.
486
  - `val` (int): The value.
487
+
488
  #### Returns:
489
  - `int`: The upscaled position.
490
  """
 
496
 
497
  def get_downscale_pos(dim: int, val: int) -> int:
498
  """#### Get the downscaled position.
499
+
500
  #### Args:
501
  - `dim` (int): The dimension.
502
  - `val` (int): The value.
503
+
504
  #### Returns:
505
  - `int`: The downscaled position.
506
  """
 
519
 
520
  def mult_list_upscale(a: list) -> list:
521
  """#### Multiply a list by the upscale amount.
522
+
523
  #### Args:
524
  - `a` (list): The list.
525
+
526
  #### Returns:
527
  - `list`: The multiplied list.
528
  """
 
612
  pbar: any = None,
613
  ):
614
  """#### Scale an image using a tiled approach.
615
+
616
  #### Args:
617
  - `samples` (torch.Tensor): The input samples.
618
  - `function` (function): The scaling function.
 
623
  - `out_channels` (int, optional): The number of output channels. Defaults to 3.
624
  - `output_device` (str, optional): The output device. Defaults to "cpu".
625
  - `pbar` (any, optional): The progress bar. Defaults to None.
626
+
627
  #### Returns:
628
  - The scaled image.
629
  """
modules/cond/cond.py CHANGED
@@ -42,13 +42,13 @@ class CONDRegular:
42
  return self._copy_with(
43
  util.repeat_to_batch_size(self.cond, batch_size).to(device)
44
  )
45
-
46
  def can_concat(self, other: "CONDRegular") -> bool:
47
  """#### Check if conditions can be concatenated.
48
-
49
  #### Args:
50
  - `other` (CONDRegular): The other condition.
51
-
52
  #### Returns:
53
  - `bool`: True if conditions can be concatenated, False otherwise.
54
  """
@@ -58,10 +58,10 @@ class CONDRegular:
58
 
59
  def concat(self, others: list) -> torch.Tensor:
60
  """#### Concatenate conditions.
61
-
62
  #### Args:
63
  - `others` (list): The list of other conditions.
64
-
65
  #### Returns:
66
  - `torch.Tensor`: The concatenated conditions.
67
  """
@@ -76,11 +76,11 @@ class CONDCrossAttn(CONDRegular):
76
 
77
  def can_concat(self, other: "CONDRegular") -> bool:
78
  """#### Check if conditions can be concatenated.
79
-
80
  #### Args:
81
  - `other` (CONDRegular): The other condition.
82
-
83
- #### Returns:
84
  - `bool`: True if conditions can be concatenated, False otherwise.
85
  """
86
  s1 = self.cond.shape
@@ -96,31 +96,34 @@ class CONDCrossAttn(CONDRegular):
96
  ): # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
97
  return False
98
  return True
99
-
100
- def concat(self, others: list) -> torch.Tensor:
101
- """#### Concatenate cross-attention conditions.
102
 
103
- #### Args:
104
- - `others` (list): The list of other conditions.
105
-
106
- #### Returns:
107
- - `torch.Tensor`: The concatenated conditions.
108
- """
109
  conds = [self.cond]
110
- crossattn_max_len = self.cond.shape[1]
 
 
111
  for x in others:
112
- c = x.cond
113
- crossattn_max_len = util.lcm(crossattn_max_len, c.shape[1])
114
- conds.append(c)
115
 
116
- out = []
117
- for c in conds:
118
- if c.shape[1] < crossattn_max_len:
119
- c = c.repeat(
120
- 1, crossattn_max_len // c.shape[1], 1
121
- ) # padding with repeat doesn't change result, but avoids an error on tensor shape
122
- out.append(c)
123
- return torch.cat(out)
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  def convert_cond(cond: list) -> list:
@@ -277,8 +280,10 @@ def calc_cond_batch(
277
  out_c += output[o] * mult[o]
278
  out_cts += mult[o]
279
 
 
280
  for i in range(len(out_conds)):
281
- out_conds[i] /= out_counts[i]
 
282
 
283
  return out_conds
284
 
@@ -328,48 +333,75 @@ def encode_model_conds(
328
  conds[t] = x
329
  return conds
330
 
331
- def resolve_areas_and_cond_masks_multidim(conditions: list, dims: tuple, device: torch.device) -> None:
332
- """#### Resolve areas and condition masks for multidimensional conditions.
333
-
334
- #### Args:
335
- - `conditions` (list): The list of conditions.
336
- - `dims` (tuple): The dimensions.
337
- - `device` (torch.device): The device.
338
- """
339
- # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
340
- # While we're doing this, we can also resolve the mask device and scaling for performance reasons
341
  for i in range(len(conditions)):
342
  c = conditions[i]
 
343
  if "area" in c:
344
  area = c["area"]
345
  if area[0] == "percentage":
346
- modified = c.copy()
347
  a = area[1:]
348
  a_len = len(a) // 2
349
- area = ()
350
- for d in range(len(dims)):
351
- area += (max(1, round(a[d] * dims[d])),)
352
- for d in range(len(dims)):
353
- area += (round(a[d + a_len] * dims[d]),)
354
 
355
- modified["area"] = area
356
- c = modified
357
- conditions[i] = c
 
 
 
 
 
 
 
358
 
 
 
 
 
 
 
 
 
 
 
 
359
  if "mask" in c:
360
- mask = c["mask"]
361
- mask = mask.to(device=device)
362
  modified = c.copy()
 
 
 
363
  if len(mask.shape) == len(dims):
364
  mask = mask.unsqueeze(0)
 
 
365
  if mask.shape[1:] != dims:
366
- mask = torch.nn.functional.interpolate(
367
- mask.unsqueeze(1), size=dims, mode="bilinear", align_corners=False
368
- ).squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  modified["mask"] = mask
371
  conditions[i] = modified
372
 
 
373
  def process_conds(
374
  model: object,
375
  noise: torch.Tensor,
@@ -442,4 +474,4 @@ def process_conds(
442
  positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]
443
  )
444
 
445
- return conds
 
42
  return self._copy_with(
43
  util.repeat_to_batch_size(self.cond, batch_size).to(device)
44
  )
45
+
46
  def can_concat(self, other: "CONDRegular") -> bool:
47
  """#### Check if conditions can be concatenated.
48
+
49
  #### Args:
50
  - `other` (CONDRegular): The other condition.
51
+
52
  #### Returns:
53
  - `bool`: True if conditions can be concatenated, False otherwise.
54
  """
 
58
 
59
  def concat(self, others: list) -> torch.Tensor:
60
  """#### Concatenate conditions.
61
+
62
  #### Args:
63
  - `others` (list): The list of other conditions.
64
+
65
  #### Returns:
66
  - `torch.Tensor`: The concatenated conditions.
67
  """
 
76
 
77
  def can_concat(self, other: "CONDRegular") -> bool:
78
  """#### Check if conditions can be concatenated.
79
+
80
  #### Args:
81
  - `other` (CONDRegular): The other condition.
82
+
83
+ #### Returns:
84
  - `bool`: True if conditions can be concatenated, False otherwise.
85
  """
86
  s1 = self.cond.shape
 
96
  ): # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
97
  return False
98
  return True
 
 
 
99
 
100
+ def concat(self, others: list) -> torch.Tensor:
101
+ """Optimized version of cross-attention condition concatenation."""
 
 
 
 
102
  conds = [self.cond]
103
+ shapes = [self.cond.shape[1]]
104
+
105
+ # Collect all conditions and their shapes
106
  for x in others:
107
+ conds.append(x.cond)
108
+ shapes.append(x.cond.shape[1])
 
109
 
110
+ # Calculate LCM more efficiently
111
+ crossattn_max_len = util.lcm_of_list(shapes)
112
+
113
+ # Process and concat in one step where possible
114
+ if all(c.shape[1] == shapes[0] for c in conds):
115
+ # All same length, simple concatenation
116
+ return torch.cat(conds)
117
+ else:
118
+ # Process conditions that need repeating
119
+ out = []
120
+ for c in conds:
121
+ if c.shape[1] < crossattn_max_len:
122
+ repeat_factor = crossattn_max_len // c.shape[1]
123
+ # Use repeat instead of individual operations
124
+ c = c.repeat(1, repeat_factor, 1)
125
+ out.append(c)
126
+ return torch.cat(out)
127
 
128
 
129
  def convert_cond(cond: list) -> list:
 
280
  out_c += output[o] * mult[o]
281
  out_cts += mult[o]
282
 
283
+ # Vectorize the division at the end
284
  for i in range(len(out_conds)):
285
+ # Inplace division is already efficient
286
+ out_conds[i].div_(out_counts[i]) # Using .div_ instead of /= for clarity
287
 
288
  return out_conds
289
 
 
333
  conds[t] = x
334
  return conds
335
 
336
+
337
+ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
338
+ """Optimized version that processes areas and masks more efficiently"""
 
 
 
 
 
 
 
339
  for i in range(len(conditions)):
340
  c = conditions[i]
341
+ # Process area
342
  if "area" in c:
343
  area = c["area"]
344
  if area[0] == "percentage":
345
+ # Vectorized calculation of area dimensions
346
  a = area[1:]
347
  a_len = len(a) // 2
 
 
 
 
 
348
 
349
+ # Calculate all dimensions at once using tensor operations
350
+ dims_tensor = torch.tensor(dims, device="cpu")
351
+ first_part = torch.tensor(a[:a_len], device="cpu") * dims_tensor
352
+ second_part = torch.tensor(a[a_len:], device="cpu") * dims_tensor
353
+
354
+ # Convert to rounded integers and tuple
355
+ first_part = torch.max(
356
+ torch.ones_like(first_part), torch.round(first_part)
357
+ )
358
+ second_part = torch.round(second_part)
359
 
360
+ # Create the new area tuple
361
+ new_area = tuple(first_part.int().tolist()) + tuple(
362
+ second_part.int().tolist()
363
+ )
364
+
365
+ # Create a modified copy with the new area
366
+ modified = c.copy()
367
+ modified["area"] = new_area
368
+ conditions[i] = modified
369
+
370
+ # Process mask
371
  if "mask" in c:
 
 
372
  modified = c.copy()
373
+ mask = c["mask"].to(device=device)
374
+
375
+ # Combine dimension checks and unsqueeze operation
376
  if len(mask.shape) == len(dims):
377
  mask = mask.unsqueeze(0)
378
+
379
+ # Only interpolate if needed
380
  if mask.shape[1:] != dims:
381
+ # Optimize interpolation by ensuring mask is in the right format for the operation
382
+ if len(mask.shape) == 3 and mask.shape[0] == 1:
383
+ # Already in the right format for interpolation
384
+ mask = torch.nn.functional.interpolate(
385
+ mask.unsqueeze(1),
386
+ size=dims,
387
+ mode="bilinear",
388
+ align_corners=False,
389
+ ).squeeze(1)
390
+ else:
391
+ # Ensure mask is properly formatted for interpolation
392
+ mask = torch.nn.functional.interpolate(
393
+ mask
394
+ if len(mask.shape) > 3 and mask.shape[1] == 1
395
+ else mask.unsqueeze(1),
396
+ size=dims,
397
+ mode="bilinear",
398
+ align_corners=False,
399
+ ).squeeze(1)
400
 
401
  modified["mask"] = mask
402
  conditions[i] = modified
403
 
404
+
405
  def process_conds(
406
  model: object,
407
  noise: torch.Tensor,
 
474
  positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]
475
  )
476
 
477
+ return conds
modules/sample/CFG.py CHANGED
@@ -30,10 +30,15 @@ def cfg_function(
30
  #### Returns:
31
  - `torch.Tensor`: The CFG result.
32
  """
 
33
  if "sampler_cfg_function" in model_options:
 
 
 
 
34
  args = {
35
- "cond": x - cond_pred,
36
- "uncond": x - uncond_pred,
37
  "cond_scale": cond_scale,
38
  "timestep": timestep,
39
  "input": x,
@@ -45,9 +50,18 @@ def cfg_function(
45
  }
46
  cfg_result = x - model_options["sampler_cfg_function"](args)
47
  else:
48
- cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
49
-
50
- for fn in model_options.get("sampler_post_cfg_function", []):
 
 
 
 
 
 
 
 
 
51
  args = {
52
  "denoised": cfg_result,
53
  "cond": cond,
@@ -59,7 +73,12 @@ def cfg_function(
59
  "model_options": model_options,
60
  "input": x,
61
  }
62
- cfg_result = fn(args)
 
 
 
 
 
63
 
64
  return cfg_result
65
 
@@ -89,21 +108,29 @@ def sampling_function(
89
  #### Returns:
90
  - `torch.Tensor`: The sampled tensor.
91
  """
92
- if (
93
- math.isclose(cond_scale, 1.0)
94
- and model_options.get("disable_cfg1_optimization", False) is False
95
- ):
96
- uncond_ = None
97
- else:
98
- uncond_ = uncond
 
 
99
 
 
100
  conds = [condo, uncond_]
101
- out = cond.calc_cond_batch(model, conds, x, timestep, model_options)
102
 
103
- for fn in model_options.get("sampler_pre_cfg_function", []):
 
 
 
 
 
 
104
  args = {
105
  "conds": conds,
106
- "conds_out": out,
107
  "cond_scale": cond_scale,
108
  "timestep": timestep,
109
  "input": x,
@@ -111,12 +138,20 @@ def sampling_function(
111
  "model": model,
112
  "model_options": model_options,
113
  }
114
- out = fn(args)
115
 
 
 
 
 
 
 
 
 
 
116
  return cfg_function(
117
  model,
118
- out[0],
119
- out[1],
120
  cond_scale,
121
  x,
122
  timestep,
@@ -128,6 +163,7 @@ def sampling_function(
128
 
129
  class CFGGuider:
130
  """#### Class for guiding the sampling process with CFG."""
 
131
  def __init__(self, model_patcher, flux=False):
132
  """#### Initialize the CFGGuider.
133
 
@@ -315,4 +351,4 @@ class CFGGuider:
315
  del self.inner_model
316
  del self.conds
317
  del self.loaded_models
318
- return output
 
30
  #### Returns:
31
  - `torch.Tensor`: The CFG result.
32
  """
33
+ # Check for custom sampler CFG function first
34
  if "sampler_cfg_function" in model_options:
35
+ # Precompute differences to avoid redundant operations
36
+ cond_diff = x - cond_pred
37
+ uncond_diff = x - uncond_pred
38
+
39
  args = {
40
+ "cond": cond_diff,
41
+ "uncond": uncond_diff,
42
  "cond_scale": cond_scale,
43
  "timestep": timestep,
44
  "input": x,
 
50
  }
51
  cfg_result = x - model_options["sampler_cfg_function"](args)
52
  else:
53
+ # Standard CFG calculation - optimized to avoid intermediate tensor allocation
54
+ # When cond_scale = 1.0, we can just return cond_pred without computation
55
+ if math.isclose(cond_scale, 1.0):
56
+ cfg_result = cond_pred
57
+ else:
58
+ # Fused operation: uncond_pred + (cond_pred - uncond_pred) * cond_scale
59
+ # Equivalent to: uncond_pred * (1 - cond_scale) + cond_pred * cond_scale
60
+ cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale)
61
+
62
+ # Apply post-CFG functions if any
63
+ post_cfg_functions = model_options.get("sampler_post_cfg_function", [])
64
+ if post_cfg_functions:
65
  args = {
66
  "denoised": cfg_result,
67
  "cond": cond,
 
73
  "model_options": model_options,
74
  "input": x,
75
  }
76
+
77
+ # Apply each post-CFG function in sequence
78
+ for fn in post_cfg_functions:
79
+ cfg_result = fn(args)
80
+ # Update the denoised result for the next function
81
+ args["denoised"] = cfg_result
82
 
83
  return cfg_result
84
 
 
108
  #### Returns:
109
  - `torch.Tensor`: The sampled tensor.
110
  """
111
+ # Optimize conditional logic for uncond
112
+ uncond_ = (
113
+ None
114
+ if (
115
+ math.isclose(cond_scale, 1.0)
116
+ and not model_options.get("disable_cfg1_optimization", False)
117
+ )
118
+ else uncond
119
+ )
120
 
121
+ # Create conditions list once
122
  conds = [condo, uncond_]
 
123
 
124
+ # Get model predictions for both conditions
125
+ cond_outputs = cond.calc_cond_batch(model, conds, x, timestep, model_options)
126
+
127
+ # Apply pre-CFG functions if any
128
+ pre_cfg_functions = model_options.get("sampler_pre_cfg_function", [])
129
+ if pre_cfg_functions:
130
+ # Create args dictionary once
131
  args = {
132
  "conds": conds,
133
+ "conds_out": cond_outputs,
134
  "cond_scale": cond_scale,
135
  "timestep": timestep,
136
  "input": x,
 
138
  "model": model,
139
  "model_options": model_options,
140
  }
 
141
 
142
+ # Apply each pre-CFG function
143
+ for fn in pre_cfg_functions:
144
+ cond_outputs = fn(args)
145
+ args["conds_out"] = cond_outputs
146
+
147
+ # Extract conditional and unconditional outputs explicitly for clarity
148
+ cond_pred, uncond_pred = cond_outputs[0], cond_outputs[1]
149
+
150
+ # Apply the CFG function
151
  return cfg_function(
152
  model,
153
+ cond_pred,
154
+ uncond_pred,
155
  cond_scale,
156
  x,
157
  timestep,
 
163
 
164
  class CFGGuider:
165
  """#### Class for guiding the sampling process with CFG."""
166
+
167
  def __init__(self, model_patcher, flux=False):
168
  """#### Initialize the CFGGuider.
169
 
 
351
  del self.inner_model
352
  del self.conds
353
  del self.loaded_models
354
+ return output
modules/sample/ksampler_util.py CHANGED
@@ -46,6 +46,7 @@ def pre_run_control(model: torch.nn.Module, conds: list) -> None:
46
 
47
  def percent_to_timestep_function(a):
48
  return s.percent_to_sigma(a)
 
49
  if "control" in x:
50
  x["control"].pre_run(model, percent_to_timestep_function)
51
 
@@ -96,9 +97,13 @@ def apply_empty_x_to_equal_area(
96
  uncond[temp[1]] = n
97
 
98
 
99
- def get_area_and_mult(
100
- conds: dict, x_in: torch.Tensor, timestep_in: int
101
- ) -> collections.namedtuple:
 
 
 
 
102
  """#### Get the area and multiplier.
103
 
104
  #### Args:
@@ -109,26 +114,39 @@ def get_area_and_mult(
109
  #### Returns:
110
  - `collections.namedtuple`: The area and multiplier.
111
  """
112
- area = (x_in.shape[2], x_in.shape[3], 0, 0)
113
- strength = 1.0
114
 
115
- input_x = x_in[:, :, area[2] : area[0] + area[2], area[3] : area[1] + area[3]]
116
- mask = torch.ones_like(input_x)
117
- mult = mask * strength
118
 
 
 
 
 
 
 
 
 
 
 
119
  conditioning = {}
120
  model_conds = conds["model_conds"]
 
 
 
 
121
  for c in model_conds:
122
  conditioning[c] = model_conds[c].process_cond(
123
- batch_size=x_in.shape[0], device=x_in.device, area=area
124
  )
125
 
 
126
  control = conds.get("control", None)
127
  patches = None
128
- cond_obj = collections.namedtuple(
129
- "cond_obj", ["input_x", "mult", "conditioning", "area", "control", "patches"]
130
- )
131
- return cond_obj(input_x, mult, conditioning, area, control, patches)
132
 
133
 
134
  def normal_scheduler(
@@ -158,6 +176,7 @@ def normal_scheduler(
158
  sigs += [0.0]
159
  return torch.FloatTensor(sigs)
160
 
 
161
  def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor:
162
  """#### Create a simple scheduler.
163
 
@@ -176,21 +195,52 @@ def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.Float
176
  sigs += [0.0]
177
  return torch.FloatTensor(sigs)
178
 
 
179
  # Implemented based on: https://arxiv.org/abs/2407.12173
180
  def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
181
- total_timesteps = (len(model_sampling.sigmas) - 1)
182
- ts = 1 - np.linspace(0, 1, steps, endpoint=False)
183
- ts = np.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- sigs = []
186
- last_t = -1
187
- for t in ts:
188
- if t != last_t:
189
- sigs += [float(model_sampling.sigmas[int(t)])]
190
- last_t = t
191
- sigs += [0.0]
192
  return torch.FloatTensor(sigs)
193
 
 
194
  def calculate_sigmas(
195
  model_sampling: torch.nn.Module, scheduler_name: str, steps: int
196
  ) -> torch.Tensor:
 
46
 
47
  def percent_to_timestep_function(a):
48
  return s.percent_to_sigma(a)
49
+
50
  if "control" in x:
51
  x["control"].pre_run(model, percent_to_timestep_function)
52
 
 
97
  uncond[temp[1]] = n
98
 
99
 
100
+ # Define the namedtuple class once outside the function for reuse
101
+ CondObj = collections.namedtuple(
102
+ "cond_obj", ["input_x", "mult", "conditioning", "area", "control", "patches"]
103
+ )
104
+
105
+
106
+ def get_area_and_mult(conds: dict, x_in: torch.Tensor, timestep_in: int) -> CondObj:
107
  """#### Get the area and multiplier.
108
 
109
  #### Args:
 
114
  #### Returns:
115
  - `collections.namedtuple`: The area and multiplier.
116
  """
117
+ # Cache shape information to avoid repeated access
118
+ x_shape = x_in.shape
119
 
120
+ # Define area dimensions in one operation
121
+ area = (x_shape[2], x_shape[3], 0, 0)
 
122
 
123
+ # Extract input region efficiently
124
+ # Since area[2] and area[3] are 0, this is essentially taking the full tensor
125
+ # But we maintain the slice operation for consistency
126
+ input_x = x_in[:, :, : area[0], : area[1]]
127
+
128
+ # Create multiplier tensor directly without intermediate mask creation
129
+ # This avoids an unnecessary tensor allocation and multiplication
130
+ mult = torch.ones_like(input_x) # strength is 1.0, so just create ones directly
131
+
132
+ # Prepare conditioning dictionary with cached device and batch_size
133
  conditioning = {}
134
  model_conds = conds["model_conds"]
135
+ batch_size = x_shape[0]
136
+ device = x_in.device
137
+
138
+ # Process conditions with cached parameters
139
  for c in model_conds:
140
  conditioning[c] = model_conds[c].process_cond(
141
+ batch_size=batch_size, device=device, area=area
142
  )
143
 
144
+ # Get control directly without redundant variable assignment
145
  control = conds.get("control", None)
146
  patches = None
147
+
148
+ # Use the pre-defined namedtuple class instead of creating it every call
149
+ return CondObj(input_x, mult, conditioning, area, control, patches)
 
150
 
151
 
152
  def normal_scheduler(
 
176
  sigs += [0.0]
177
  return torch.FloatTensor(sigs)
178
 
179
+
180
  def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor:
181
  """#### Create a simple scheduler.
182
 
 
195
  sigs += [0.0]
196
  return torch.FloatTensor(sigs)
197
 
198
+
199
  # Implemented based on: https://arxiv.org/abs/2407.12173
200
  def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
201
+ """Creates a beta scheduler for noise levels based on the beta distribution.
202
+
203
+ This optimized implementation efficiently computes sigmas using the beta
204
+ distribution and caches calculations where possible.
205
+
206
+ Args:
207
+ model_sampling: Model sampling module
208
+ steps: Number of steps
209
+ alpha: Alpha parameter for beta distribution
210
+ beta: Beta parameter for beta distribution
211
+
212
+ Returns:
213
+ torch.FloatTensor: Tensor of sigma values for each step
214
+ """
215
+ # Calculate total timesteps once
216
+ total_timesteps = len(model_sampling.sigmas) - 1
217
+
218
+ # Create a cache dictionary for reused values
219
+ model_sigmas = model_sampling.sigmas
220
+
221
+ # Generate evenly spaced values in [0,1) interval
222
+ ts_normalized = np.linspace(0, 1, steps, endpoint=False)
223
+
224
+ # Apply beta inverse CDF to get sampled time points - vectorized operation
225
+ ts_beta = scipy.stats.beta.ppf(1 - ts_normalized, alpha, beta)
226
+
227
+ # Scale to timestep indices and round to integers
228
+ ts_indices = np.rint(ts_beta * total_timesteps).astype(np.int32)
229
+
230
+ # Use numpy's unique function with return_index to efficiently find unique values
231
+ # while preserving order
232
+ unique_ts, indices = np.unique(ts_indices, return_index=True)
233
+ ordered_unique_ts = unique_ts[np.argsort(indices)]
234
+
235
+ # Map indices to sigma values efficiently
236
+ sigs = [float(model_sigmas[idx]) for idx in ordered_unique_ts]
237
+
238
+ # Add final sigma value of 0.0
239
+ sigs.append(0.0)
240
 
 
 
 
 
 
 
 
241
  return torch.FloatTensor(sigs)
242
 
243
+
244
  def calculate_sigmas(
245
  model_sampling: torch.nn.Module, scheduler_name: str, steps: int
246
  ) -> torch.Tensor:
modules/sample/samplers.py CHANGED
@@ -142,181 +142,15 @@ def sample_euler(
142
  return x
143
 
144
 
145
- @torch.no_grad()
146
- def sample_dpmpp_sde(
147
- model,
148
- x,
149
- sigmas,
150
- extra_args=None,
151
- callback=None,
152
- disable=None,
153
- eta=1.0,
154
- s_noise=1.0,
155
- noise_sampler=None,
156
- r=1 / 2,
157
- pipeline=False,
158
- seed=None,
159
- ):
160
- # Pre-calculate common values
161
- device = x.device
162
- global disable_gui
163
- disable_gui = pipeline
164
-
165
- if not disable_gui:
166
- from modules.AutoEncoders import taesd
167
- from modules.user import app_instance
168
-
169
- # Early return check
170
- if len(sigmas) <= 1:
171
- return x
172
-
173
- # Pre-allocate tensors and values
174
- s_in = torch.ones((x.shape[0],), device=device)
175
- n_steps = len(sigmas) - 1
176
- extra_args = {} if extra_args is None else extra_args
177
-
178
- # Define helper functions
179
- def sigma_fn(t):
180
- return (-t).exp()
181
-
182
- def t_fn(sigma):
183
- return -sigma.log()
184
-
185
- # Initialize noise sampler
186
- if noise_sampler is None:
187
- noise_sampler = sampling_util.BrownianTreeNoiseSampler(
188
- x, sigmas[sigmas > 0].min(), sigmas.max(), seed=seed, cpu=True
189
- )
190
-
191
- for i in trange(n_steps, disable=disable):
192
- if (
193
- not pipeline
194
- and hasattr(app_instance.app, "interrupt_flag")
195
- and app_instance.app.interrupt_flag
196
- ):
197
- return x
198
-
199
- if not pipeline:
200
- app_instance.app.progress.set(i / n_steps)
201
-
202
- # Model inference
203
- denoised = model(x, sigmas[i] * s_in, **extra_args)
204
-
205
- if callback is not None:
206
- callback({"x": x, "i": i, "sigma": sigmas[i], "denoised": denoised})
207
-
208
- if sigmas[i + 1] == 0:
209
- # Single fused Euler step
210
- x = x + util.to_d(x, sigmas[i], denoised) * (sigmas[i + 1] - sigmas[i])
211
- else:
212
- # Fused DPM-Solver++ steps
213
- t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
214
- s = t + (t_next - t) * r
215
-
216
- # Step 1 - Combined calculations
217
- sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
218
- s_ = t_fn(sd)
219
- x_2 = (
220
- (sigma_fn(s_) / sigma_fn(t)) * x
221
- - (t - s_).expm1() * denoised
222
- + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
223
- )
224
-
225
- denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
226
-
227
- # Step 2 - Combined calculations
228
- sd, su = sampling_util.get_ancestral_step(
229
- sigma_fn(t), sigma_fn(t_next), eta
230
- )
231
- t_next_ = t_fn(sd)
232
-
233
- # Final update in single calculation
234
- x = (
235
- (sigma_fn(t_next_) / sigma_fn(t)) * x
236
- - (t - t_next_).expm1()
237
- * ((1 - 1 / (2 * r)) * denoised + (1 / (2 * r)) * denoised_2)
238
- + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
239
- )
240
-
241
- # Preview updates
242
- if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
243
- threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
244
-
245
- return x
246
-
247
-
248
- @torch.no_grad()
249
- def sample_dpmpp_2m(
250
- model,
251
- x,
252
- sigmas,
253
- extra_args=None,
254
- callback=None,
255
- disable=None,
256
- pipeline=False,
257
  ):
258
- """DPM-Solver++(2M) sampler with optimizations"""
259
- # Pre-calculate common values and setup
260
- device = x.device
261
- global disable_gui
262
- disable_gui = pipeline
263
-
264
- if not disable_gui:
265
- from modules.AutoEncoders import taesd
266
- from modules.user import app_instance
267
-
268
- # Pre-allocate tensors and transform sigmas
269
- s_in = torch.ones((x.shape[0],), device=device)
270
- t_steps = -torch.log(sigmas) # Fused calculation
271
-
272
- # Pre-calculate all needed values in one go
273
- sigma_steps = torch.exp(-t_steps) # Fused calculation
274
- ratios = sigma_steps[1:] / sigma_steps[:-1]
275
- h_steps = t_steps[1:] - t_steps[:-1]
276
-
277
- old_denoised = None
278
- extra_args = {} if extra_args is None else extra_args
279
-
280
- for i in trange(len(sigmas) - 1, disable=disable):
281
- if (
282
- not pipeline
283
- and hasattr(app_instance.app, "interrupt_flag")
284
- and app_instance.app.interrupt_flag
285
- ):
286
- return x
287
-
288
- if not pipeline:
289
- app_instance.app.progress.set(i / (len(sigmas) - 1))
290
-
291
- # Fused model inference and update calculations
292
- denoised = model(x, sigmas[i] * s_in, **extra_args)
293
-
294
- if callback is not None:
295
- callback(
296
- {
297
- "x": x,
298
- "i": i,
299
- "sigma": sigmas[i],
300
- "sigma_hat": sigmas[i],
301
- "denoised": denoised,
302
- }
303
- )
304
-
305
- # Combined update step
306
- x = ratios[i] * x - (-h_steps[i]).expm1() * (
307
- denoised
308
- if old_denoised is None or sigmas[i + 1] == 0
309
- else (1 + h_steps[i - 1] / (2 * h_steps[i])) * denoised
310
- - (h_steps[i - 1] / (2 * h_steps[i])) * old_denoised
311
- )
312
-
313
- old_denoised = denoised
314
-
315
- # Preview updates
316
- if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
317
- threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
318
-
319
- return x
320
 
321
 
322
  @torch.no_grad()
@@ -354,17 +188,26 @@ def sample_dpmpp_2m_cfgpp(
354
  ratios = sigma_steps[1:] / sigma_steps[:-1]
355
  h_steps = t_steps[1:] - t_steps[:-1]
356
 
357
- # CFG++ scheduling
358
- def get_cfg_scale(step):
359
- # Linear scheduling from cfg_scale to cfg_min
360
- progress = step / n_steps
361
- return cfg_scale + (cfg_min - cfg_scale) * progress
362
 
363
  old_denoised = None
364
  old_uncond_denoised = None
365
  extra_args = {} if extra_args is None else extra_args
366
 
367
- for i in trange(len(sigmas) - 1, disable=disable):
 
 
 
 
 
 
 
 
 
 
 
368
  if (
369
  not pipeline
370
  and hasattr(app_instance.app, "interrupt_flag")
@@ -373,20 +216,10 @@ def sample_dpmpp_2m_cfgpp(
373
  return x
374
 
375
  if not pipeline:
376
- app_instance.app.progress.set(i / (len(sigmas) - 1))
377
-
378
- # Get current CFG scale
379
- current_cfg = get_cfg_scale(i)
380
-
381
- def post_cfg_function(args):
382
- nonlocal old_uncond_denoised
383
- old_uncond_denoised = args["uncond_denoised"]
384
- return args["denoised"]
385
 
386
- model_options = extra_args.get("model_options", {}).copy()
387
- extra_args["model_options"] = set_model_options_post_cfg_function(
388
- model_options, post_cfg_function, disable_cfg1_optimization=True
389
- )
390
 
391
  # Fused model inference and update calculations
392
  denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -406,27 +239,29 @@ def sample_dpmpp_2m_cfgpp(
406
  }
407
  )
408
 
409
- # CFG++ update step
410
  if old_uncond_denoised is None or sigmas[i + 1] == 0:
411
- # First step or last step - regular update
412
- cfg_denoised = uncond_denoised + (denoised - uncond_denoised) * current_cfg
413
  else:
414
- # CFG++ combination with momentum
415
- x0_coeff = cfg_x0_scale * current_cfg
416
- s_coeff = cfg_s_scale * current_cfg
417
-
418
- # Momentum terms
419
  h_ratio = h_steps[i - 1] / (2 * h_steps[i])
420
- momentum = (1 + h_ratio) * denoised - h_ratio * old_denoised
421
- uncond_momentum = (
422
- 1 + h_ratio
423
- ) * uncond_denoised - h_ratio * old_uncond_denoised
424
 
425
- # Combined update
426
- cfg_denoised = uncond_momentum + (momentum - uncond_momentum) * x0_coeff
 
 
 
427
 
428
- # Apply update
429
- x = ratios[i] * x - (-h_steps[i]).expm1() * cfg_denoised
 
 
 
 
 
 
430
 
431
  old_denoised = denoised
432
  old_uncond_denoised = uncond_denoised
@@ -438,17 +273,6 @@ def sample_dpmpp_2m_cfgpp(
438
  return x
439
 
440
 
441
- def set_model_options_post_cfg_function(
442
- model_options, post_cfg_function, disable_cfg1_optimization=False
443
- ):
444
- model_options["sampler_post_cfg_function"] = model_options.get(
445
- "sampler_post_cfg_function", []
446
- ) + [post_cfg_function]
447
- if disable_cfg1_optimization:
448
- model_options["disable_cfg1_optimization"] = True
449
- return model_options
450
-
451
-
452
  @torch.no_grad()
453
  def sample_dpmpp_sde_cfgpp(
454
  model,
@@ -572,7 +396,6 @@ def sample_dpmpp_sde_cfgpp(
572
  else:
573
  # CFG++ with momentum
574
  x0_coeff = cfg_x0_scale * current_cfg
575
- s_coeff = cfg_s_scale * current_cfg
576
 
577
  # Calculate momentum terms
578
  h_ratio = (t - s_) / (2 * (t - t_next))
 
142
  return x
143
 
144
 
145
+ def set_model_options_post_cfg_function(
146
+ model_options, post_cfg_function, disable_cfg1_optimization=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  ):
148
+ model_options["sampler_post_cfg_function"] = model_options.get(
149
+ "sampler_post_cfg_function", []
150
+ ) + [post_cfg_function]
151
+ if disable_cfg1_optimization:
152
+ model_options["disable_cfg1_optimization"] = True
153
+ return model_options
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
 
156
  @torch.no_grad()
 
188
  ratios = sigma_steps[1:] / sigma_steps[:-1]
189
  h_steps = t_steps[1:] - t_steps[:-1]
190
 
191
+ # Pre-calculate CFG schedule for the entire sampling process
192
+ steps = torch.arange(n_steps, device=device)
193
+ cfg_values = cfg_scale + (cfg_min - cfg_scale) * (steps / n_steps)
 
 
194
 
195
  old_denoised = None
196
  old_uncond_denoised = None
197
  extra_args = {} if extra_args is None else extra_args
198
 
199
+ # Define post-CFG function once outside the loop
200
+ def post_cfg_function(args):
201
+ nonlocal old_uncond_denoised
202
+ old_uncond_denoised = args["uncond_denoised"]
203
+ return args["denoised"]
204
+
205
+ model_options = extra_args.get("model_options", {}).copy()
206
+ extra_args["model_options"] = set_model_options_post_cfg_function(
207
+ model_options, post_cfg_function, disable_cfg1_optimization=True
208
+ )
209
+
210
+ for i in trange(n_steps, disable=disable):
211
  if (
212
  not pipeline
213
  and hasattr(app_instance.app, "interrupt_flag")
 
216
  return x
217
 
218
  if not pipeline:
219
+ app_instance.app.progress.set(i / n_steps)
 
 
 
 
 
 
 
 
220
 
221
+ # Use pre-calculated CFG scale
222
+ current_cfg = cfg_values[i]
 
 
223
 
224
  # Fused model inference and update calculations
225
  denoised = model(x, sigmas[i] * s_in, **extra_args)
 
239
  }
240
  )
241
 
242
+ # CFG++ update step using optimized operations
243
  if old_uncond_denoised is None or sigmas[i + 1] == 0:
244
+ # First step or last step - use torch.lerp for efficient interpolation
245
+ cfg_denoised = torch.lerp(uncond_denoised, denoised, current_cfg)
246
  else:
247
+ # Fused momentum calculations
 
 
 
 
248
  h_ratio = h_steps[i - 1] / (2 * h_steps[i])
249
+ h_ratio_plus_1 = 1 + h_ratio
 
 
 
250
 
251
+ # Use fused multiply-add operations for momentum terms
252
+ momentum = torch.addcmul(denoised * h_ratio_plus_1, old_denoised, -h_ratio)
253
+ uncond_momentum = torch.addcmul(
254
+ uncond_denoised * h_ratio_plus_1, old_uncond_denoised, -h_ratio
255
+ )
256
 
257
+ # Optimized interpolation for CFG++ update
258
+ cfg_denoised = torch.lerp(
259
+ uncond_momentum, momentum, current_cfg * cfg_x0_scale
260
+ )
261
+
262
+ # Apply update with pre-calculated expm1
263
+ h_expm1 = torch.expm1(-h_steps[i])
264
+ x = ratios[i] * x - h_expm1 * cfg_denoised
265
 
266
  old_denoised = denoised
267
  old_uncond_denoised = uncond_denoised
 
273
  return x
274
 
275
 
 
 
 
 
 
 
 
 
 
 
 
276
  @torch.no_grad()
277
  def sample_dpmpp_sde_cfgpp(
278
  model,
 
396
  else:
397
  # CFG++ with momentum
398
  x0_coeff = cfg_x0_scale * current_cfg
 
399
 
400
  # Calculate momentum terms
401
  h_ratio = (t - s_) / (2 * (t - t_next))
modules/sample/sampling.py CHANGED
@@ -76,6 +76,7 @@ class EPS:
76
  if max_denoise:
77
  noise = noise * torch.sqrt(1.0 + sigma**2.0)
78
  else:
 
79
  noise = noise * sigma
80
 
81
  noise += latent_image
@@ -513,153 +514,22 @@ def ksampler(
513
  #### Returns:
514
  - `KSAMPLER`: The KSAMPLER object.
515
  """
516
- if sampler_name == "dpmpp_2m":
517
-
518
- def dpmpp_2m_function(
519
- model: torch.nn.Module,
520
- noise: torch.Tensor,
521
- sigmas: torch.Tensor,
522
- extra_args: dict,
523
- callback: callable,
524
- disable: bool,
525
- pipeline: bool,
526
- **extra_options,
527
- ) -> torch.Tensor:
528
- sigma_min = sigmas[-1]
529
- if sigma_min == 0:
530
- sigma_min = sigmas[-2]
531
- return samplers.sample_dpmpp_2m(
532
- model,
533
- noise,
534
- sigmas,
535
- extra_args=extra_args,
536
- callback=callback,
537
- disable=disable,
538
- pipeline=pipeline,
539
- **extra_options,
540
- )
541
-
542
- sampler_function = dpmpp_2m_function
543
-
544
- elif sampler_name == "dpmpp_2m_cfgpp":
545
-
546
- def dpmpp_2m_dy_function(
547
- model: torch.nn.Module,
548
- noise: torch.Tensor,
549
- sigmas: torch.Tensor,
550
- extra_args: dict,
551
- callback: callable,
552
- disable: bool,
553
- pipeline: bool,
554
- **extra_options,
555
- ) -> torch.Tensor:
556
- sigma_min = sigmas[-1]
557
- if sigma_min == 0:
558
- sigma_min = sigmas[-2]
559
- return samplers.sample_dpmpp_2m_cfgpp(
560
- model,
561
- noise,
562
- sigmas,
563
- extra_args=extra_args,
564
- callback=callback,
565
- disable=disable,
566
- pipeline=pipeline,
567
- **extra_options,
568
- )
569
-
570
- sampler_function = dpmpp_2m_dy_function
571
-
572
- elif sampler_name == "dpmpp_sde":
573
-
574
- def dpmpp_sde_function(
575
- model: torch.nn.Module,
576
- noise: torch.Tensor,
577
- sigmas: torch.Tensor,
578
- extra_args: dict,
579
- callback: callable,
580
- disable: bool,
581
- pipeline: bool,
582
- **extra_options,
583
- ) -> torch.Tensor:
584
- return samplers.sample_dpmpp_sde(
585
- model,
586
- noise,
587
- sigmas,
588
- extra_args=extra_args,
589
- callback=callback,
590
- disable=disable,
591
- pipeline=pipeline,
592
- **extra_options,
593
- )
594
-
595
- sampler_function = dpmpp_sde_function
596
 
597
  elif sampler_name == "euler_ancestral":
598
-
599
- def euler_ancestral_function(
600
- model: torch.nn.Module,
601
- noise: torch.Tensor,
602
- sigmas: torch.Tensor,
603
- extra_args: dict,
604
- callback: callable,
605
- disable: bool,
606
- pipeline: bool,
607
- ) -> torch.Tensor:
608
- return samplers.sample_euler_ancestral(
609
- model,
610
- noise,
611
- sigmas,
612
- extra_args=extra_args,
613
- callback=callback,
614
- disable=disable,
615
- pipeline=pipeline,
616
- **extra_options,
617
- )
618
-
619
- sampler_function = euler_ancestral_function
620
 
621
  elif sampler_name == "dpmpp_sde_cfgpp":
622
-
623
- def dpmpp_sde_dy_function(
624
- model: torch.nn.Module,
625
- noise: torch.Tensor,
626
- sigmas: torch.Tensor,
627
- extra_args: dict,
628
- callback: callable,
629
- disable: bool,
630
- pipeline: bool,
631
- **extra_options,
632
- ) -> torch.Tensor:
633
- return samplers.sample_dpmpp_sde_cfgpp(
634
- model,
635
- noise,
636
- sigmas,
637
- extra_args=extra_args,
638
- callback=callback,
639
- disable=disable,
640
- pipeline=pipeline,
641
- **extra_options,
642
- )
643
-
644
- sampler_function = dpmpp_sde_dy_function
645
 
646
  elif sampler_name == "euler":
 
647
 
648
- def euler_function(
649
- model, noise, sigmas, extra_args, callback, disable, pipeline=False
650
- ):
651
- return samplers.sample_euler(
652
- model,
653
- noise,
654
- sigmas,
655
- extra_args=extra_args,
656
- callback=callback,
657
- disable=disable,
658
- pipeline=pipeline,
659
- **extra_options,
660
- )
661
-
662
- sampler_function = euler_function
663
 
664
  return KSAMPLER(sampler_function, extra_options, inpaint_options)
665
 
@@ -734,49 +604,49 @@ def sampler_object(name: str, pipeline: bool = False) -> KSAMPLER:
734
  return sampler
735
 
736
 
737
- class KSampler1:
738
- """#### Class for KSampler1."""
739
 
740
  def __init__(
741
  self,
742
- model: torch.nn.Module,
743
- steps: int,
744
- device,
745
  sampler: str = None,
746
  scheduler: str = None,
747
- denoise: float = None,
748
  model_options: dict = {},
749
  pipeline: bool = False,
750
  ):
751
- """#### Initialize the KSampler1 class.
752
-
753
- #### Args:
754
- - `model` (torch.nn.Module): The model.
755
- - `steps` (int): The number of steps.
756
- - `device` (torch.device): The device.
757
- - `sampler` (str, optional): The sampler name. Defaults to None.
758
- - `scheduler` (str, optional): The scheduler name. Defaults to None.
759
- - `denoise` (float, optional): The denoise factor. Defaults to None.
760
- - `model_options` (dict, optional): The model options. Defaults to {}.
761
- - `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
762
  """
763
  self.model = model
764
- self.device = device
765
  self.scheduler = scheduler
766
- self.sampler = sampler
767
- self.set_steps(steps, denoise)
768
  self.denoise = denoise
769
  self.model_options = model_options
770
  self.pipeline = pipeline
771
 
 
 
 
772
  def calculate_sigmas(self, steps: int) -> torch.Tensor:
773
- """#### Calculate the sigmas for the given steps.
774
 
775
- #### Args:
776
- - `steps` (int): The number of steps.
777
 
778
- #### Returns:
779
- - `torch.Tensor`: The calculated sigmas.
780
  """
781
  sigmas = ksampler_util.calculate_sigmas(
782
  self.model.get_model_object("model_sampling"), self.scheduler, steps
@@ -784,11 +654,11 @@ class KSampler1:
784
  return sigmas
785
 
786
  def set_steps(self, steps: int, denoise: float = None):
787
- """#### Set the steps and calculate the sigmas.
788
 
789
- #### Args:
790
- - `steps` (int): The number of steps.
791
- - `denoise` (float, optional): The denoise factor. Defaults to None.
792
  """
793
  self.steps = steps
794
  if denoise is None or denoise > 0.9999:
@@ -801,7 +671,29 @@ class KSampler1:
801
  sigmas = self.calculate_sigmas(new_steps).to(self.device)
802
  self.sigmas = sigmas[-(steps + 1) :]
803
 
804
- def sample(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
805
  self,
806
  noise: torch.Tensor,
807
  positive: torch.Tensor,
@@ -816,48 +708,45 @@ class KSampler1:
816
  callback: callable = None,
817
  disable_pbar: bool = False,
818
  seed: int = None,
819
- pipeline: bool = False,
820
  flux: bool = False,
821
  ) -> torch.Tensor:
822
- """#### Sample using the KSampler1.
823
-
824
- #### Args:
825
- - `noise` (torch.Tensor): The noise tensor.
826
- - `positive` (torch.Tensor): The positive tensor.
827
- - `negative` (torch.Tensor): The negative tensor.
828
- - `cfg` (float): The CFG value.
829
- - `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
830
- - `start_step` (int, optional): The start step. Defaults to None.
831
- - `last_step` (int, optional): The last step. Defaults to None.
832
- - `force_full_denoise` (bool, optional): Whether to force full denoise. Defaults to False.
833
- - `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
834
- - `sigmas` (torch.Tensor, optional): The sigmas tensor. Defaults to None.
835
- - `callback` (callable, optional): The callback function. Defaults to None.
836
- - `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
837
- - `seed` (int, optional): The seed value. Defaults to None.
838
- - `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
839
-
840
- #### Returns:
841
- - `torch.Tensor`: The sampled tensor.
842
  """
 
 
 
843
  if sigmas is None:
844
  sigmas = self.sigmas
845
 
846
- if last_step is not None and last_step < (len(sigmas) - 1):
847
- sigmas = sigmas[: last_step + 1]
848
- if force_full_denoise:
849
- sigmas[-1] = 0
850
 
851
- if start_step is not None:
852
- if start_step < (len(sigmas) - 1):
853
- sigmas = sigmas[start_step:]
 
854
  else:
855
- if latent_image is not None:
856
- return latent_image
857
- else:
858
- return torch.zeros_like(noise)
859
 
860
- sampler = sampler_object(self.sampler, pipeline=pipeline)
861
 
862
  return sample(
863
  self.model,
@@ -866,7 +755,7 @@ class KSampler1:
866
  negative,
867
  cfg,
868
  self.device,
869
- sampler,
870
  sigmas,
871
  self.model_options,
872
  latent_image=latent_image,
@@ -874,11 +763,117 @@ class KSampler1:
874
  callback=callback,
875
  disable_pbar=disable_pbar,
876
  seed=seed,
877
- pipeline=pipeline,
878
  flux=flux,
879
  )
880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882
  def sample1(
883
  model: torch.nn.Module,
884
  noise: torch.Tensor,
@@ -902,37 +897,37 @@ def sample1(
902
  pipeline: bool = False,
903
  flux: bool = False,
904
  ) -> torch.Tensor:
905
- """#### Sample using the given parameters.
906
-
907
- #### Args:
908
- - `model` (torch.nn.Module): The model.
909
- - `noise` (torch.Tensor): The noise tensor.
910
- - `steps` (int): The number of steps.
911
- - `cfg` (float): The CFG value.
912
- - `sampler_name` (str): The sampler name.
913
- - `scheduler` (str): The scheduler name.
914
- - `positive` (torch.Tensor): The positive tensor.
915
- - `negative` (torch.Tensor): The negative tensor.
916
- - `latent_image` (torch.Tensor): The latent image tensor.
917
- - `denoise` (float, optional): The denoise factor. Defaults to 1.0.
918
- - `disable_noise` (bool, optional): Whether to disable noise. Defaults to False.
919
- - `start_step` (int, optional): The start step. Defaults to None.
920
- - `last_step` (int, optional): The last step. Defaults to None.
921
- - `force_full_denoise` (bool, optional): Whether to force full denoise. Defaults to False.
922
- - `noise_mask` (torch.Tensor, optional): The noise mask tensor. Defaults to None.
923
- - `sigmas` (torch.Tensor, optional): The sigmas tensor. Defaults to None.
924
- - `callback` (callable, optional): The callback function. Defaults to None.
925
- - `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False.
926
- - `seed` (int, optional): The seed value. Defaults to None.
927
- - `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
928
-
929
- #### Returns:
930
- - `torch.Tensor`: The sampled tensor.
 
931
  """
932
- sampler = KSampler1(
933
- model,
934
  steps=steps,
935
- device=model.load_device,
936
  sampler=sampler_name,
937
  scheduler=scheduler,
938
  denoise=denoise,
@@ -940,7 +935,7 @@ def sample1(
940
  pipeline=pipeline,
941
  )
942
 
943
- samples = sampler.sample(
944
  noise,
945
  positive,
946
  negative,
@@ -954,147 +949,12 @@ def sample1(
954
  callback=callback,
955
  disable_pbar=disable_pbar,
956
  seed=seed,
957
- pipeline=pipeline,
958
  flux=flux,
959
  )
960
  samples = samples.to(Device.intermediate_device())
961
  return samples
962
 
963
 
964
- def common_ksampler(
965
- model: torch.nn.Module,
966
- seed: int,
967
- steps: int,
968
- cfg: float,
969
- sampler_name: str,
970
- scheduler: str,
971
- positive: torch.Tensor,
972
- negative: torch.Tensor,
973
- latent: dict,
974
- denoise: float = 1.0,
975
- disable_noise: bool = False,
976
- start_step: int = None,
977
- last_step: int = None,
978
- force_full_denoise: bool = False,
979
- pipeline: bool = False,
980
- flux: bool = False,
981
- ) -> tuple:
982
- """#### Common ksampler function.
983
-
984
- #### Args:
985
- - `model` (torch.nn.Module): The model.
986
- - `seed` (int): The seed value.
987
- - `steps` (int): The number of steps.
988
- - `cfg` (float): The CFG value.
989
- - `sampler_name` (str): The sampler name.
990
- - `scheduler` (str): The scheduler name.
991
- - `positive` (torch.Tensor): The positive tensor.
992
- - `negative` (torch.Tensor): The negative tensor.
993
- - `latent` (dict): The latent dictionary.
994
- - `denoise` (float, optional): The denoise factor. Defaults to 1.0.
995
- - `disable_noise` (bool, optional): Whether to disable noise. Defaults to False.
996
- - `start_step` (int, optional): The start step. Defaults to None.
997
- - `last_step` (int, optional): The last step. Defaults to None.
998
- - `force_full_denoise` (bool, optional): Whether to force full denoise. Defaults to False.
999
- - `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
1000
-
1001
- #### Returns:
1002
- - `tuple`: The output tuple containing the latent dictionary and samples.
1003
- """
1004
- latent_image = latent["samples"]
1005
- latent_image = Latent.fix_empty_latent_channels(model, latent_image)
1006
-
1007
- if disable_noise:
1008
- noise = torch.zeros(
1009
- latent_image.size(),
1010
- dtype=latent_image.dtype,
1011
- layout=latent_image.layout,
1012
- device="cpu",
1013
- )
1014
- else:
1015
- batch_inds = latent["batch_index"] if "batch_index" in latent else None
1016
- noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
1017
-
1018
- noise_mask = None
1019
- if "noise_mask" in latent:
1020
- noise_mask = latent["noise_mask"]
1021
- samples = sample1(
1022
- model,
1023
- noise,
1024
- steps,
1025
- cfg,
1026
- sampler_name,
1027
- scheduler,
1028
- positive,
1029
- negative,
1030
- latent_image,
1031
- denoise=denoise,
1032
- disable_noise=disable_noise,
1033
- start_step=start_step,
1034
- last_step=last_step,
1035
- force_full_denoise=force_full_denoise,
1036
- noise_mask=noise_mask,
1037
- seed=seed,
1038
- pipeline=pipeline,
1039
- flux=flux,
1040
- )
1041
- out = latent.copy()
1042
- out["samples"] = samples
1043
- return (out,)
1044
-
1045
-
1046
- class KSampler2:
1047
- """#### Class for KSampler2."""
1048
-
1049
- def sample(
1050
- self,
1051
- model: torch.nn.Module,
1052
- seed: int,
1053
- steps: int,
1054
- cfg: float,
1055
- sampler_name: str,
1056
- scheduler: str,
1057
- positive: torch.Tensor,
1058
- negative: torch.Tensor,
1059
- latent_image: torch.Tensor,
1060
- denoise: float = 1.0,
1061
- pipeline: bool = False,
1062
- flux: bool = False,
1063
- ) -> tuple:
1064
- """#### Sample using the KSampler2.
1065
-
1066
- #### Args:
1067
- - `model` (torch.nn.Module): The model.
1068
- - `seed` (int): The seed value.
1069
- - `steps` (int): The number of steps.
1070
- - `cfg` (float): The CFG value.
1071
- - `sampler_name` (str): The sampler name.
1072
- - `scheduler` (str): The scheduler name.
1073
- - `positive` (torch.Tensor): The positive tensor.
1074
- - `negative` (torch.Tensor): The negative tensor.
1075
- - `latent_image` (torch.Tensor): The latent image tensor.
1076
- - `denoise` (float, optional): The denoise factor. Defaults to 1.0.
1077
- - `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
1078
-
1079
- #### Returns:
1080
- - `tuple`: The output tuple containing the latent dictionary and samples.
1081
- """
1082
- return common_ksampler(
1083
- model,
1084
- seed,
1085
- steps,
1086
- cfg,
1087
- sampler_name,
1088
- scheduler,
1089
- positive,
1090
- negative,
1091
- latent_image,
1092
- denoise=denoise,
1093
- pipeline=pipeline,
1094
- flux=flux,
1095
- )
1096
-
1097
-
1098
  class ModelType(Enum):
1099
  """#### Enum for Model Types."""
1100
 
@@ -1187,3 +1047,86 @@ def sample_custom(
1187
  )
1188
  samples = samples.to(Device.intermediate_device())
1189
  return samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  if max_denoise:
77
  noise = noise * torch.sqrt(1.0 + sigma**2.0)
78
  else:
79
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
80
  noise = noise * sigma
81
 
82
  noise += latent_image
 
514
  #### Returns:
515
  - `KSAMPLER`: The KSAMPLER object.
516
  """
517
+ if sampler_name == "dpmpp_2m_cfgpp":
518
+ sampler_function = samplers.sample_dpmpp_2m_cfgpp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
  elif sampler_name == "euler_ancestral":
521
+ sampler_function = samplers.sample_euler_ancestral
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  elif sampler_name == "dpmpp_sde_cfgpp":
524
+ sampler_function = samplers.sample_dpmpp_sde_cfgpp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  elif sampler_name == "euler":
527
+ sampler_function = samplers.sample_euler
528
 
529
+ else:
530
+ # Default fallback
531
+ sampler_function = samplers.sample_euler
532
+ print(f"Warning: Unknown sampler '{sampler_name}', falling back to euler")
 
 
 
 
 
 
 
 
 
 
 
533
 
534
  return KSAMPLER(sampler_function, extra_options, inpaint_options)
535
 
 
604
  return sampler
605
 
606
 
607
+ class KSampler:
608
+ """A unified sampler class that replaces both KSampler1 and KSampler2."""
609
 
610
  def __init__(
611
  self,
612
+ model: torch.nn.Module = None,
613
+ steps: int = None,
 
614
  sampler: str = None,
615
  scheduler: str = None,
616
+ denoise: float = 1.0,
617
  model_options: dict = {},
618
  pipeline: bool = False,
619
  ):
620
+ """Initialize the KSampler class.
621
+
622
+ Args:
623
+ model (torch.nn.Module, optional): The model to use for sampling. Required for direct sampling.
624
+ steps (int, optional): The number of steps. Required for direct sampling.
625
+ sampler (str, optional): The sampler name. Defaults to None.
626
+ scheduler (str, optional): The scheduler name. Defaults to None.
627
+ denoise (float, optional): The denoise factor. Defaults to 1.0.
628
+ model_options (dict, optional): The model options. Defaults to {}.
629
+ pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
 
630
  """
631
  self.model = model
632
+ self.device = model.load_device if model is not None else None
633
  self.scheduler = scheduler
634
+ self.sampler_name = sampler
 
635
  self.denoise = denoise
636
  self.model_options = model_options
637
  self.pipeline = pipeline
638
 
639
+ if model is not None and steps is not None:
640
+ self.set_steps(steps, denoise)
641
+
642
  def calculate_sigmas(self, steps: int) -> torch.Tensor:
643
+ """Calculate the sigmas for the given steps.
644
 
645
+ Args:
646
+ steps (int): The number of steps.
647
 
648
+ Returns:
649
+ torch.Tensor: The calculated sigmas.
650
  """
651
  sigmas = ksampler_util.calculate_sigmas(
652
  self.model.get_model_object("model_sampling"), self.scheduler, steps
 
654
  return sigmas
655
 
656
  def set_steps(self, steps: int, denoise: float = None):
657
+ """Set the steps and calculate the sigmas.
658
 
659
+ Args:
660
+ steps (int): The number of steps.
661
+ denoise (float, optional): The denoise factor. Defaults to None.
662
  """
663
  self.steps = steps
664
  if denoise is None or denoise > 0.9999:
 
671
  sigmas = self.calculate_sigmas(new_steps).to(self.device)
672
  self.sigmas = sigmas[-(steps + 1) :]
673
 
674
+ def _process_sigmas(self, sigmas, start_step, last_step, force_full_denoise):
675
+ """Process sigmas based on start_step and last_step.
676
+
677
+ Args:
678
+ sigmas (torch.Tensor): The sigmas tensor.
679
+ start_step (int, optional): The start step. Defaults to None.
680
+ last_step (int, optional): The last step. Defaults to None.
681
+ force_full_denoise (bool): Whether to force full denoise.
682
+
683
+ Returns:
684
+ torch.Tensor: The processed sigmas.
685
+ """
686
+ if last_step is not None and last_step < (len(sigmas) - 1):
687
+ sigmas = sigmas[: last_step + 1]
688
+ if force_full_denoise:
689
+ sigmas[-1] = 0
690
+
691
+ if start_step is not None and start_step < (len(sigmas) - 1):
692
+ sigmas = sigmas[start_step:]
693
+
694
+ return sigmas
695
+
696
+ def direct_sample(
697
  self,
698
  noise: torch.Tensor,
699
  positive: torch.Tensor,
 
708
  callback: callable = None,
709
  disable_pbar: bool = False,
710
  seed: int = None,
 
711
  flux: bool = False,
712
  ) -> torch.Tensor:
713
+ """Sample directly with the initialized model and parameters.
714
+
715
+ Args:
716
+ noise (torch.Tensor): The noise tensor.
717
+ positive (torch.Tensor): The positive tensor.
718
+ negative (torch.Tensor): The negative tensor.
719
+ cfg (float): The CFG value.
720
+ latent_image (torch.Tensor, optional): The latent image tensor. Defaults to None.
721
+ start_step (int, optional): The start step. Defaults to None.
722
+ last_step (int, optional): The last step. Defaults to None.
723
+ force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
724
+ denoise_mask (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
725
+ sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
726
+ callback (callable, optional): The callback function. Defaults to None.
727
+ disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
728
+ seed (int, optional): The seed value. Defaults to None.
729
+ flux (bool, optional): Whether to use flux mode. Defaults to False.
730
+
731
+ Returns:
732
+ torch.Tensor: The sampled tensor.
733
  """
734
+ if self.model is None:
735
+ raise ValueError("Model must be provided for direct sampling")
736
+
737
  if sigmas is None:
738
  sigmas = self.sigmas
739
 
740
+ sigmas = self._process_sigmas(sigmas, start_step, last_step, force_full_denoise)
 
 
 
741
 
742
+ # Early return if needed
743
+ if start_step is not None and start_step >= (len(sigmas) - 1):
744
+ if latent_image is not None:
745
+ return latent_image
746
  else:
747
+ return torch.zeros_like(noise)
 
 
 
748
 
749
+ sampler_obj = sampler_object(self.sampler_name, pipeline=self.pipeline)
750
 
751
  return sample(
752
  self.model,
 
755
  negative,
756
  cfg,
757
  self.device,
758
+ sampler_obj,
759
  sigmas,
760
  self.model_options,
761
  latent_image=latent_image,
 
763
  callback=callback,
764
  disable_pbar=disable_pbar,
765
  seed=seed,
766
+ pipeline=self.pipeline,
767
  flux=flux,
768
  )
769
 
770
+ def sample(
771
+ self,
772
+ model: torch.nn.Module = None,
773
+ seed: int = None,
774
+ steps: int = None,
775
+ cfg: float = None,
776
+ sampler_name: str = None,
777
+ scheduler: str = None,
778
+ positive: torch.Tensor = None,
779
+ negative: torch.Tensor = None,
780
+ latent_image: torch.Tensor = None,
781
+ denoise: float = None,
782
+ start_step: int = None,
783
+ last_step: int = None,
784
+ force_full_denoise: bool = False,
785
+ noise_mask: torch.Tensor = None,
786
+ callback: callable = None,
787
+ disable_pbar: bool = False,
788
+ disable_noise: bool = False,
789
+ pipeline: bool = False,
790
+ flux: bool = False,
791
+ ) -> tuple:
792
+ """Unified sampling interface that works both as direct sampling and through the common_ksampler.
793
+
794
+ This method can be used in two ways:
795
+ 1. If model is provided, it will create a temporary sampler and use that
796
+ 2. If model is None, it will use the pre-initialized model and parameters
797
+
798
+ Args:
799
+ model (torch.nn.Module, optional): The model to use for sampling. If None, uses pre-initialized model.
800
+ seed (int, optional): The seed value.
801
+ steps (int, optional): The number of steps. If None, uses pre-initialized steps.
802
+ cfg (float, optional): The CFG value.
803
+ sampler_name (str, optional): The sampler name. If None, uses pre-initialized sampler.
804
+ scheduler (str, optional): The scheduler name. If None, uses pre-initialized scheduler.
805
+ positive (torch.Tensor, optional): The positive tensor.
806
+ negative (torch.Tensor, optional): The negative tensor.
807
+ latent_image (torch.Tensor, optional): The latent image tensor.
808
+ denoise (float, optional): The denoise factor. If None, uses pre-initialized denoise.
809
+ start_step (int, optional): The start step. Defaults to None.
810
+ last_step (int, optional): The last step. Defaults to None.
811
+ force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
812
+ noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
813
+ callback (callable, optional): The callback function. Defaults to None.
814
+ disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
815
+ disable_noise (bool, optional): Whether to disable noise. Defaults to False.
816
+ pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
817
+ flux (bool, optional): Whether to use flux mode. Defaults to False.
818
+
819
+ Returns:
820
+ tuple: The output tuple containing either (latent_dict,) or the sampled tensor.
821
+ """
822
+ # Case 1: Use pre-initialized model for direct sampling
823
+ if model is None:
824
+ if latent_image is None:
825
+ raise ValueError(
826
+ "latent_image must be provided when using pre-initialized model"
827
+ )
828
+
829
+ return (
830
+ self.direct_sample(
831
+ None, # noise will be generated in common_ksampler
832
+ positive,
833
+ negative,
834
+ cfg,
835
+ latent_image,
836
+ start_step,
837
+ last_step,
838
+ force_full_denoise,
839
+ noise_mask,
840
+ None, # sigmas will use pre-calculated ones
841
+ callback,
842
+ disable_pbar,
843
+ seed,
844
+ flux,
845
+ ),
846
+ )
847
+
848
+ # Case 2: Use common_ksampler approach with provided model
849
+ else:
850
+ # For backwards compatibility with KSampler2 usage pattern
851
+ if isinstance(latent_image, dict):
852
+ latent = latent_image
853
+ else:
854
+ latent = {"samples": latent_image}
855
 
856
+ return common_ksampler(
857
+ model,
858
+ seed,
859
+ steps,
860
+ cfg,
861
+ sampler_name or self.sampler_name,
862
+ scheduler or self.scheduler,
863
+ positive,
864
+ negative,
865
+ latent,
866
+ denoise or self.denoise,
867
+ disable_noise,
868
+ start_step,
869
+ last_step,
870
+ force_full_denoise,
871
+ pipeline or self.pipeline,
872
+ flux,
873
+ )
874
+
875
+
876
+ # Refactor sample1 to use KSampler directly
877
  def sample1(
878
  model: torch.nn.Module,
879
  noise: torch.Tensor,
 
897
  pipeline: bool = False,
898
  flux: bool = False,
899
  ) -> torch.Tensor:
900
+ """Sample using the given parameters with the unified KSampler.
901
+
902
+ Args:
903
+ model (torch.nn.Module): The model.
904
+ noise (torch.Tensor): The noise tensor.
905
+ steps (int): The number of steps.
906
+ cfg (float): The CFG value.
907
+ sampler_name (str): The sampler name.
908
+ scheduler (str): The scheduler name.
909
+ positive (torch.Tensor): The positive tensor.
910
+ negative (torch.Tensor): The negative tensor.
911
+ latent_image (torch.Tensor): The latent image tensor.
912
+ denoise (float, optional): The denoise factor. Defaults to 1.0.
913
+ disable_noise (bool, optional): Whether to disable noise. Defaults to False.
914
+ start_step (int, optional): The start step. Defaults to None.
915
+ last_step (int, optional): The last step. Defaults to None.
916
+ force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
917
+ noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
918
+ sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
919
+ callback (callable, optional): The callback function. Defaults to None.
920
+ disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
921
+ seed (int, optional): The seed value. Defaults to None.
922
+ pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
923
+ flux (bool, optional): Whether to use flux mode. Defaults to False.
924
+
925
+ Returns:
926
+ torch.Tensor: The sampled tensor.
927
  """
928
+ sampler = KSampler(
929
+ model=model,
930
  steps=steps,
 
931
  sampler=sampler_name,
932
  scheduler=scheduler,
933
  denoise=denoise,
 
935
  pipeline=pipeline,
936
  )
937
 
938
+ samples = sampler.direct_sample(
939
  noise,
940
  positive,
941
  negative,
 
949
  callback=callback,
950
  disable_pbar=disable_pbar,
951
  seed=seed,
 
952
  flux=flux,
953
  )
954
  samples = samples.to(Device.intermediate_device())
955
  return samples
956
 
957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958
  class ModelType(Enum):
959
  """#### Enum for Model Types."""
960
 
 
1047
  )
1048
  samples = samples.to(Device.intermediate_device())
1049
  return samples
1050
+
1051
+
1052
+ def common_ksampler(
1053
+ model: torch.nn.Module,
1054
+ seed: int,
1055
+ steps: int,
1056
+ cfg: float,
1057
+ sampler_name: str,
1058
+ scheduler: str,
1059
+ positive: torch.Tensor,
1060
+ negative: torch.Tensor,
1061
+ latent: dict,
1062
+ denoise: float = 1.0,
1063
+ disable_noise: bool = False,
1064
+ start_step: int = None,
1065
+ last_step: int = None,
1066
+ force_full_denoise: bool = False,
1067
+ pipeline: bool = False,
1068
+ flux: bool = False,
1069
+ ) -> tuple:
1070
+ """Common ksampler function.
1071
+
1072
+ Args:
1073
+ model (torch.nn.Module): The model.
1074
+ seed (int): The seed value.
1075
+ steps (int): The number of steps.
1076
+ cfg (float): The CFG value.
1077
+ sampler_name (str): The sampler name.
1078
+ scheduler (str): The scheduler name.
1079
+ positive (torch.Tensor): The positive tensor.
1080
+ negative (torch.Tensor): The negative tensor.
1081
+ latent (dict): The latent dictionary.
1082
+ denoise (float, optional): The denoise factor. Defaults to 1.0.
1083
+ disable_noise (bool, optional): Whether to disable noise. Defaults to False.
1084
+ start_step (int, optional): The start step. Defaults to None.
1085
+ last_step (int, optional): The last step. Defaults to None.
1086
+ force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
1087
+ pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
1088
+ flux (bool, optional): Whether to use flux mode. Defaults to False.
1089
+
1090
+ Returns:
1091
+ tuple: The output tuple containing the latent dictionary and samples.
1092
+ """
1093
+ latent_image = latent["samples"]
1094
+ latent_image = Latent.fix_empty_latent_channels(model, latent_image)
1095
+
1096
+ if disable_noise:
1097
+ noise = torch.zeros(
1098
+ latent_image.size(),
1099
+ dtype=latent_image.dtype,
1100
+ layout=latent_image.layout,
1101
+ device="cpu",
1102
+ )
1103
+ else:
1104
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
1105
+ noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
1106
+
1107
+ noise_mask = None
1108
+ if "noise_mask" in latent:
1109
+ noise_mask = latent["noise_mask"]
1110
+ samples = sample1(
1111
+ model,
1112
+ noise,
1113
+ steps,
1114
+ cfg,
1115
+ sampler_name,
1116
+ scheduler,
1117
+ positive,
1118
+ negative,
1119
+ latent_image,
1120
+ denoise=denoise,
1121
+ disable_noise=disable_noise,
1122
+ start_step=start_step,
1123
+ last_step=last_step,
1124
+ force_full_denoise=force_full_denoise,
1125
+ noise_mask=noise_mask,
1126
+ seed=seed,
1127
+ pipeline=pipeline,
1128
+ flux=flux,
1129
+ )
1130
+ out = latent.copy()
1131
+ out["samples"] = samples
1132
+ return (out,)
modules/user/GUI.py CHANGED
@@ -449,7 +449,9 @@ class App(tk.Tk):
449
  img_tensor = img_tensor.unsqueeze(0)
450
  self.interrupt_flag = False
451
  self.sampler = (
452
- "dpmpp_sde" if not self.prioritize_speed_var.get() else "dpmpp_2m"
 
 
453
  )
454
  with torch.inference_mode():
455
  (
@@ -612,7 +614,7 @@ class App(tk.Tk):
612
  )
613
  self.cliptextencode = Clip.CLIPTextEncode()
614
  self.emptylatentimage = Latent.EmptyLatentImage()
615
- self.ksampler_instance = sampling.KSampler2()
616
  self.vaedecode = VariationalAE.VAEDecode()
617
  self.latent_upscale = upscale.LatentUpscale()
618
  self.upscalemodelloader = USDU_upscaler.UpscaleModelLoader()
@@ -637,7 +639,9 @@ class App(tk.Tk):
637
  self.generation_threads.append(current_thread)
638
  self.interrupt_flag = False
639
  self.sampler = (
640
- "dpmpp_sde" if not self.prioritize_speed_var.get() else "dpmpp_2m"
 
 
641
  )
642
  try:
643
  # Disable generate button during generation
@@ -955,7 +959,7 @@ class App(tk.Tk):
955
  unetloadergguf = Quantizer.UnetLoaderGGUF()
956
  cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
957
  conditioningzeroout = Quantizer.ConditioningZeroOut()
958
- ksampler = sampling.KSampler2()
959
  vaedecode = VariationalAE.VAEDecode()
960
  unetloadergguf_10 = unetloadergguf.load_unet(
961
  unet_name="flux1-dev-Q8_0.gguf"
 
449
  img_tensor = img_tensor.unsqueeze(0)
450
  self.interrupt_flag = False
451
  self.sampler = (
452
+ "dpmpp_sde_cfgpp"
453
+ if not self.prioritize_speed_var.get()
454
+ else "dpmpp_2m_cfgpp"
455
  )
456
  with torch.inference_mode():
457
  (
 
614
  )
615
  self.cliptextencode = Clip.CLIPTextEncode()
616
  self.emptylatentimage = Latent.EmptyLatentImage()
617
+ self.ksampler_instance = sampling.KSampler()
618
  self.vaedecode = VariationalAE.VAEDecode()
619
  self.latent_upscale = upscale.LatentUpscale()
620
  self.upscalemodelloader = USDU_upscaler.UpscaleModelLoader()
 
639
  self.generation_threads.append(current_thread)
640
  self.interrupt_flag = False
641
  self.sampler = (
642
+ "dpmpp_sde_cfgpp"
643
+ if not self.prioritize_speed_var.get()
644
+ else "dpmpp_2m_cfgpp"
645
  )
646
  try:
647
  # Disable generate button during generation
 
959
  unetloadergguf = Quantizer.UnetLoaderGGUF()
960
  cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
961
  conditioningzeroout = Quantizer.ConditioningZeroOut()
962
+ ksampler = sampling.KSampler()
963
  vaedecode = VariationalAE.VAEDecode()
964
  unetloadergguf_10 = unetloadergguf.load_unet(
965
  unet_name="flux1-dev-Q8_0.gguf"
modules/user/pipeline.py CHANGED
@@ -92,7 +92,7 @@ def pipeline(
92
  hidiffoptimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple()
93
  cliptextencode = Clip.CLIPTextEncode()
94
  emptylatentimage = Latent.EmptyLatentImage()
95
- ksampler_instance = sampling.KSampler2()
96
  vaedecode = VariationalAE.VAEDecode()
97
  saveimage = ImageSaver.SaveImage()
98
  latent_upscale = upscale.LatentUpscale()
@@ -187,7 +187,7 @@ def pipeline(
187
  unetloadergguf = Quantizer.UnetLoaderGGUF()
188
  cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
189
  conditioningzeroout = Quantizer.ConditioningZeroOut()
190
- ksampler = sampling.KSampler2()
191
  unetloadergguf_10 = unetloadergguf.load_unet(
192
  unet_name="flux1-dev-Q8_0.gguf"
193
  )
@@ -283,10 +283,10 @@ def pipeline(
283
  )
284
  else:
285
  applystablefast_158 = loraloader_274
286
- fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
287
- applystablefast_158 = fb_cache.patch(
288
- applystablefast_158, "diffusion_model", 0.120
289
- )
290
 
291
  ksampler_239 = ksampler_instance.sample(
292
  seed=seed,
 
92
  hidiffoptimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple()
93
  cliptextencode = Clip.CLIPTextEncode()
94
  emptylatentimage = Latent.EmptyLatentImage()
95
+ ksampler_instance = sampling.KSampler()
96
  vaedecode = VariationalAE.VAEDecode()
97
  saveimage = ImageSaver.SaveImage()
98
  latent_upscale = upscale.LatentUpscale()
 
187
  unetloadergguf = Quantizer.UnetLoaderGGUF()
188
  cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
189
  conditioningzeroout = Quantizer.ConditioningZeroOut()
190
+ ksampler = sampling.KSampler()
191
  unetloadergguf_10 = unetloadergguf.load_unet(
192
  unet_name="flux1-dev-Q8_0.gguf"
193
  )
 
283
  )
284
  else:
285
  applystablefast_158 = loraloader_274
286
+ # fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
287
+ # applystablefast_158 = fb_cache.patch(
288
+ # applystablefast_158, "diffusion_model", 0.120
289
+ # )
290
 
291
  ksampler_239 = ksampler_instance.sample(
292
  seed=seed,