Aatricks's picture
Upload folder using huggingface_hub
cfe609e verified
import collections
import logging
import numpy as np
import scipy
import torch
from modules.sample import sampling_util
def calculate_start_end_timesteps(model: torch.nn.Module, conds: list) -> None:
"""#### Calculate the start and end timesteps for a model.
#### Args:
- `model` (torch.nn.Module): The input model.
- `conds` (list): The list of conditions.
"""
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
if "start_percent" in x:
timestep_start = s.percent_to_sigma(x["start_percent"])
if "end_percent" in x:
timestep_end = s.percent_to_sigma(x["end_percent"])
if (timestep_start is not None) or (timestep_end is not None):
n = x.copy()
if timestep_start is not None:
n["timestep_start"] = timestep_start
if timestep_end is not None:
n["timestep_end"] = timestep_end
conds[t] = n
def pre_run_control(model: torch.nn.Module, conds: list) -> None:
"""#### Pre-run control for a model.
#### Args:
- `model` (torch.nn.Module): The input model.
- `conds` (list): The list of conditions.
"""
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
def percent_to_timestep_function(a):
return s.percent_to_sigma(a)
if "control" in x:
x["control"].pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(
conds: list, uncond: list, name: str, uncond_fill_func: callable
) -> None:
"""#### Apply empty x to equal area.
#### Args:
- `conds` (list): The list of conditions.
- `uncond` (list): The list of unconditional conditions.
- `name` (str): The name.
- `uncond_fill_func` (callable): The unconditional fill function.
"""
cond_cnets = []
cond_other = []
uncond_cnets = []
uncond_other = []
for t in range(len(conds)):
x = conds[t]
if "area" not in x:
if name in x and x[name] is not None:
cond_cnets.append(x[name])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if "area" not in x:
if name in x and x[name] is not None:
uncond_cnets.append(x[name])
else:
uncond_other.append((x, t))
if len(uncond_cnets) > 0:
return
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if name in o and o[name] is not None:
n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond += [n]
else:
n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = n
# Define the namedtuple class once outside the function for reuse
CondObj = collections.namedtuple(
"cond_obj", ["input_x", "mult", "conditioning", "area", "control", "patches"]
)
def get_area_and_mult(conds: dict, x_in: torch.Tensor, timestep_in: int) -> CondObj:
"""#### Get the area and multiplier.
#### Args:
- `conds` (dict): The conditions.
- `x_in` (torch.Tensor): The input tensor.
- `timestep_in` (int): The timestep.
#### Returns:
- `collections.namedtuple`: The area and multiplier.
"""
# Cache shape information to avoid repeated access
x_shape = x_in.shape
# Define area dimensions in one operation
area = (x_shape[2], x_shape[3], 0, 0)
# Extract input region efficiently
# Since area[2] and area[3] are 0, this is essentially taking the full tensor
# But we maintain the slice operation for consistency
input_x = x_in[:, :, : area[0], : area[1]]
# Create multiplier tensor directly without intermediate mask creation
# This avoids an unnecessary tensor allocation and multiplication
mult = torch.ones_like(input_x) # strength is 1.0, so just create ones directly
# Prepare conditioning dictionary with cached device and batch_size
conditioning = {}
model_conds = conds["model_conds"]
batch_size = x_shape[0]
device = x_in.device
# Process conditions with cached parameters
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(
batch_size=batch_size, device=device, area=area
)
# Get control directly without redundant variable assignment
control = conds.get("control", None)
patches = None
# Use the pre-defined namedtuple class instead of creating it every call
return CondObj(input_x, mult, conditioning, area, control, patches)
def normal_scheduler(
model_sampling: torch.nn.Module, steps: int, sgm: bool = False, floor: bool = False
) -> torch.FloatTensor:
"""#### Create a normal scheduler.
#### Args:
- `model_sampling` (torch.nn.Module): The model sampling module.
- `steps` (int): The number of steps.
- `sgm` (bool, optional): Whether to use SGM. Defaults to False.
- `floor` (bool, optional): Whether to floor the values. Defaults to False.
#### Returns:
- `torch.FloatTensor`: The scheduler.
"""
s = model_sampling
start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min)
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(s.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor:
"""#### Create a simple scheduler.
#### Args:
- `model_sampling` (torch.nn.Module): The model sampling module.
- `steps` (int): The number of steps.
#### Returns:
- `torch.FloatTensor`: The scheduler.
"""
s = model_sampling
sigs = []
ss = len(s.sigmas) / steps
for x in range(steps):
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
sigs += [0.0]
return torch.FloatTensor(sigs)
# Implemented based on: https://arxiv.org/abs/2407.12173
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
"""Creates a beta scheduler for noise levels based on the beta distribution.
This optimized implementation efficiently computes sigmas using the beta
distribution and caches calculations where possible.
Args:
model_sampling: Model sampling module
steps: Number of steps
alpha: Alpha parameter for beta distribution
beta: Beta parameter for beta distribution
Returns:
torch.FloatTensor: Tensor of sigma values for each step
"""
# Calculate total timesteps once
total_timesteps = len(model_sampling.sigmas) - 1
# Create a cache dictionary for reused values
model_sigmas = model_sampling.sigmas
# Generate evenly spaced values in [0,1) interval
ts_normalized = np.linspace(0, 1, steps, endpoint=False)
# Apply beta inverse CDF to get sampled time points - vectorized operation
ts_beta = scipy.stats.beta.ppf(1 - ts_normalized, alpha, beta)
# Scale to timestep indices and round to integers
ts_indices = np.rint(ts_beta * total_timesteps).astype(np.int32)
# Use numpy's unique function with return_index to efficiently find unique values
# while preserving order
unique_ts, indices = np.unique(ts_indices, return_index=True)
ordered_unique_ts = unique_ts[np.argsort(indices)]
# Map indices to sigma values efficiently
sigs = [float(model_sigmas[idx]) for idx in ordered_unique_ts]
# Add final sigma value of 0.0
sigs.append(0.0)
return torch.FloatTensor(sigs)
def calculate_sigmas(
model_sampling: torch.nn.Module, scheduler_name: str, steps: int
) -> torch.Tensor:
"""#### Calculate the sigmas for a model.
#### Args:
- `model_sampling` (torch.nn.Module): The model sampling module.
- `scheduler_name` (str): The scheduler name.
- `steps` (int): The number of steps.
#### Returns:
- `torch.Tensor`: The calculated sigmas.
"""
if scheduler_name == "karras":
sigmas = sampling_util.get_sigmas_karras(
n=steps,
sigma_min=float(model_sampling.sigma_min),
sigma_max=float(model_sampling.sigma_max),
)
elif scheduler_name == "normal":
sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "beta":
sigmas = beta_scheduler(model_sampling, steps)
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas
def prepare_noise(
latent_image: torch.Tensor, seed: int, noise_inds: list = None
) -> torch.Tensor:
"""#### Prepare noise for a latent image.
#### Args:
- `latent_image` (torch.Tensor): The latent image tensor.
- `seed` (int): The seed for random noise.
- `noise_inds` (list, optional): The noise indices. Defaults to None.
#### Returns:
- `torch.Tensor`: The prepared noise tensor.
"""
generator = torch.manual_seed(seed)
if noise_inds is None:
return torch.randn(
latent_image.size(),
dtype=latent_image.dtype,
layout=latent_image.layout,
generator=generator,
device="cpu",
)
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1] + 1):
noise = torch.randn(
[1] + list(latent_image.size())[1:],
dtype=latent_image.dtype,
layout=latent_image.layout,
generator=generator,
device="cpu",
)
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises