Aatricks's picture
Upload folder using huggingface_hub
cfe609e verified
import torch
from modules.Utilities import util
from modules.Device import Device
from modules.cond import cond_util
from modules.sample import ksampler_util
class CONDRegular:
"""#### Class representing a regular condition."""
def __init__(self, cond: torch.Tensor):
"""#### Initialize the CONDRegular class.
#### Args:
- `cond` (torch.Tensor): The condition tensor.
"""
self.cond = cond
def _copy_with(self, cond: torch.Tensor) -> "CONDRegular":
"""#### Copy the condition with a new condition.
#### Args:
- `cond` (torch.Tensor): The new condition.
#### Returns:
- `CONDRegular`: The copied condition.
"""
return self.__class__(cond)
def process_cond(
self, batch_size: int, device: torch.device, **kwargs
) -> "CONDRegular":
"""#### Process the condition.
#### Args:
- `batch_size` (int): The batch size.
- `device` (torch.device): The device.
#### Returns:
- `CONDRegular`: The processed condition.
"""
return self._copy_with(
util.repeat_to_batch_size(self.cond, batch_size).to(device)
)
def can_concat(self, other: "CONDRegular") -> bool:
"""#### Check if conditions can be concatenated.
#### Args:
- `other` (CONDRegular): The other condition.
#### Returns:
- `bool`: True if conditions can be concatenated, False otherwise.
"""
if self.cond.shape != other.cond.shape:
return False
return True
def concat(self, others: list) -> torch.Tensor:
"""#### Concatenate conditions.
#### Args:
- `others` (list): The list of other conditions.
#### Returns:
- `torch.Tensor`: The concatenated conditions.
"""
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)
class CONDCrossAttn(CONDRegular):
"""#### Class representing a cross-attention condition."""
def can_concat(self, other: "CONDRegular") -> bool:
"""#### Check if conditions can be concatenated.
#### Args:
- `other` (CONDRegular): The other condition.
#### Returns:
- `bool`: True if conditions can be concatenated, False otherwise.
"""
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: # these 2 cases should not happen
return False
mult_min = torch.lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if (
diff > 4
): # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True
def concat(self, others: list) -> torch.Tensor:
"""Optimized version of cross-attention condition concatenation."""
conds = [self.cond]
shapes = [self.cond.shape[1]]
# Collect all conditions and their shapes
for x in others:
conds.append(x.cond)
shapes.append(x.cond.shape[1])
# Calculate LCM more efficiently
crossattn_max_len = util.lcm_of_list(shapes)
# Process and concat in one step where possible
if all(c.shape[1] == shapes[0] for c in conds):
# All same length, simple concatenation
return torch.cat(conds)
else:
# Process conditions that need repeating
out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
repeat_factor = crossattn_max_len // c.shape[1]
# Use repeat instead of individual operations
c = c.repeat(1, repeat_factor, 1)
out.append(c)
return torch.cat(out)
def convert_cond(cond: list) -> list:
"""#### Convert conditions to cross-attention conditions.
#### Args:
- `cond` (list): The list of conditions.
#### Returns:
- `list`: The converted conditions.
"""
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = CONDCrossAttn(c[0])
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def calc_cond_batch(
model: object,
conds: list,
x_in: torch.Tensor,
timestep: torch.Tensor,
model_options: dict,
) -> list:
"""#### Calculate the condition batch.
#### Args:
- `model` (object): The model.
- `conds` (list): The list of conditions.
- `x_in` (torch.Tensor): The input tensor.
- `timestep` (torch.Tensor): The timestep tensor.
- `model_options` (dict): The model options.
#### Returns:
- `list`: The calculated condition batch.
"""
out_conds = []
out_counts = []
to_run = []
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
if cond is not None:
for x in cond:
p = ksampler_util.get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, i)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if cond_util.can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = Device.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[: len(to_batch_temp) // i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_util.cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c["control"] = control.get_control(
input_x, timestep_, c, len(cond_or_uncond)
)
transformer_options = {}
if "transformer_options" in model_options:
transformer_options = model_options["transformer_options"].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c["transformer_options"] = transformer_options
if "model_function_wrapper" in model_options:
output = model_options["model_function_wrapper"](
model.apply_model,
{
"input": input_x,
"timestep": timestep_,
"c": c,
"cond_or_uncond": cond_or_uncond,
},
).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
# Vectorize the division at the end
for i in range(len(out_conds)):
# Inplace division is already efficient
out_conds[i].div_(out_counts[i]) # Using .div_ instead of /= for clarity
return out_conds
def encode_model_conds(
model_function: callable,
conds: list,
noise: torch.Tensor,
device: torch.device,
prompt_type: str,
**kwargs,
) -> list:
"""#### Encode model conditions.
#### Args:
- `model_function` (callable): The model function.
- `conds` (list): The list of conditions.
- `noise` (torch.Tensor): The noise tensor.
- `device` (torch.device): The device.
- `prompt_type` (str): The prompt type.
- `**kwargs`: Additional keyword arguments.
#### Returns:
- `list`: The encoded model conditions.
"""
for t in range(len(conds)):
x = conds[t]
params = x.copy()
params["device"] = device
params["noise"] = noise
default_width = None
if len(noise.shape) >= 4: # TODO: 8 multiple should be set by the model
default_width = noise.shape[3] * 8
params["width"] = params.get("width", default_width)
params["height"] = params.get("height", noise.shape[2] * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
for k in kwargs:
if k not in params:
params[k] = kwargs[k]
out = model_function(**params)
x = x.copy()
model_conds = x["model_conds"].copy()
for k in out:
model_conds[k] = out[k]
x["model_conds"] = model_conds
conds[t] = x
return conds
def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
"""Optimized version that processes areas and masks more efficiently"""
for i in range(len(conditions)):
c = conditions[i]
# Process area
if "area" in c:
area = c["area"]
if area[0] == "percentage":
# Vectorized calculation of area dimensions
a = area[1:]
a_len = len(a) // 2
# Calculate all dimensions at once using tensor operations
dims_tensor = torch.tensor(dims, device="cpu")
first_part = torch.tensor(a[:a_len], device="cpu") * dims_tensor
second_part = torch.tensor(a[a_len:], device="cpu") * dims_tensor
# Convert to rounded integers and tuple
first_part = torch.max(
torch.ones_like(first_part), torch.round(first_part)
)
second_part = torch.round(second_part)
# Create the new area tuple
new_area = tuple(first_part.int().tolist()) + tuple(
second_part.int().tolist()
)
# Create a modified copy with the new area
modified = c.copy()
modified["area"] = new_area
conditions[i] = modified
# Process mask
if "mask" in c:
modified = c.copy()
mask = c["mask"].to(device=device)
# Combine dimension checks and unsqueeze operation
if len(mask.shape) == len(dims):
mask = mask.unsqueeze(0)
# Only interpolate if needed
if mask.shape[1:] != dims:
# Optimize interpolation by ensuring mask is in the right format for the operation
if len(mask.shape) == 3 and mask.shape[0] == 1:
# Already in the right format for interpolation
mask = torch.nn.functional.interpolate(
mask.unsqueeze(1),
size=dims,
mode="bilinear",
align_corners=False,
).squeeze(1)
else:
# Ensure mask is properly formatted for interpolation
mask = torch.nn.functional.interpolate(
mask
if len(mask.shape) > 3 and mask.shape[1] == 1
else mask.unsqueeze(1),
size=dims,
mode="bilinear",
align_corners=False,
).squeeze(1)
modified["mask"] = mask
conditions[i] = modified
def process_conds(
model: object,
noise: torch.Tensor,
conds: dict,
device: torch.device,
latent_image: torch.Tensor = None,
denoise_mask: torch.Tensor = None,
seed: int = None,
) -> dict:
"""#### Process conditions.
#### Args:
- `model` (object): The model.
- `noise` (torch.Tensor): The noise tensor.
- `conds` (dict): The conditions.
- `device` (torch.device): The device.
- `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None.
- `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
- `seed` (int, optional): The seed. Defaults to None.
#### Returns:
- `dict`: The processed conditions.
"""
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
for k in conds:
ksampler_util.calculate_start_end_timesteps(model, conds[k])
if hasattr(model, "extra_conds"):
for k in conds:
conds[k] = encode_model_conds(
model.extra_conds,
conds[k],
noise,
device,
k,
latent_image=latent_image,
denoise_mask=denoise_mask,
seed=seed,
)
# make sure each cond area has an opposite one with the same area
for k in conds:
for c in conds[k]:
for kk in conds:
if k != kk:
cond_util.create_cond_with_same_area_if_none(conds[kk], c)
for k in conds:
ksampler_util.pre_run_control(model, conds[k])
if "positive" in conds:
positive = conds["positive"]
for k in conds:
if k != "positive":
ksampler_util.apply_empty_x_to_equal_area(
list(
filter(
lambda c: c.get("control_apply_to_uncond", False) is True,
positive,
)
),
conds[k],
"control",
lambda cond_cnets, x: cond_cnets[x],
)
ksampler_util.apply_empty_x_to_equal_area(
positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]
)
return conds