File size: 19,285 Bytes
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfe609e
 
 
 
 
 
 
 
 
 
 
 
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfe609e
d9a2e19
 
 
cfe609e
d9a2e19
 
 
 
 
 
 
 
 
 
 
cfe609e
d9a2e19
 
 
cfe609e
d9a2e19
 
 
 
 
 
 
 
 
 
 
cfe609e
d9a2e19
 
 
cfe609e
d9a2e19
 
 
 
 
 
 
 
 
 
 
cfe609e
d9a2e19
 
 
cfe609e
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfe609e
d9a2e19
 
cfe609e
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfe609e
d9a2e19
 
 
 
 
 
 
 
 
 
cfe609e
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
import importlib
from inspect import isfunction
import itertools
import logging
import math
import os
import safetensors.torch
import torch


def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
    """#### Appends dimensions to the end of a tensor until it has target_dims dimensions.



    #### Args:

        - `x` (torch.Tensor): The input tensor.

        - `target_dims` (int): The target number of dimensions.



    #### Returns:

        - `torch.Tensor`: The expanded tensor.

    """
    dims_to_append = target_dims - x.ndim
    expanded = x[(...,) + (None,) * dims_to_append]
    return expanded.detach().clone() if expanded.device.type == "mps" else expanded


def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor:
    """#### Convert a tensor to a denoised tensor.



    #### Args:

        - `x` (torch.Tensor): The input tensor.

        - `sigma` (torch.Tensor): The noise level.

        - `denoised` (torch.Tensor): The denoised tensor.



    #### Returns:

        - `torch.Tensor`: The converted tensor.

    """
    return (x - denoised) / append_dims(sigma, x.ndim)


def load_torch_file(ckpt: str, safe_load: bool = False, device: str = None) -> dict:
    """#### Load a PyTorch checkpoint file.



    #### Args:

        - `ckpt` (str): The path to the checkpoint file.

        - `safe_load` (bool, optional): Whether to use safe loading. Defaults to False.

        - `device` (str, optional): The device to load the checkpoint on. Defaults to None.



    #### Returns:

        - `dict`: The loaded checkpoint.

    """
    if device is None:
        device = torch.device("cpu")
    if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
        sd = safetensors.torch.load_file(ckpt, device=device.type)
    else:
        if safe_load:
            if "weights_only" not in torch.load.__code__.co_varnames:
                logging.warning(
                    "Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
                )
                safe_load = False
        if safe_load:
            pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
        else:
            pl_sd = torch.load(ckpt, map_location=device)
        if "global_step" in pl_sd:
            logging.debug(f"Global Step: {pl_sd['global_step']}")
        if "state_dict" in pl_sd:
            sd = pl_sd["state_dict"]
        else:
            sd = pl_sd
    return sd


def calculate_parameters(sd: dict, prefix: str = "") -> dict:
    """#### Calculate the parameters of a state dictionary.



    #### Args:

        - `sd` (dict): The state dictionary.

        - `prefix` (str, optional): The prefix for the parameters. Defaults to "".



    #### Returns:

        - `dict`: The calculated parameters.

    """
    params = 0
    for k in sd.keys():
        if k.startswith(prefix):
            params += sd[k].nelement()
    return params


def state_dict_prefix_replace(

    state_dict: dict, replace_prefix: str, filter_keys: bool = False

) -> dict:
    """#### Replace the prefix of keys in a state dictionary.



    #### Args:

        - `state_dict` (dict): The state dictionary.

        - `replace_prefix` (str): The prefix to replace.

        - `filter_keys` (bool, optional): Whether to filter keys. Defaults to False.



    #### Returns:

        - `dict`: The updated state dictionary.

    """
    if filter_keys:
        out = {}
    else:
        out = state_dict
    for rp in replace_prefix:
        replace = list(
            map(
                lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp) :])),
                filter(lambda a: a.startswith(rp), state_dict.keys()),
            )
        )
        for x in replace:
            w = state_dict.pop(x[0])
            out[x[1]] = w
    return out


def lcm_of_list(numbers):
    """Calculate LCM of a list of numbers more efficiently."""
    if not numbers:
        return 1

    result = numbers[0]
    for num in numbers[1:]:
        result = torch.lcm(torch.tensor(result), torch.tensor(num)).item()

    return result


