File size: 6,468 Bytes
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from modules.Device import Device
import torch
from typing import List, Tuple, Any


def get_models_from_cond(cond: dict, model_type: str) -> List[object]:
    """#### Get models from a condition.



    #### Args:

        - `cond` (dict): The condition.

        - `model_type` (str): The model type.



    #### Returns:

        - `List[object]`: The list of models.

    """
    models = []
    for c in cond:
        if model_type in c:
            models += [c[model_type]]
    return models


def get_additional_models(conds: dict, dtype: torch.dtype) -> Tuple[List[object], int]:
    """#### Load additional models in conditioning.



    #### Args:

        - `conds` (dict): The conditions.

        - `dtype` (torch.dtype): The data type.



    #### Returns:

        - `Tuple[List[object], int]`: The list of models and the inference memory.

    """
    cnets = []
    gligen = []

    for k in conds:
        cnets += get_models_from_cond(conds[k], "control")
        gligen += get_models_from_cond(conds[k], "gligen")

    control_nets = set(cnets)

    inference_memory = 0
    control_models = []
    for m in control_nets:
        control_models += m.get_models()
        inference_memory += m.inference_memory_requirements(dtype)

    gligen = [x[1] for x in gligen]
    models = control_models + gligen
    return models, inference_memory


def prepare_sampling(

    model: object, noise_shape: Tuple[int], conds: dict, flux_enabled: bool = False

) -> Tuple[object, dict, List[object]]:
    """#### Prepare the model for sampling.



    #### Args:

        - `model` (object): The model.

        - `noise_shape` (Tuple[int]): The shape of the noise.

        - `conds` (dict): The conditions.

        - `flux_enabled` (bool, optional): Whether flux is enabled. Defaults to False.



    #### Returns:

        - `Tuple[object, dict, List[object]]`: The prepared model, conditions, and additional models.

    """
    real_model = None
    models, inference_memory = get_additional_models(conds, model.model_dtype())
    memory_required = (
        model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]))
        + inference_memory
    )
    minimum_memory_required = (
        model.memory_required([noise_shape[0]] + list(noise_shape[1:]))
        + inference_memory
    )
    Device.load_models_gpu(
        [model] + models,
        memory_required=memory_required,
        minimum_memory_required=minimum_memory_required,
        flux_enabled=flux_enabled,
    )
    real_model = model.model

    return real_model, conds, models

def cleanup_additional_models(models: List[object]) -> None:
    """#### Clean up additional models.

    

    #### Args:

        - `models` (List[object]): The list of models.

    """
    for m in models:
        if hasattr(m, "cleanup"):
            m.cleanup()

def cleanup_models(conds: dict, models: List[object]) -> None:
    """#### Clean up the models after sampling.



    #### Args:

        - `conds` (dict): The conditions.

        - `models` (List[object]): The list of models.

    """
    cleanup_additional_models(models)

    control_cleanup = []
    for k in conds:
        control_cleanup += get_models_from_cond(conds[k], "control")

    cleanup_additional_models(set(control_cleanup))


def cond_equal_size(c1: Any, c2: Any) -> bool:
    """#### Check if two conditions have equal size.



    #### Args:

        - `c1` (Any): The first condition.

        - `c2` (Any): The second condition.



    #### Returns:

        - `bool`: Whether the conditions have equal size.

    """
    if c1 is c2:
        return True
    if c1.keys() != c2.keys():
        return False
    return True


def can_concat_cond(c1: Any, c2: Any) -> bool:
    """#### Check if two conditions can be concatenated.



    #### Args:

        - `c1` (Any): The first condition.

        - `c2` (Any): The second condition.



    #### Returns:

        - `bool`: Whether the conditions can be concatenated.

    """
    if c1.input_x.shape != c2.input_x.shape:
        return False

    def objects_concatable(obj1, obj2):
        """#### Check if two objects can be concatenated."""
        if (obj1 is None) != (obj2 is None):
            return False
        if obj1 is not None:
            if obj1 is not obj2:
                return False
        return True

    if not objects_concatable(c1.control, c2.control):
        return False

    if not objects_concatable(c1.patches, c2.patches):
        return False

    return cond_equal_size(c1.conditioning, c2.conditioning)


def cond_cat(c_list: List[dict]) -> dict:
    """#### Concatenate a list of conditions.



    #### Args:

        - `c_list` (List[dict]): The list of conditions.



    #### Returns:

        - `dict`: The concatenated conditions.

    """
    temp = {}
    for x in c_list:
        for k in x:
            cur = temp.get(k, [])
            cur.append(x[k])
            temp[k] = cur

    out = {}
    for k in temp:
        conds = temp[k]
        out[k] = conds[0].concat(conds[1:])

    return out


def create_cond_with_same_area_if_none(conds: List[dict], c: dict) -> None:
    """#### Create a condition with the same area if none exists.



    #### Args:

        - `conds` (List[dict]): The list of conditions.

        - `c` (dict): The condition.

    """
    if "area" not in c:
        return

    c_area = c["area"]
    smallest = None
    for x in conds:
        if "area" in x:
            a = x["area"]
            if c_area[2] >= a[2] and c_area[3] >= a[3]:
                if a[0] + a[2] >= c_area[0] + c_area[2]:
                    if a[1] + a[3] >= c_area[1] + c_area[3]:
                        if smallest is None:
                            smallest = x
                        elif "area" not in smallest:
                            smallest = x
                        else:
                            if smallest["area"][0] * smallest["area"][1] > a[0] * a[1]:
                                smallest = x
        else:
            if smallest is None:
                smallest = x
    if smallest is None:
        return
    if "area" in smallest:
        if smallest["area"] == c_area:
            return

    out = c.copy()
    out["model_conds"] = smallest[
        "model_conds"
    ].copy()
    conds += [out]