from modules.Device import Device import torch from typing import List, Tuple, Any def get_models_from_cond(cond: dict, model_type: str) -> List[object]: """#### Get models from a condition. #### Args: - `cond` (dict): The condition. - `model_type` (str): The model type. #### Returns: - `List[object]`: The list of models. """ models = [] for c in cond: if model_type in c: models += [c[model_type]] return models def get_additional_models(conds: dict, dtype: torch.dtype) -> Tuple[List[object], int]: """#### Load additional models in conditioning. #### Args: - `conds` (dict): The conditions. - `dtype` (torch.dtype): The data type. #### Returns: - `Tuple[List[object], int]`: The list of models and the inference memory. """ cnets = [] gligen = [] for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") control_nets = set(cnets) inference_memory = 0 control_models = [] for m in control_nets: control_models += m.get_models() inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] models = control_models + gligen return models, inference_memory def prepare_sampling( model: object, noise_shape: Tuple[int], conds: dict, flux_enabled: bool = False ) -> Tuple[object, dict, List[object]]: """#### Prepare the model for sampling. #### Args: - `model` (object): The model. - `noise_shape` (Tuple[int]): The shape of the noise. - `conds` (dict): The conditions. - `flux_enabled` (bool, optional): Whether flux is enabled. Defaults to False. #### Returns: - `Tuple[object, dict, List[object]]`: The prepared model, conditions, and additional models. """ real_model = None models, inference_memory = get_additional_models(conds, model.model_dtype()) memory_required = ( model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory ) minimum_memory_required = ( model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory ) Device.load_models_gpu( [model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, flux_enabled=flux_enabled, ) real_model = model.model return real_model, conds, models def cleanup_additional_models(models: List[object]) -> None: """#### Clean up additional models. #### Args: - `models` (List[object]): The list of models. """ for m in models: if hasattr(m, "cleanup"): m.cleanup() def cleanup_models(conds: dict, models: List[object]) -> None: """#### Clean up the models after sampling. #### Args: - `conds` (dict): The conditions. - `models` (List[object]): The list of models. """ cleanup_additional_models(models) control_cleanup = [] for k in conds: control_cleanup += get_models_from_cond(conds[k], "control") cleanup_additional_models(set(control_cleanup)) def cond_equal_size(c1: Any, c2: Any) -> bool: """#### Check if two conditions have equal size. #### Args: - `c1` (Any): The first condition. - `c2` (Any): The second condition. #### Returns: - `bool`: Whether the conditions have equal size. """ if c1 is c2: return True if c1.keys() != c2.keys(): return False return True def can_concat_cond(c1: Any, c2: Any) -> bool: """#### Check if two conditions can be concatenated. #### Args: - `c1` (Any): The first condition. - `c2` (Any): The second condition. #### Returns: - `bool`: Whether the conditions can be concatenated. """ if c1.input_x.shape != c2.input_x.shape: return False def objects_concatable(obj1, obj2): """#### Check if two objects can be concatenated.""" if (obj1 is None) != (obj2 is None): return False if obj1 is not None: if obj1 is not obj2: return False return True if not objects_concatable(c1.control, c2.control): return False if not objects_concatable(c1.patches, c2.patches): return False return cond_equal_size(c1.conditioning, c2.conditioning) def cond_cat(c_list: List[dict]) -> dict: """#### Concatenate a list of conditions. #### Args: - `c_list` (List[dict]): The list of conditions. #### Returns: - `dict`: The concatenated conditions. """ temp = {} for x in c_list: for k in x: cur = temp.get(k, []) cur.append(x[k]) temp[k] = cur out = {} for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) return out def create_cond_with_same_area_if_none(conds: List[dict], c: dict) -> None: """#### Create a condition with the same area if none exists. #### Args: - `conds` (List[dict]): The list of conditions. - `c` (dict): The condition. """ if "area" not in c: return c_area = c["area"] smallest = None for x in conds: if "area" in x: a = x["area"] if c_area[2] >= a[2] and c_area[3] >= a[3]: if a[0] + a[2] >= c_area[0] + c_area[2]: if a[1] + a[3] >= c_area[1] + c_area[3]: if smallest is None: smallest = x elif "area" not in smallest: smallest = x else: if smallest["area"][0] * smallest["area"][1] > a[0] * a[1]: smallest = x else: if smallest is None: smallest = x if smallest is None: return if "area" in smallest: if smallest["area"] == c_area: return out = c.copy() out["model_conds"] = smallest[ "model_conds" ].copy() conds += [out]