def repeat_to_batch_size(

    tensor: torch.Tensor, batch_size: int, dim: int = 0

) -> torch.Tensor:
    """#### Repeat a tensor to match a specific batch size.



    #### Args:

        - `tensor` (torch.Tensor): The input tensor.

        - `batch_size` (int): The target batch size.

        - `dim` (int, optional): The dimension to repeat. Defaults to 0.



    #### Returns:

        - `torch.Tensor`: The repeated tensor.

    """
    if tensor.shape[dim] > batch_size:
        return tensor.narrow(dim, 0, batch_size)
    elif tensor.shape[dim] < batch_size:
        return tensor.repeat(
            dim * [1]
            + [math.ceil(batch_size / tensor.shape[dim])]
            + [1] * (len(tensor.shape) - 1 - dim)
        ).narrow(dim, 0, batch_size)
    return tensor


def set_attr(obj: object, attr: str, value: any) -> any:
    """#### Set an attribute of an object.



    #### Args:

        - `obj` (object): The object.

        - `attr` (str): The attribute name.

        - `value` (any): The value to set.



    #### Returns:

        - `prev`: The previous attribute value.

    """
    attrs = attr.split(".")
    for name in attrs[:-1]:
        obj = getattr(obj, name)
    prev = getattr(obj, attrs[-1])
    setattr(obj, attrs[-1], value)
    return prev


def set_attr_param(obj: object, attr: str, value: any) -> any:
    """#### Set an attribute parameter of an object.



    #### Args:

        - `obj` (object): The object.

        - `attr` (str): The attribute name.

        - `value` (any): The value to set.



    #### Returns:

        - `prev`: The previous attribute value.

    """
    return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))


def copy_to_param(obj: object, attr: str, value: any) -> None:
    """#### Copy a value to a parameter of an object.



    #### Args:

        - `obj` (object): The object.

        - `attr` (str): The attribute name.

        - `value` (any): The value to set.

    """
    attrs = attr.split(".")
    for name in attrs[:-1]:
        obj = getattr(obj, name)
    prev = getattr(obj, attrs[-1])
    prev.data.copy_(value)


def get_obj_from_str(string: str, reload: bool = False) -> object:
    """#### Get an object from a string.



    #### Args:

        - `string` (str): The string.

        - `reload` (bool, optional): Whether to reload the module. Defaults to False.



    #### Returns:

        - `object`: The object.

    """
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def get_attr(obj: object, attr: str) -> any:
    """#### Get an attribute of an object.



    #### Args:

        - `obj` (object): The object.

        - `attr` (str): The attribute name.



    #### Returns:

        - `obj`: The attribute value.

    """
    attrs = attr.split(".")
    for name in attrs:
        obj = getattr(obj, name)
    return obj


def lcm(a: int, b: int) -> int:
    """#### Calculate the least common multiple (LCM) of two numbers.



    #### Args:

        - `a` (int): The first number.

        - `b` (int): The second number.



    #### Returns:

        - `int`: The LCM of the two numbers.

    """
    return abs(a * b) // math.gcd(a, b)


def get_full_path(folder_name: str, filename: str) -> str:
    """#### Get the full path of a file in a folder.



    Args:

        folder_name (str): The folder name.

        filename (str): The filename.



    Returns:

        str: The full path of the file.

    """
    global folder_names_and_paths
    folders = folder_names_and_paths[folder_name]
    filename = os.path.relpath(os.path.join("/", filename), "/")
    for x in folders[0]:
        full_path = os.path.join(x, filename)
        if os.path.isfile(full_path):
            return full_path


