Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
from modules.Device import Device
import torch
from typing import List, Tuple, Any
def get_models_from_cond(cond: dict, model_type: str) -> List[object]:
"""#### Get models from a condition.
#### Args:
- `cond` (dict): The condition.
- `model_type` (str): The model type.
#### Returns:
- `List[object]`: The list of models.
"""
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def get_additional_models(conds: dict, dtype: torch.dtype) -> Tuple[List[object], int]:
"""#### Load additional models in conditioning.
#### Args:
- `conds` (dict): The conditions.
- `dtype` (torch.dtype): The data type.
#### Returns:
- `Tuple[List[object], int]`: The list of models and the inference memory.
"""
cnets = []
gligen = []
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def prepare_sampling(
model: object, noise_shape: Tuple[int], conds: dict, flux_enabled: bool = False
) -> Tuple[object, dict, List[object]]:
"""#### Prepare the model for sampling.
#### Args:
- `model` (object): The model.
- `noise_shape` (Tuple[int]): The shape of the noise.
- `conds` (dict): The conditions.
- `flux_enabled` (bool, optional): Whether flux is enabled. Defaults to False.
#### Returns:
- `Tuple[object, dict, List[object]]`: The prepared model, conditions, and additional models.
"""
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
memory_required = (
model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]))
+ inference_memory
)
minimum_memory_required = (
model.memory_required([noise_shape[0]] + list(noise_shape[1:]))
+ inference_memory
)
Device.load_models_gpu(
[model] + models,
memory_required=memory_required,
minimum_memory_required=minimum_memory_required,
flux_enabled=flux_enabled,
)
real_model = model.model
return real_model, conds, models
def cleanup_additional_models(models: List[object]) -> None:
"""#### Clean up additional models.
#### Args:
- `models` (List[object]): The list of models.
"""
for m in models:
if hasattr(m, "cleanup"):
m.cleanup()
def cleanup_models(conds: dict, models: List[object]) -> None:
"""#### Clean up the models after sampling.
#### Args:
- `conds` (dict): The conditions.
- `models` (List[object]): The list of models.
"""
cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))
def cond_equal_size(c1: Any, c2: Any) -> bool:
"""#### Check if two conditions have equal size.
#### Args:
- `c1` (Any): The first condition.
- `c2` (Any): The second condition.
#### Returns:
- `bool`: Whether the conditions have equal size.
"""
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
return True
def can_concat_cond(c1: Any, c2: Any) -> bool:
"""#### Check if two conditions can be concatenated.
#### Args:
- `c1` (Any): The first condition.
- `c2` (Any): The second condition.
#### Returns:
- `bool`: Whether the conditions can be concatenated.
"""
if c1.input_x.shape != c2.input_x.shape:
return False
def objects_concatable(obj1, obj2):
"""#### Check if two objects can be concatenated."""
if (obj1 is None) != (obj2 is None):
return False
if obj1 is not None:
if obj1 is not obj2:
return False
return True
if not objects_concatable(c1.control, c2.control):
return False
if not objects_concatable(c1.patches, c2.patches):
return False
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list: List[dict]) -> dict:
"""#### Concatenate a list of conditions.
#### Args:
- `c_list` (List[dict]): The list of conditions.
#### Returns:
- `dict`: The concatenated conditions.
"""
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def create_cond_with_same_area_if_none(conds: List[dict], c: dict) -> None:
"""#### Create a condition with the same area if none exists.
#### Args:
- `conds` (List[dict]): The list of conditions.
- `c` (dict): The condition.
"""
if "area" not in c:
return
c_area = c["area"]
smallest = None
for x in conds:
if "area" in x:
a = x["area"]
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
elif "area" not in smallest:
smallest = x
else:
if smallest["area"][0] * smallest["area"][1] > a[0] * a[1]:
smallest = x
else:
if smallest is None:
smallest = x
if smallest is None:
return
if "area" in smallest:
if smallest["area"] == c_area:
return
out = c.copy()
out["model_conds"] = smallest[
"model_conds"
].copy()
conds += [out]