Aatricks's picture
Upload folder using huggingface_hub
cfe609e verified
import importlib
from inspect import isfunction
import itertools
import logging
import math
import os
import safetensors.torch
import torch
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""#### Appends dimensions to the end of a tensor until it has target_dims dimensions.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `target_dims` (int): The target number of dimensions.
#### Returns:
- `torch.Tensor`: The expanded tensor.
"""
dims_to_append = target_dims - x.ndim
expanded = x[(...,) + (None,) * dims_to_append]
return expanded.detach().clone() if expanded.device.type == "mps" else expanded
def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor:
"""#### Convert a tensor to a denoised tensor.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `sigma` (torch.Tensor): The noise level.
- `denoised` (torch.Tensor): The denoised tensor.
#### Returns:
- `torch.Tensor`: The converted tensor.
"""
return (x - denoised) / append_dims(sigma, x.ndim)
def load_torch_file(ckpt: str, safe_load: bool = False, device: str = None) -> dict:
"""#### Load a PyTorch checkpoint file.
#### Args:
- `ckpt` (str): The path to the checkpoint file.
- `safe_load` (bool, optional): Whether to use safe loading. Defaults to False.
- `device` (str, optional): The device to load the checkpoint on. Defaults to None.
#### Returns:
- `dict`: The loaded checkpoint.
"""
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
if "weights_only" not in torch.load.__code__.co_varnames:
logging.warning(
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
)
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device)
if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def calculate_parameters(sd: dict, prefix: str = "") -> dict:
"""#### Calculate the parameters of a state dictionary.
#### Args:
- `sd` (dict): The state dictionary.
- `prefix` (str, optional): The prefix for the parameters. Defaults to "".
#### Returns:
- `dict`: The calculated parameters.
"""
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params
def state_dict_prefix_replace(
state_dict: dict, replace_prefix: str, filter_keys: bool = False
) -> dict:
"""#### Replace the prefix of keys in a state dictionary.
#### Args:
- `state_dict` (dict): The state dictionary.
- `replace_prefix` (str): The prefix to replace.
- `filter_keys` (bool, optional): Whether to filter keys. Defaults to False.
#### Returns:
- `dict`: The updated state dictionary.
"""
if filter_keys:
out = {}
else:
out = state_dict
for rp in replace_prefix:
replace = list(
map(
lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp) :])),
filter(lambda a: a.startswith(rp), state_dict.keys()),
)
)
for x in replace:
w = state_dict.pop(x[0])
out[x[1]] = w
return out
def lcm_of_list(numbers):
"""Calculate LCM of a list of numbers more efficiently."""
if not numbers:
return 1
result = numbers[0]
for num in numbers[1:]:
result = torch.lcm(torch.tensor(result), torch.tensor(num)).item()
return result
def repeat_to_batch_size(
tensor: torch.Tensor, batch_size: int, dim: int = 0
) -> torch.Tensor:
"""#### Repeat a tensor to match a specific batch size.
#### Args:
- `tensor` (torch.Tensor): The input tensor.
- `batch_size` (int): The target batch size.
- `dim` (int, optional): The dimension to repeat. Defaults to 0.
#### Returns:
- `torch.Tensor`: The repeated tensor.
"""
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)
elif tensor.shape[dim] < batch_size:
return tensor.repeat(
dim * [1]
+ [math.ceil(batch_size / tensor.shape[dim])]
+ [1] * (len(tensor.shape) - 1 - dim)
).narrow(dim, 0, batch_size)
return tensor
def set_attr(obj: object, attr: str, value: any) -> any:
"""#### Set an attribute of an object.
#### Args:
- `obj` (object): The object.
- `attr` (str): The attribute name.
- `value` (any): The value to set.
#### Returns:
- `prev`: The previous attribute value.
"""
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj: object, attr: str, value: any) -> any:
"""#### Set an attribute parameter of an object.
#### Args:
- `obj` (object): The object.
- `attr` (str): The attribute name.
- `value` (any): The value to set.
#### Returns:
- `prev`: The previous attribute value.
"""
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def copy_to_param(obj: object, attr: str, value: any) -> None:
"""#### Copy a value to a parameter of an object.
#### Args:
- `obj` (object): The object.
- `attr` (str): The attribute name.
- `value` (any): The value to set.
"""
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_obj_from_str(string: str, reload: bool = False) -> object:
"""#### Get an object from a string.
#### Args:
- `string` (str): The string.
- `reload` (bool, optional): Whether to reload the module. Defaults to False.
#### Returns:
- `object`: The object.
"""
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def get_attr(obj: object, attr: str) -> any:
"""#### Get an attribute of an object.
#### Args:
- `obj` (object): The object.
- `attr` (str): The attribute name.
#### Returns:
- `obj`: The attribute value.
"""
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj
def lcm(a: int, b: int) -> int:
"""#### Calculate the least common multiple (LCM) of two numbers.
#### Args:
- `a` (int): The first number.
- `b` (int): The second number.
#### Returns:
- `int`: The LCM of the two numbers.
"""
return abs(a * b) // math.gcd(a, b)
def get_full_path(folder_name: str, filename: str) -> str:
"""#### Get the full path of a file in a folder.
Args:
folder_name (str): The folder name.
filename (str): The filename.
Returns:
str: The full path of the file.
"""
global folder_names_and_paths
folders = folder_names_and_paths[folder_name]
filename = os.path.relpath(os.path.join("/", filename), "/")
for x in folders[0]:
full_path = os.path.join(x, filename)
if os.path.isfile(full_path):
return full_path
def zero_module(module: torch.nn.Module) -> torch.nn.Module:
"""#### Zero out the parameters of a module.
#### Args:
- `module` (torch.nn.Module): The module.
#### Returns:
- `torch.nn.Module`: The zeroed module.
"""
for p in module.parameters():
p.detach().zero_()
return module
def append_zero(x: torch.Tensor) -> torch.Tensor:
"""#### Append a zero to the end of a tensor.
#### Args:
- `x` (torch.Tensor): The input tensor.
#### Returns:
- `torch.Tensor`: The tensor with a zero appended.
"""
return torch.cat([x, x.new_zeros([1])])
def exists(val: any) -> bool:
"""#### Check if a value exists.
#### Args:
- `val` (any): The value.
#### Returns:
- `bool`: Whether the value exists.
"""
return val is not None
def default(val: any, d: any) -> any:
"""#### Get the default value of a variable.
#### Args:
- `val` (any): The value.
- `d` (any): The default value.
#### Returns:
- `any`: The default value if the value does not exist.
"""
if exists(val):
return val
return d() if isfunction(d) else d
def write_parameters_to_file(
prompt_entry: str, neg: str, width: int, height: int, cfg: int
) -> None:
"""#### Write parameters to a file.
#### Args:
- `prompt_entry` (str): The prompt entry.
- `neg` (str): The negative prompt entry.
- `width` (int): The width.
- `height` (int): The height.
- `cfg` (int): The CFG.
"""
with open("./_internal/prompt.txt", "w") as f:
f.write(f"prompt: {prompt_entry}")
f.write(f"neg: {neg}")
f.write(f"w: {int(width)}\n")
f.write(f"h: {int(height)}\n")
f.write(f"cfg: {int(cfg)}\n")
def load_parameters_from_file() -> tuple:
"""#### Load parameters from a file.
#### Returns:
- `str`: The prompt entry.
- `str`: The negative prompt entry.
- `int`: The width.
- `int`: The height.
- `int`: The CFG.
"""
with open("./_internal/prompt.txt", "r") as f:
lines = f.readlines()
parameters = {}
for line in lines:
# Skip empty lines
if line.strip() == "":
continue
key, value = line.split(": ")
parameters[key] = value.strip()
prompt = parameters["prompt"]
neg = parameters["neg"]
width = int(parameters["w"])
height = int(parameters["h"])
cfg = int(parameters["cfg"])
return prompt, neg, width, height, cfg
PROGRESS_BAR_ENABLED = True
PROGRESS_BAR_HOOK = None
class ProgressBar:
"""#### Class representing a progress bar."""
def __init__(self, total: int):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
def get_tiled_scale_steps(
width: int, height: int, tile_x: int, tile_y: int, overlap: int
) -> int:
"""#### Get the number of steps for tiled scaling.
#### Args:
- `width` (int): The width.
- `height` (int): The height.
- `tile_x` (int): The tile width.
- `tile_y` (int): The tile height.
- `overlap` (int): The overlap.
#### Returns:
- `int`: The number of steps.
"""
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
return rows * cols
@torch.inference_mode()
def tiled_scale_multidim(
samples: torch.Tensor,
function,
tile: tuple = (64, 64),
overlap: int = 8,
upscale_amount: int = 4,
out_channels: int = 3,
output_device: str = "cpu",
downscale: bool = False,
index_formulas: any = None,
pbar: any = None,
):
"""#### Scale an image using a tiled approach.
#### Args:
- `samples` (torch.Tensor): The input samples.
- `function` (function): The scaling function.
- `tile` (tuple, optional): The tile size. Defaults to (64, 64).
- `overlap` (int, optional): The overlap. Defaults to 8.
- `upscale_amount` (int, optional): The upscale amount. Defaults to 4.
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
- `output_device` (str, optional): The output device. Defaults to "cpu".
- `downscale` (bool, optional): Whether to downscale. Defaults to False.
- `index_formulas` (any, optional): The index formulas. Defaults to None.
- `pbar` (any, optional): The progress bar. Defaults to None.
#### Returns:
- `torch.Tensor`: The scaled image.
"""
dims = len(tile)
if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
if index_formulas is None:
index_formulas = upscale_amount
if not (isinstance(index_formulas, (tuple, list))):
index_formulas = [index_formulas] * dims
def get_upscale(dim: int, val: int) -> int:
"""#### Get the upscale value.
#### Args:
- `dim` (int): The dimension.
- `val` (int): The value.
#### Returns:
- `int`: The upscaled value.
"""
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return up * val
def get_downscale(dim: int, val: int) -> int:
"""#### Get the downscale value.
#### Args:
- `dim` (int): The dimension.
- `val` (int): The value.
#### Returns:
- `int`: The downscaled value.
"""
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return val / up
def get_upscale_pos(dim: int, val: int) -> int:
"""#### Get the upscaled position.
#### Args:
- `dim` (int): The dimension.
- `val` (int): The value.
#### Returns:
- `int`: The upscaled position.
"""
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return up * val
def get_downscale_pos(dim: int, val: int) -> int:
"""#### Get the downscaled position.
#### Args:
- `dim` (int): The dimension.
- `val` (int): The value.
#### Returns:
- `int`: The downscaled position.
"""
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return val / up
if downscale:
get_scale = get_downscale
get_pos = get_downscale_pos
else:
get_scale = get_upscale
get_pos = get_upscale_pos
def mult_list_upscale(a: list) -> list:
"""#### Multiply a list by the upscale amount.
#### Args:
- `a` (list): The list.
#### Returns:
- `list`: The multiplied list.
"""
out = []
for i in range(len(a)):
out.append(round(get_scale(i, a[i])))
return out
output = torch.empty(
[samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]),
device=output_device,
)
for b in range(samples.shape[0]):
s = samples[b : b + 1]
# handle entire input fitting in a single tile
if all(s.shape[d + 2] <= tile[d] for d in range(dims)):
output[b : b + 1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue
out = torch.zeros(
[s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]),
device=output_device,
)
out_div = torch.zeros(
[s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]),
device=output_device,
)
positions = [
range(0, s.shape[d + 2] - overlap[d], tile[d] - overlap[d])
if s.shape[d + 2] > tile[d]
else [0]
for d in range(dims)
]
for it in itertools.product(*positions):
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = out
o_d = out_div
for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o.add_(ps * mask)
o_d.add_(mask)
if pbar is not None:
pbar.update(1)
output[b : b + 1] = out / out_div
return output
def tiled_scale(
samples: torch.Tensor,
function,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 8,
upscale_amount: int = 4,
out_channels: int = 3,
output_device: str = "cpu",
pbar: any = None,
):
"""#### Scale an image using a tiled approach.
#### Args:
- `samples` (torch.Tensor): The input samples.
- `function` (function): The scaling function.
- `tile_x` (int, optional): The tile width. Defaults to 64.
- `tile_y` (int, optional): The tile height. Defaults to 64.
- `overlap` (int, optional): The overlap. Defaults to 8.
- `upscale_amount` (int, optional): The upscale amount. Defaults to 4.
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
- `output_device` (str, optional): The output device. Defaults to "cpu".
- `pbar` (any, optional): The progress bar. Defaults to None.
#### Returns:
- The scaled image.
"""
return tiled_scale_multidim(
samples,
function,
(tile_y, tile_x),
overlap=overlap,
upscale_amount=upscale_amount,
out_channels=out_channels,
output_device=output_device,
pbar=pbar,
)