def zero_module(module: torch.nn.Module) -> torch.nn.Module:
    """#### Zero out the parameters of a module.



    #### Args:

        - `module` (torch.nn.Module): The module.



    #### Returns:

        - `torch.nn.Module`: The zeroed module.

    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def append_zero(x: torch.Tensor) -> torch.Tensor:
    """#### Append a zero to the end of a tensor.



    #### Args:

        - `x` (torch.Tensor): The input tensor.



    #### Returns:

        - `torch.Tensor`: The tensor with a zero appended.

    """
    return torch.cat([x, x.new_zeros([1])])


def exists(val: any) -> bool:
    """#### Check if a value exists.



    #### Args:

        - `val` (any): The value.



    #### Returns:

        - `bool`: Whether the value exists.

    """
    return val is not None


def default(val: any, d: any) -> any:
    """#### Get the default value of a variable.



    #### Args:

        - `val` (any): The value.

        - `d` (any): The default value.



    #### Returns:

        - `any`: The default value if the value does not exist.

    """
    if exists(val):
        return val
    return d() if isfunction(d) else d


def write_parameters_to_file(

    prompt_entry: str, neg: str, width: int, height: int, cfg: int

) -> None:
    """#### Write parameters to a file.



    #### Args:

        - `prompt_entry` (str): The prompt entry.

        - `neg` (str): The negative prompt entry.

        - `width` (int): The width.

        - `height` (int): The height.

        - `cfg` (int): The CFG.

    """
    with open("./_internal/prompt.txt", "w") as f:
        f.write(f"prompt: {prompt_entry}")
        f.write(f"neg: {neg}")
        f.write(f"w: {int(width)}\n")
        f.write(f"h: {int(height)}\n")
        f.write(f"cfg: {int(cfg)}\n")


def load_parameters_from_file() -> tuple:
    """#### Load parameters from a file.



    #### Returns:

        - `str`: The prompt entry.

        - `str`: The negative prompt entry.

        - `int`: The width.

        - `int`: The height.

        - `int`: The CFG.

    """
    with open("./_internal/prompt.txt", "r") as f:
        lines = f.readlines()
        parameters = {}
        for line in lines:
            # Skip empty lines
            if line.strip() == "":
                continue
            key, value = line.split(": ")
            parameters[key] = value.strip()
        prompt = parameters["prompt"]
        neg = parameters["neg"]
        width = int(parameters["w"])
        height = int(parameters["h"])
        cfg = int(parameters["cfg"])
    return prompt, neg, width, height, cfg


PROGRESS_BAR_ENABLED = True
PROGRESS_BAR_HOOK = None


class ProgressBar:
    """#### Class representing a progress bar."""

    def __init__(self, total: int):
        global PROGRESS_BAR_HOOK
        self.total = total
        self.current = 0
        self.hook = PROGRESS_BAR_HOOK


def get_tiled_scale_steps(

    width: int, height: int, tile_x: int, tile_y: int, overlap: int

) -> int:
    """#### Get the number of steps for tiled scaling.



    #### Args:

        - `width` (int): The width.

        - `height` (int): The height.

        - `tile_x` (int): The tile width.

        - `tile_y` (int): The tile height.

        - `overlap` (int): The overlap.



    #### Returns:

        - `int`: The number of steps.

    """
    rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
    cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
    return rows * cols


