Spaces:
Running
on
Zero
Running
on
Zero
| import importlib | |
| from inspect import isfunction | |
| import itertools | |
| import logging | |
| import math | |
| import os | |
| import pickle | |
| import safetensors.torch | |
| import torch | |
| def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: | |
| """#### Appends dimensions to the end of a tensor until it has target_dims dimensions. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `target_dims` (int): The target number of dimensions. | |
| #### Returns: | |
| - `torch.Tensor`: The expanded tensor. | |
| """ | |
| dims_to_append = target_dims - x.ndim | |
| expanded = x[(...,) + (None,) * dims_to_append] | |
| return expanded.detach().clone() if expanded.device.type == "mps" else expanded | |
| def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: | |
| """#### Convert a tensor to a denoised tensor. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `sigma` (torch.Tensor): The noise level. | |
| - `denoised` (torch.Tensor): The denoised tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The converted tensor. | |
| """ | |
| return (x - denoised) / append_dims(sigma, x.ndim) | |
| def load_torch_file(ckpt: str, safe_load: bool = False, device: str = None) -> dict: | |
| """#### Load a PyTorch checkpoint file. | |
| #### Args: | |
| - `ckpt` (str): The path to the checkpoint file. | |
| - `safe_load` (bool, optional): Whether to use safe loading. Defaults to False. | |
| - `device` (str, optional): The device to load the checkpoint on. Defaults to None. | |
| #### Returns: | |
| - `dict`: The loaded checkpoint. | |
| """ | |
| if device is None: | |
| device = torch.device("cpu") | |
| if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): | |
| sd = safetensors.torch.load_file(ckpt, device=device.type) | |
| else: | |
| if safe_load: | |
| if "weights_only" not in torch.load.__code__.co_varnames: | |
| logging.warning( | |
| "Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely." | |
| ) | |
| safe_load = False | |
| if safe_load: | |
| pl_sd = torch.load(ckpt, map_location=device, weights_only=True) | |
| else: | |
| pl_sd = torch.load(ckpt, map_location=device) | |
| if "global_step" in pl_sd: | |
| logging.debug(f"Global Step: {pl_sd['global_step']}") | |
| if "state_dict" in pl_sd: | |
| sd = pl_sd["state_dict"] | |
| else: | |
| sd = pl_sd | |
| return sd | |
| def calculate_parameters(sd: dict, prefix: str = "") -> dict: | |
| """#### Calculate the parameters of a state dictionary. | |
| #### Args: | |
| - `sd` (dict): The state dictionary. | |
| - `prefix` (str, optional): The prefix for the parameters. Defaults to "". | |
| #### Returns: | |
| - `dict`: The calculated parameters. | |
| """ | |
| params = 0 | |
| for k in sd.keys(): | |
| if k.startswith(prefix): | |
| params += sd[k].nelement() | |
| return params | |
| def state_dict_prefix_replace( | |
| state_dict: dict, replace_prefix: str, filter_keys: bool = False | |
| ) -> dict: | |
| """#### Replace the prefix of keys in a state dictionary. | |
| #### Args: | |
| - `state_dict` (dict): The state dictionary. | |
| - `replace_prefix` (str): The prefix to replace. | |
| - `filter_keys` (bool, optional): Whether to filter keys. Defaults to False. | |
| #### Returns: | |
| - `dict`: The updated state dictionary. | |
| """ | |
| if filter_keys: | |
| out = {} | |
| else: | |
| out = state_dict | |
| for rp in replace_prefix: | |
| replace = list( | |
| map( | |
| lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp) :])), | |
| filter(lambda a: a.startswith(rp), state_dict.keys()), | |
| ) | |
| ) | |
| for x in replace: | |
| w = state_dict.pop(x[0]) | |
| out[x[1]] = w | |
| return out | |
| def repeat_to_batch_size( | |
| tensor: torch.Tensor, batch_size: int, dim: int = 0 | |
| ) -> torch.Tensor: | |
| """#### Repeat a tensor to match a specific batch size. | |
| #### Args: | |
| - `tensor` (torch.Tensor): The input tensor. | |
| - `batch_size` (int): The target batch size. | |
| - `dim` (int, optional): The dimension to repeat. Defaults to 0. | |
| #### Returns: | |
| - `torch.Tensor`: The repeated tensor. | |
| """ | |
| if tensor.shape[dim] > batch_size: | |
| return tensor.narrow(dim, 0, batch_size) | |
| elif tensor.shape[dim] < batch_size: | |
| return tensor.repeat( | |
| dim * [1] | |
| + [math.ceil(batch_size / tensor.shape[dim])] | |
| + [1] * (len(tensor.shape) - 1 - dim) | |
| ).narrow(dim, 0, batch_size) | |
| return tensor | |
| def set_attr(obj: object, attr: str, value: any) -> any: | |
| """#### Set an attribute of an object. | |
| #### Args: | |
| - `obj` (object): The object. | |
| - `attr` (str): The attribute name. | |
| - `value` (any): The value to set. | |
| #### Returns: | |
| - `prev`: The previous attribute value. | |
| """ | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| prev = getattr(obj, attrs[-1]) | |
| setattr(obj, attrs[-1], value) | |
| return prev | |
| def set_attr_param(obj: object, attr: str, value: any) -> any: | |
| """#### Set an attribute parameter of an object. | |
| #### Args: | |
| - `obj` (object): The object. | |
| - `attr` (str): The attribute name. | |
| - `value` (any): The value to set. | |
| #### Returns: | |
| - `prev`: The previous attribute value. | |
| """ | |
| return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) | |
| def copy_to_param(obj: object, attr: str, value: any) -> None: | |
| """#### Copy a value to a parameter of an object. | |
| #### Args: | |
| - `obj` (object): The object. | |
| - `attr` (str): The attribute name. | |
| - `value` (any): The value to set. | |
| """ | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| prev = getattr(obj, attrs[-1]) | |
| prev.data.copy_(value) | |
| def get_obj_from_str(string: str, reload: bool = False) -> object: | |
| """#### Get an object from a string. | |
| #### Args: | |
| - `string` (str): The string. | |
| - `reload` (bool, optional): Whether to reload the module. Defaults to False. | |
| #### Returns: | |
| - `object`: The object. | |
| """ | |
| module, cls = string.rsplit(".", 1) | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| def get_attr(obj: object, attr: str) -> any: | |
| """#### Get an attribute of an object. | |
| #### Args: | |
| - `obj` (object): The object. | |
| - `attr` (str): The attribute name. | |
| #### Returns: | |
| - `obj`: The attribute value. | |
| """ | |
| attrs = attr.split(".") | |
| for name in attrs: | |
| obj = getattr(obj, name) | |
| return obj | |
| def lcm(a: int, b: int) -> int: | |
| """#### Calculate the least common multiple (LCM) of two numbers. | |
| #### Args: | |
| - `a` (int): The first number. | |
| - `b` (int): The second number. | |
| #### Returns: | |
| - `int`: The LCM of the two numbers. | |
| """ | |
| return abs(a * b) // math.gcd(a, b) | |
| def get_full_path(folder_name: str, filename: str) -> str: | |
| """#### Get the full path of a file in a folder. | |
| Args: | |
| folder_name (str): The folder name. | |
| filename (str): The filename. | |
| Returns: | |
| str: The full path of the file. | |
| """ | |
| global folder_names_and_paths | |
| folders = folder_names_and_paths[folder_name] | |
| filename = os.path.relpath(os.path.join("/", filename), "/") | |
| for x in folders[0]: | |
| full_path = os.path.join(x, filename) | |
| if os.path.isfile(full_path): | |
| return full_path | |
| def zero_module(module: torch.nn.Module) -> torch.nn.Module: | |
| """#### Zero out the parameters of a module. | |
| #### Args: | |
| - `module` (torch.nn.Module): The module. | |
| #### Returns: | |
| - `torch.nn.Module`: The zeroed module. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| def append_zero(x: torch.Tensor) -> torch.Tensor: | |
| """#### Append a zero to the end of a tensor. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The tensor with a zero appended. | |
| """ | |
| return torch.cat([x, x.new_zeros([1])]) | |
| def exists(val: any) -> bool: | |
| """#### Check if a value exists. | |
| #### Args: | |
| - `val` (any): The value. | |
| #### Returns: | |
| - `bool`: Whether the value exists. | |
| """ | |
| return val is not None | |
| def default(val: any, d: any) -> any: | |
| """#### Get the default value of a variable. | |
| #### Args: | |
| - `val` (any): The value. | |
| - `d` (any): The default value. | |
| #### Returns: | |
| - `any`: The default value if the value does not exist. | |
| """ | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| def write_parameters_to_file( | |
| prompt_entry: str, neg: str, width: int, height: int, cfg: int | |
| ) -> None: | |
| """#### Write parameters to a file. | |
| #### Args: | |
| - `prompt_entry` (str): The prompt entry. | |
| - `neg` (str): The negative prompt entry. | |
| - `width` (int): The width. | |
| - `height` (int): The height. | |
| - `cfg` (int): The CFG. | |
| """ | |
| with open("./_internal/prompt.txt", "w") as f: | |
| f.write(f"prompt: {prompt_entry}") | |
| f.write(f"neg: {neg}") | |
| f.write(f"w: {int(width)}\n") | |
| f.write(f"h: {int(height)}\n") | |
| f.write(f"cfg: {int(cfg)}\n") | |
| def load_parameters_from_file() -> tuple: | |
| """#### Load parameters from a file. | |
| #### Returns: | |
| - `str`: The prompt entry. | |
| - `str`: The negative prompt entry. | |
| - `int`: The width. | |
| - `int`: The height. | |
| - `int`: The CFG. | |
| """ | |
| with open("./_internal/prompt.txt", "r") as f: | |
| lines = f.readlines() | |
| parameters = {} | |
| for line in lines: | |
| # Skip empty lines | |
| if line.strip() == "": | |
| continue | |
| key, value = line.split(": ") | |
| parameters[key] = value.strip() | |
| prompt = parameters["prompt"] | |
| neg = parameters["neg"] | |
| width = int(parameters["w"]) | |
| height = int(parameters["h"]) | |
| cfg = int(parameters["cfg"]) | |
| return prompt, neg, width, height, cfg | |
| PROGRESS_BAR_ENABLED = True | |
| PROGRESS_BAR_HOOK = None | |
| class ProgressBar: | |
| """#### Class representing a progress bar.""" | |
| def __init__(self, total: int): | |
| global PROGRESS_BAR_HOOK | |
| self.total = total | |
| self.current = 0 | |
| self.hook = PROGRESS_BAR_HOOK | |
| def get_tiled_scale_steps( | |
| width: int, height: int, tile_x: int, tile_y: int, overlap: int | |
| ) -> int: | |
| """#### Get the number of steps for tiled scaling. | |
| #### Args: | |
| - `width` (int): The width. | |
| - `height` (int): The height. | |
| - `tile_x` (int): The tile width. | |
| - `tile_y` (int): The tile height. | |
| - `overlap` (int): The overlap. | |
| #### Returns: | |
| - `int`: The number of steps. | |
| """ | |
| rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap)) | |
| cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap)) | |
| return rows * cols | |
| def tiled_scale_multidim( | |
| samples: torch.Tensor, | |
| function, | |
| tile: tuple = (64, 64), | |
| overlap: int = 8, | |
| upscale_amount: int = 4, | |
| out_channels: int = 3, | |
| output_device: str = "cpu", | |
| downscale: bool = False, | |
| index_formulas: any = None, | |
| pbar: any = None, | |
| ): | |
| """#### Scale an image using a tiled approach. | |
| #### Args: | |
| - `samples` (torch.Tensor): The input samples. | |
| - `function` (function): The scaling function. | |
| - `tile` (tuple, optional): The tile size. Defaults to (64, 64). | |
| - `overlap` (int, optional): The overlap. Defaults to 8. | |
| - `upscale_amount` (int, optional): The upscale amount. Defaults to 4. | |
| - `out_channels` (int, optional): The number of output channels. Defaults to 3. | |
| - `output_device` (str, optional): The output device. Defaults to "cpu". | |
| - `downscale` (bool, optional): Whether to downscale. Defaults to False. | |
| - `index_formulas` (any, optional): The index formulas. Defaults to None. | |
| - `pbar` (any, optional): The progress bar. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The scaled image. | |
| """ | |
| dims = len(tile) | |
| if not (isinstance(upscale_amount, (tuple, list))): | |
| upscale_amount = [upscale_amount] * dims | |
| if not (isinstance(overlap, (tuple, list))): | |
| overlap = [overlap] * dims | |
| if index_formulas is None: | |
| index_formulas = upscale_amount | |
| if not (isinstance(index_formulas, (tuple, list))): | |
| index_formulas = [index_formulas] * dims | |
| def get_upscale(dim: int, val: int) -> int: | |
| """#### Get the upscale value. | |
| #### Args: | |
| - `dim` (int): The dimension. | |
| - `val` (int): The value. | |
| #### Returns: | |
| - `int`: The upscaled value. | |
| """ | |
| up = upscale_amount[dim] | |
| if callable(up): | |
| return up(val) | |
| else: | |
| return up * val | |
| def get_downscale(dim: int, val: int) -> int: | |
| """#### Get the downscale value. | |
| #### Args: | |
| - `dim` (int): The dimension. | |
| - `val` (int): The value. | |
| #### Returns: | |
| - `int`: The downscaled value. | |
| """ | |
| up = upscale_amount[dim] | |
| if callable(up): | |
| return up(val) | |
| else: | |
| return val / up | |
| def get_upscale_pos(dim: int, val: int) -> int: | |
| """#### Get the upscaled position. | |
| #### Args: | |
| - `dim` (int): The dimension. | |
| - `val` (int): The value. | |
| #### Returns: | |
| - `int`: The upscaled position. | |
| """ | |
| up = index_formulas[dim] | |
| if callable(up): | |
| return up(val) | |
| else: | |
| return up * val | |
| def get_downscale_pos(dim: int, val: int) -> int: | |
| """#### Get the downscaled position. | |
| #### Args: | |
| - `dim` (int): The dimension. | |
| - `val` (int): The value. | |
| #### Returns: | |
| - `int`: The downscaled position. | |
| """ | |
| up = index_formulas[dim] | |
| if callable(up): | |
| return up(val) | |
| else: | |
| return val / up | |
| if downscale: | |
| get_scale = get_downscale | |
| get_pos = get_downscale_pos | |
| else: | |
| get_scale = get_upscale | |
| get_pos = get_upscale_pos | |
| def mult_list_upscale(a: list) -> list: | |
| """#### Multiply a list by the upscale amount. | |
| #### Args: | |
| - `a` (list): The list. | |
| #### Returns: | |
| - `list`: The multiplied list. | |
| """ | |
| out = [] | |
| for i in range(len(a)): | |
| out.append(round(get_scale(i, a[i]))) | |
| return out | |
| output = torch.empty( | |
| [samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), | |
| device=output_device, | |
| ) | |
| for b in range(samples.shape[0]): | |
| s = samples[b : b + 1] | |
| # handle entire input fitting in a single tile | |
| if all(s.shape[d + 2] <= tile[d] for d in range(dims)): | |
| output[b : b + 1] = function(s).to(output_device) | |
| if pbar is not None: | |
| pbar.update(1) | |
| continue | |
| out = torch.zeros( | |
| [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), | |
| device=output_device, | |
| ) | |
| out_div = torch.zeros( | |
| [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), | |
| device=output_device, | |
| ) | |
| positions = [ | |
| range(0, s.shape[d + 2] - overlap[d], tile[d] - overlap[d]) | |
| if s.shape[d + 2] > tile[d] | |
| else [0] | |
| for d in range(dims) | |
| ] | |
| for it in itertools.product(*positions): | |
| s_in = s | |
| upscaled = [] | |
| for d in range(dims): | |
| pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) | |
| l = min(tile[d], s.shape[d + 2] - pos) | |
| s_in = s_in.narrow(d + 2, pos, l) | |
| upscaled.append(round(get_pos(d, pos))) | |
| ps = function(s_in).to(output_device) | |
| mask = torch.ones_like(ps) | |
| for d in range(2, dims + 2): | |
| feather = round(get_scale(d - 2, overlap[d - 2])) | |
| if feather >= mask.shape[d]: | |
| continue | |
| for t in range(feather): | |
| a = (t + 1) / feather | |
| mask.narrow(d, t, 1).mul_(a) | |
| mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) | |
| o = out | |
| o_d = out_div | |
| for d in range(dims): | |
| o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) | |
| o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) | |
| o.add_(ps * mask) | |
| o_d.add_(mask) | |
| if pbar is not None: | |
| pbar.update(1) | |
| output[b : b + 1] = out / out_div | |
| return output | |
| def tiled_scale( | |
| samples: torch.Tensor, | |
| function, | |
| tile_x: int = 64, | |
| tile_y: int = 64, | |
| overlap: int = 8, | |
| upscale_amount: int = 4, | |
| out_channels: int = 3, | |
| output_device: str = "cpu", | |
| pbar: any = None, | |
| ): | |
| """#### Scale an image using a tiled approach. | |
| #### Args: | |
| - `samples` (torch.Tensor): The input samples. | |
| - `function` (function): The scaling function. | |
| - `tile_x` (int, optional): The tile width. Defaults to 64. | |
| - `tile_y` (int, optional): The tile height. Defaults to 64. | |
| - `overlap` (int, optional): The overlap. Defaults to 8. | |
| - `upscale_amount` (int, optional): The upscale amount. Defaults to 4. | |
| - `out_channels` (int, optional): The number of output channels. Defaults to 3. | |
| - `output_device` (str, optional): The output device. Defaults to "cpu". | |
| - `pbar` (any, optional): The progress bar. Defaults to None. | |
| #### Returns: | |
| - The scaled image. | |
| """ | |
| return tiled_scale_multidim( | |
| samples, | |
| function, | |
| (tile_y, tile_x), | |
| overlap=overlap, | |
| upscale_amount=upscale_amount, | |
| out_channels=out_channels, | |
| output_device=output_device, | |
| pbar=pbar, | |
| ) | |