Spaces:
Running
on
Zero
Running
on
Zero
import importlib | |
from inspect import isfunction | |
import itertools | |
import logging | |
import math | |
import os | |
import safetensors.torch | |
import torch | |
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: | |
"""#### Appends dimensions to the end of a tensor until it has target_dims dimensions. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `target_dims` (int): The target number of dimensions. | |
#### Returns: | |
- `torch.Tensor`: The expanded tensor. | |
""" | |
dims_to_append = target_dims - x.ndim | |
expanded = x[(...,) + (None,) * dims_to_append] | |
return expanded.detach().clone() if expanded.device.type == "mps" else expanded | |
def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: | |
"""#### Convert a tensor to a denoised tensor. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `sigma` (torch.Tensor): The noise level. | |
- `denoised` (torch.Tensor): The denoised tensor. | |
#### Returns: | |
- `torch.Tensor`: The converted tensor. | |
""" | |
return (x - denoised) / append_dims(sigma, x.ndim) | |
def load_torch_file(ckpt: str, safe_load: bool = False, device: str = None) -> dict: | |
"""#### Load a PyTorch checkpoint file. | |
#### Args: | |
- `ckpt` (str): The path to the checkpoint file. | |
- `safe_load` (bool, optional): Whether to use safe loading. Defaults to False. | |
- `device` (str, optional): The device to load the checkpoint on. Defaults to None. | |
#### Returns: | |
- `dict`: The loaded checkpoint. | |
""" | |
if device is None: | |
device = torch.device("cpu") | |
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): | |
sd = safetensors.torch.load_file(ckpt, device=device.type) | |
else: | |
if safe_load: | |
if "weights_only" not in torch.load.__code__.co_varnames: | |
logging.warning( | |
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely." | |
) | |
safe_load = False | |
if safe_load: | |
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) | |
else: | |
pl_sd = torch.load(ckpt, map_location=device) | |
if "global_step" in pl_sd: | |
logging.debug(f"Global Step: {pl_sd['global_step']}") | |
if "state_dict" in pl_sd: | |
sd = pl_sd["state_dict"] | |
else: | |
sd = pl_sd | |
return sd | |
def calculate_parameters(sd: dict, prefix: str = "") -> dict: | |
"""#### Calculate the parameters of a state dictionary. | |
#### Args: | |
- `sd` (dict): The state dictionary. | |
- `prefix` (str, optional): The prefix for the parameters. Defaults to "". | |
#### Returns: | |
- `dict`: The calculated parameters. | |
""" | |
params = 0 | |
for k in sd.keys(): | |
if k.startswith(prefix): | |
params += sd[k].nelement() | |
return params | |
def state_dict_prefix_replace( | |
state_dict: dict, replace_prefix: str, filter_keys: bool = False | |
) -> dict: | |
"""#### Replace the prefix of keys in a state dictionary. | |
#### Args: | |
- `state_dict` (dict): The state dictionary. | |
- `replace_prefix` (str): The prefix to replace. | |
- `filter_keys` (bool, optional): Whether to filter keys. Defaults to False. | |
#### Returns: | |
- `dict`: The updated state dictionary. | |
""" | |
if filter_keys: | |
out = {} | |
else: | |
out = state_dict | |
for rp in replace_prefix: | |
replace = list( | |
map( | |
lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp) :])), | |
filter(lambda a: a.startswith(rp), state_dict.keys()), | |
) | |
) | |
for x in replace: | |
w = state_dict.pop(x[0]) | |
out[x[1]] = w | |
return out | |
def lcm_of_list(numbers): | |
"""Calculate LCM of a list of numbers more efficiently.""" | |
if not numbers: | |
return 1 | |
result = numbers[0] | |
for num in numbers[1:]: | |
result = torch.lcm(torch.tensor(result), torch.tensor(num)).item() | |
return result | |
def repeat_to_batch_size( | |
tensor: torch.Tensor, batch_size: int, dim: int = 0 | |
) -> torch.Tensor: | |
"""#### Repeat a tensor to match a specific batch size. | |
#### Args: | |
- `tensor` (torch.Tensor): The input tensor. | |
- `batch_size` (int): The target batch size. | |
- `dim` (int, optional): The dimension to repeat. Defaults to 0. | |
#### Returns: | |
- `torch.Tensor`: The repeated tensor. | |
""" | |
if tensor.shape[dim] > batch_size: | |
return tensor.narrow(dim, 0, batch_size) | |
elif tensor.shape[dim] < batch_size: | |
return tensor.repeat( | |
dim * [1] | |
+ [math.ceil(batch_size / tensor.shape[dim])] | |
+ [1] * (len(tensor.shape) - 1 - dim) | |
).narrow(dim, 0, batch_size) | |
return tensor | |
def set_attr(obj: object, attr: str, value: any) -> any: | |
"""#### Set an attribute of an object. | |
#### Args: | |
- `obj` (object): The object. | |
- `attr` (str): The attribute name. | |
- `value` (any): The value to set. | |
#### Returns: | |
- `prev`: The previous attribute value. | |
""" | |
attrs = attr.split(".") | |
for name in attrs[:-1]: | |
obj = getattr(obj, name) | |
prev = getattr(obj, attrs[-1]) | |
setattr(obj, attrs[-1], value) | |
return prev | |
def set_attr_param(obj: object, attr: str, value: any) -> any: | |
"""#### Set an attribute parameter of an object. | |
#### Args: | |
- `obj` (object): The object. | |
- `attr` (str): The attribute name. | |
- `value` (any): The value to set. | |
#### Returns: | |
- `prev`: The previous attribute value. | |
""" | |
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) | |
def copy_to_param(obj: object, attr: str, value: any) -> None: | |
"""#### Copy a value to a parameter of an object. | |
#### Args: | |
- `obj` (object): The object. | |
- `attr` (str): The attribute name. | |
- `value` (any): The value to set. | |
""" | |
attrs = attr.split(".") | |
for name in attrs[:-1]: | |
obj = getattr(obj, name) | |
prev = getattr(obj, attrs[-1]) | |
prev.data.copy_(value) | |
def get_obj_from_str(string: str, reload: bool = False) -> object: | |
"""#### Get an object from a string. | |
#### Args: | |
- `string` (str): The string. | |
- `reload` (bool, optional): Whether to reload the module. Defaults to False. | |
#### Returns: | |
- `object`: The object. | |
""" | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def get_attr(obj: object, attr: str) -> any: | |
"""#### Get an attribute of an object. | |
#### Args: | |
- `obj` (object): The object. | |
- `attr` (str): The attribute name. | |
#### Returns: | |
- `obj`: The attribute value. | |
""" | |
attrs = attr.split(".") | |
for name in attrs: | |
obj = getattr(obj, name) | |
return obj | |
def lcm(a: int, b: int) -> int: | |
"""#### Calculate the least common multiple (LCM) of two numbers. | |
#### Args: | |
- `a` (int): The first number. | |
- `b` (int): The second number. | |
#### Returns: | |
- `int`: The LCM of the two numbers. | |
""" | |
return abs(a * b) // math.gcd(a, b) | |
def get_full_path(folder_name: str, filename: str) -> str: | |
"""#### Get the full path of a file in a folder. | |
Args: | |
folder_name (str): The folder name. | |
filename (str): The filename. | |
Returns: | |
str: The full path of the file. | |
""" | |
global folder_names_and_paths | |
folders = folder_names_and_paths[folder_name] | |
filename = os.path.relpath(os.path.join("/", filename), "/") | |
for x in folders[0]: | |
full_path = os.path.join(x, filename) | |
if os.path.isfile(full_path): | |
return full_path | |
def zero_module(module: torch.nn.Module) -> torch.nn.Module: | |
"""#### Zero out the parameters of a module. | |
#### Args: | |
- `module` (torch.nn.Module): The module. | |
#### Returns: | |
- `torch.nn.Module`: The zeroed module. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
def append_zero(x: torch.Tensor) -> torch.Tensor: | |
"""#### Append a zero to the end of a tensor. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The tensor with a zero appended. | |
""" | |
return torch.cat([x, x.new_zeros([1])]) | |
def exists(val: any) -> bool: | |
"""#### Check if a value exists. | |
#### Args: | |
- `val` (any): The value. | |
#### Returns: | |
- `bool`: Whether the value exists. | |
""" | |
return val is not None | |
def default(val: any, d: any) -> any: | |
"""#### Get the default value of a variable. | |
#### Args: | |
- `val` (any): The value. | |
- `d` (any): The default value. | |
#### Returns: | |
- `any`: The default value if the value does not exist. | |
""" | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
def write_parameters_to_file( | |
prompt_entry: str, neg: str, width: int, height: int, cfg: int | |
) -> None: | |
"""#### Write parameters to a file. | |
#### Args: | |
- `prompt_entry` (str): The prompt entry. | |
- `neg` (str): The negative prompt entry. | |
- `width` (int): The width. | |
- `height` (int): The height. | |
- `cfg` (int): The CFG. | |
""" | |
with open("./_internal/prompt.txt", "w") as f: | |
f.write(f"prompt: {prompt_entry}") | |
f.write(f"neg: {neg}") | |
f.write(f"w: {int(width)}\n") | |
f.write(f"h: {int(height)}\n") | |
f.write(f"cfg: {int(cfg)}\n") | |
def load_parameters_from_file() -> tuple: | |
"""#### Load parameters from a file. | |
#### Returns: | |
- `str`: The prompt entry. | |
- `str`: The negative prompt entry. | |
- `int`: The width. | |
- `int`: The height. | |
- `int`: The CFG. | |
""" | |
with open("./_internal/prompt.txt", "r") as f: | |
lines = f.readlines() | |
parameters = {} | |
for line in lines: | |
# Skip empty lines | |
if line.strip() == "": | |
continue | |
key, value = line.split(": ") | |
parameters[key] = value.strip() | |
prompt = parameters["prompt"] | |
neg = parameters["neg"] | |
width = int(parameters["w"]) | |
height = int(parameters["h"]) | |
cfg = int(parameters["cfg"]) | |
return prompt, neg, width, height, cfg | |
PROGRESS_BAR_ENABLED = True | |
PROGRESS_BAR_HOOK = None | |
class ProgressBar: | |
"""#### Class representing a progress bar.""" | |
def __init__(self, total: int): | |
global PROGRESS_BAR_HOOK | |
self.total = total | |
self.current = 0 | |
self.hook = PROGRESS_BAR_HOOK | |
def get_tiled_scale_steps( | |
width: int, height: int, tile_x: int, tile_y: int, overlap: int | |
) -> int: | |
"""#### Get the number of steps for tiled scaling. | |
#### Args: | |
- `width` (int): The width. | |
- `height` (int): The height. | |
- `tile_x` (int): The tile width. | |
- `tile_y` (int): The tile height. | |
- `overlap` (int): The overlap. | |
#### Returns: | |
- `int`: The number of steps. | |
""" | |
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap)) | |
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap)) | |
return rows * cols | |
def tiled_scale_multidim( | |
samples: torch.Tensor, | |
function, | |
tile: tuple = (64, 64), | |
overlap: int = 8, | |
upscale_amount: int = 4, | |
out_channels: int = 3, | |
output_device: str = "cpu", | |
downscale: bool = False, | |
index_formulas: any = None, | |
pbar: any = None, | |
): | |
"""#### Scale an image using a tiled approach. | |
#### Args: | |
- `samples` (torch.Tensor): The input samples. | |
- `function` (function): The scaling function. | |
- `tile` (tuple, optional): The tile size. Defaults to (64, 64). | |
- `overlap` (int, optional): The overlap. Defaults to 8. | |
- `upscale_amount` (int, optional): The upscale amount. Defaults to 4. | |
- `out_channels` (int, optional): The number of output channels. Defaults to 3. | |
- `output_device` (str, optional): The output device. Defaults to "cpu". | |
- `downscale` (bool, optional): Whether to downscale. Defaults to False. | |
- `index_formulas` (any, optional): The index formulas. Defaults to None. | |
- `pbar` (any, optional): The progress bar. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The scaled image. | |
""" | |
dims = len(tile) | |
if not (isinstance(upscale_amount, (tuple, list))): | |
upscale_amount = [upscale_amount] * dims | |
if not (isinstance(overlap, (tuple, list))): | |
overlap = [overlap] * dims | |
if index_formulas is None: | |
index_formulas = upscale_amount | |
if not (isinstance(index_formulas, (tuple, list))): | |
index_formulas = [index_formulas] * dims | |
def get_upscale(dim: int, val: int) -> int: | |
"""#### Get the upscale value. | |
#### Args: | |
- `dim` (int): The dimension. | |
- `val` (int): The value. | |
#### Returns: | |
- `int`: The upscaled value. | |
""" | |
up = upscale_amount[dim] | |
if callable(up): | |
return up(val) | |
else: | |
return up * val | |
def get_downscale(dim: int, val: int) -> int: | |
"""#### Get the downscale value. | |
#### Args: | |
- `dim` (int): The dimension. | |
- `val` (int): The value. | |
#### Returns: | |
- `int`: The downscaled value. | |
""" | |
up = upscale_amount[dim] | |
if callable(up): | |
return up(val) | |
else: | |
return val / up | |
def get_upscale_pos(dim: int, val: int) -> int: | |
"""#### Get the upscaled position. | |
#### Args: | |
- `dim` (int): The dimension. | |
- `val` (int): The value. | |
#### Returns: | |
- `int`: The upscaled position. | |
""" | |
up = index_formulas[dim] | |
if callable(up): | |
return up(val) | |
else: | |
return up * val | |
def get_downscale_pos(dim: int, val: int) -> int: | |
"""#### Get the downscaled position. | |
#### Args: | |
- `dim` (int): The dimension. | |
- `val` (int): The value. | |
#### Returns: | |
- `int`: The downscaled position. | |
""" | |
up = index_formulas[dim] | |
if callable(up): | |
return up(val) | |
else: | |
return val / up | |
if downscale: | |
get_scale = get_downscale | |
get_pos = get_downscale_pos | |
else: | |
get_scale = get_upscale | |
get_pos = get_upscale_pos | |
def mult_list_upscale(a: list) -> list: | |
"""#### Multiply a list by the upscale amount. | |
#### Args: | |
- `a` (list): The list. | |
#### Returns: | |
- `list`: The multiplied list. | |
""" | |
out = [] | |
for i in range(len(a)): | |
out.append(round(get_scale(i, a[i]))) | |
return out | |
output = torch.empty( | |
[samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), | |
device=output_device, | |
) | |
for b in range(samples.shape[0]): | |
s = samples[b : b + 1] | |
# handle entire input fitting in a single tile | |
if all(s.shape[d + 2] <= tile[d] for d in range(dims)): | |
output[b : b + 1] = function(s).to(output_device) | |
if pbar is not None: | |
pbar.update(1) | |
continue | |
out = torch.zeros( | |
[s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), | |
device=output_device, | |
) | |
out_div = torch.zeros( | |
[s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), | |
device=output_device, | |
) | |
positions = [ | |
range(0, s.shape[d + 2] - overlap[d], tile[d] - overlap[d]) | |
if s.shape[d + 2] > tile[d] | |
else [0] | |
for d in range(dims) | |
] | |
for it in itertools.product(*positions): | |
s_in = s | |
upscaled = [] | |
for d in range(dims): | |
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) | |
l = min(tile[d], s.shape[d + 2] - pos) | |
s_in = s_in.narrow(d + 2, pos, l) | |
upscaled.append(round(get_pos(d, pos))) | |
ps = function(s_in).to(output_device) | |
mask = torch.ones_like(ps) | |
for d in range(2, dims + 2): | |
feather = round(get_scale(d - 2, overlap[d - 2])) | |
if feather >= mask.shape[d]: | |
continue | |
for t in range(feather): | |
a = (t + 1) / feather | |
mask.narrow(d, t, 1).mul_(a) | |
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) | |
o = out | |
o_d = out_div | |
for d in range(dims): | |
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) | |
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) | |
o.add_(ps * mask) | |
o_d.add_(mask) | |
if pbar is not None: | |
pbar.update(1) | |
output[b : b + 1] = out / out_div | |
return output | |
def tiled_scale( | |
samples: torch.Tensor, | |
function, | |
tile_x: int = 64, | |
tile_y: int = 64, | |
overlap: int = 8, | |
upscale_amount: int = 4, | |
out_channels: int = 3, | |
output_device: str = "cpu", | |
pbar: any = None, | |
): | |
"""#### Scale an image using a tiled approach. | |
#### Args: | |
- `samples` (torch.Tensor): The input samples. | |
- `function` (function): The scaling function. | |
- `tile_x` (int, optional): The tile width. Defaults to 64. | |
- `tile_y` (int, optional): The tile height. Defaults to 64. | |
- `overlap` (int, optional): The overlap. Defaults to 8. | |
- `upscale_amount` (int, optional): The upscale amount. Defaults to 4. | |
- `out_channels` (int, optional): The number of output channels. Defaults to 3. | |
- `output_device` (str, optional): The output device. Defaults to "cpu". | |
- `pbar` (any, optional): The progress bar. Defaults to None. | |
#### Returns: | |
- The scaled image. | |
""" | |
return tiled_scale_multidim( | |
samples, | |
function, | |
(tile_y, tile_x), | |
overlap=overlap, | |
upscale_amount=upscale_amount, | |
out_channels=out_channels, | |
output_device=output_device, | |
pbar=pbar, | |
) | |