@torch.inference_mode()
def tiled_scale_multidim(

    samples: torch.Tensor,

    function,

    tile: tuple = (64, 64),

    overlap: int = 8,

    upscale_amount: int = 4,

    out_channels: int = 3,

    output_device: str = "cpu",

    downscale: bool = False,

    index_formulas: any = None,

    pbar: any = None,

):
    """#### Scale an image using a tiled approach.



    #### Args:

        - `samples` (torch.Tensor): The input samples.

        - `function` (function): The scaling function.

        - `tile` (tuple, optional): The tile size. Defaults to (64, 64).

        - `overlap` (int, optional): The overlap. Defaults to 8.

        - `upscale_amount` (int, optional): The upscale amount. Defaults to 4.

        - `out_channels` (int, optional): The number of output channels. Defaults to 3.

        - `output_device` (str, optional): The output device. Defaults to "cpu".

        - `downscale` (bool, optional): Whether to downscale. Defaults to False.

        - `index_formulas` (any, optional): The index formulas. Defaults to None.

        - `pbar` (any, optional): The progress bar. Defaults to None.



    #### Returns:

        - `torch.Tensor`: The scaled image.

    """
    dims = len(tile)

    if not (isinstance(upscale_amount, (tuple, list))):
        upscale_amount = [upscale_amount] * dims

    if not (isinstance(overlap, (tuple, list))):
        overlap = [overlap] * dims

    if index_formulas is None:
        index_formulas = upscale_amount

    if not (isinstance(index_formulas, (tuple, list))):
        index_formulas = [index_formulas] * dims

    def get_upscale(dim: int, val: int) -> int:
        """#### Get the upscale value.



        #### Args:

            - `dim` (int): The dimension.

            - `val` (int): The value.



        #### Returns:

            - `int`: The upscaled value.

        """
        up = upscale_amount[dim]
        if callable(up):
            return up(val)
        else:
            return up * val

    def get_downscale(dim: int, val: int) -> int:
        """#### Get the downscale value.



        #### Args:

            - `dim` (int): The dimension.

            - `val` (int): The value.



        #### Returns:

            - `int`: The downscaled value.

        """
        up = upscale_amount[dim]
        if callable(up):
            return up(val)
        else:
            return val / up

    def get_upscale_pos(dim: int, val: int) -> int:
        """#### Get the upscaled position.



        #### Args:

            - `dim` (int): The dimension.

            - `val` (int): The value.



        #### Returns:

            - `int`: The upscaled position.

        """
        up = index_formulas[dim]
        if callable(up):
            return up(val)
        else:
            return up * val

    def get_downscale_pos(dim: int, val: int) -> int:
        """#### Get the downscaled position.



        #### Args:

            - `dim` (int): The dimension.

            - `val` (int): The value.



        #### Returns:

            - `int`: The downscaled position.

        """
        up = index_formulas[dim]
        if callable(up):
            return up(val)
        else:
            return val / up

    if downscale:
        get_scale = get_downscale
        get_pos = get_downscale_pos
    else:
        get_scale = get_upscale
        get_pos = get_upscale_pos

    def mult_list_upscale(a: list) -> list:
        """#### Multiply a list by the upscale amount.



        #### Args:

            - `a` (list): The list.



        #### Returns:

            - `list`: The multiplied list.

        """
        out = []
        for i in range(len(a)):
            out.append(round(get_scale(i, a[i])))
        return out

    output = torch.empty(
        [samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]),
        device=output_device,
    )

    for b in range(samples.shape[0]):
        s = samples[b : b + 1]

        # handle entire input fitting in a single tile
        if all(s.shape[d + 2] <= tile[d] for d in range(dims)):
            output[b : b + 1] = function(s).to(output_device)
            if pbar is not None:
                pbar.update(1)
            continue

        out = torch.zeros(
            [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]),
            device=output_device,
        )
        out_div = torch.zeros(
            [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]),
            device=output_device,
        )

        positions = [
            range(0, s.shape[d + 2] - overlap[d], tile[d] - overlap[d])
            if s.shape[d + 2] > tile[d]
            else [0]
            for d in range(dims)
        ]

        for it in itertools.product(*positions):
            s_in = s
            upscaled = []

            for d in range(dims):
                pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
                l = min(tile[d], s.shape[d + 2] - pos)
                s_in = s_in.narrow(d + 2, pos, l)
                upscaled.append(round(get_pos(d, pos)))

            ps = function(s_in).to(output_device)
            mask = torch.ones_like(ps)

            for d in range(2, dims + 2):
                feather = round(get_scale(d - 2, overlap[d - 2]))
                if feather >= mask.shape[d]:
                    continue
                for t in range(feather):
                    a = (t + 1) / feather
                    mask.narrow(d, t, 1).mul_(a)
                    mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)

            o = out
            o_d = out_div
            for d in range(dims):
                o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
                o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])

            o.add_(ps * mask)
            o_d.add_(mask)

            if pbar is not None:
                pbar.update(1)

        output[b : b + 1] = out / out_div
    return output


def tiled_scale(

    samples: torch.Tensor,

    function,

    tile_x: int = 64,

    tile_y: int = 64,

    overlap: int = 8,

    upscale_amount: int = 4,

    out_channels: int = 3,

    output_device: str = "cpu",

    pbar: any = None,

):
    """#### Scale an image using a tiled approach.



    #### Args:

        - `samples` (torch.Tensor): The input samples.

        - `function` (function): The scaling function.

        - `tile_x` (int, optional): The tile width. Defaults to 64.

        - `tile_y` (int, optional): The tile height. Defaults to 64.

        - `overlap` (int, optional): The overlap. Defaults to 8.

        - `upscale_amount` (int, optional): The upscale amount. Defaults to 4.

        - `out_channels` (int, optional): The number of output channels. Defaults to 3.

        - `output_device` (str, optional): The output device. Defaults to "cpu".

        - `pbar` (any, optional): The progress bar. Defaults to None.



    #### Returns:

        - The scaled image.

    """
    return tiled_scale_multidim(
        samples,
        function,
        (tile_y, tile_x),
        overlap=overlap,
        upscale_amount=upscale_amount,
        out_channels=out_channels,
        output_device=output_device,
        pbar=pbar,
    )