Spaces:
Running
on
Zero
Running
on
Zero
import collections | |
import logging | |
import numpy as np | |
import scipy | |
import torch | |
from modules.sample import sampling_util | |
def calculate_start_end_timesteps(model: torch.nn.Module, conds: list) -> None: | |
"""#### Calculate the start and end timesteps for a model. | |
#### Args: | |
- `model` (torch.nn.Module): The input model. | |
- `conds` (list): The list of conditions. | |
""" | |
s = model.model_sampling | |
for t in range(len(conds)): | |
x = conds[t] | |
timestep_start = None | |
timestep_end = None | |
if "start_percent" in x: | |
timestep_start = s.percent_to_sigma(x["start_percent"]) | |
if "end_percent" in x: | |
timestep_end = s.percent_to_sigma(x["end_percent"]) | |
if (timestep_start is not None) or (timestep_end is not None): | |
n = x.copy() | |
if timestep_start is not None: | |
n["timestep_start"] = timestep_start | |
if timestep_end is not None: | |
n["timestep_end"] = timestep_end | |
conds[t] = n | |
def pre_run_control(model: torch.nn.Module, conds: list) -> None: | |
"""#### Pre-run control for a model. | |
#### Args: | |
- `model` (torch.nn.Module): The input model. | |
- `conds` (list): The list of conditions. | |
""" | |
s = model.model_sampling | |
for t in range(len(conds)): | |
x = conds[t] | |
def percent_to_timestep_function(a): | |
return s.percent_to_sigma(a) | |
if "control" in x: | |
x["control"].pre_run(model, percent_to_timestep_function) | |
def apply_empty_x_to_equal_area( | |
conds: list, uncond: list, name: str, uncond_fill_func: callable | |
) -> None: | |
"""#### Apply empty x to equal area. | |
#### Args: | |
- `conds` (list): The list of conditions. | |
- `uncond` (list): The list of unconditional conditions. | |
- `name` (str): The name. | |
- `uncond_fill_func` (callable): The unconditional fill function. | |
""" | |
cond_cnets = [] | |
cond_other = [] | |
uncond_cnets = [] | |
uncond_other = [] | |
for t in range(len(conds)): | |
x = conds[t] | |
if "area" not in x: | |
if name in x and x[name] is not None: | |
cond_cnets.append(x[name]) | |
else: | |
cond_other.append((x, t)) | |
for t in range(len(uncond)): | |
x = uncond[t] | |
if "area" not in x: | |
if name in x and x[name] is not None: | |
uncond_cnets.append(x[name]) | |
else: | |
uncond_other.append((x, t)) | |
if len(uncond_cnets) > 0: | |
return | |
for x in range(len(cond_cnets)): | |
temp = uncond_other[x % len(uncond_other)] | |
o = temp[0] | |
if name in o and o[name] is not None: | |
n = o.copy() | |
n[name] = uncond_fill_func(cond_cnets, x) | |
uncond += [n] | |
else: | |
n = o.copy() | |
n[name] = uncond_fill_func(cond_cnets, x) | |
uncond[temp[1]] = n | |
# Define the namedtuple class once outside the function for reuse | |
CondObj = collections.namedtuple( | |
"cond_obj", ["input_x", "mult", "conditioning", "area", "control", "patches"] | |
) | |
def get_area_and_mult(conds: dict, x_in: torch.Tensor, timestep_in: int) -> CondObj: | |
"""#### Get the area and multiplier. | |
#### Args: | |
- `conds` (dict): The conditions. | |
- `x_in` (torch.Tensor): The input tensor. | |
- `timestep_in` (int): The timestep. | |
#### Returns: | |
- `collections.namedtuple`: The area and multiplier. | |
""" | |
# Cache shape information to avoid repeated access | |
x_shape = x_in.shape | |
# Define area dimensions in one operation | |
area = (x_shape[2], x_shape[3], 0, 0) | |
# Extract input region efficiently | |
# Since area[2] and area[3] are 0, this is essentially taking the full tensor | |
# But we maintain the slice operation for consistency | |
input_x = x_in[:, :, : area[0], : area[1]] | |
# Create multiplier tensor directly without intermediate mask creation | |
# This avoids an unnecessary tensor allocation and multiplication | |
mult = torch.ones_like(input_x) # strength is 1.0, so just create ones directly | |
# Prepare conditioning dictionary with cached device and batch_size | |
conditioning = {} | |
model_conds = conds["model_conds"] | |
batch_size = x_shape[0] | |
device = x_in.device | |
# Process conditions with cached parameters | |
for c in model_conds: | |
conditioning[c] = model_conds[c].process_cond( | |
batch_size=batch_size, device=device, area=area | |
) | |
# Get control directly without redundant variable assignment | |
control = conds.get("control", None) | |
patches = None | |
# Use the pre-defined namedtuple class instead of creating it every call | |
return CondObj(input_x, mult, conditioning, area, control, patches) | |
def normal_scheduler( | |
model_sampling: torch.nn.Module, steps: int, sgm: bool = False, floor: bool = False | |
) -> torch.FloatTensor: | |
"""#### Create a normal scheduler. | |
#### Args: | |
- `model_sampling` (torch.nn.Module): The model sampling module. | |
- `steps` (int): The number of steps. | |
- `sgm` (bool, optional): Whether to use SGM. Defaults to False. | |
- `floor` (bool, optional): Whether to floor the values. Defaults to False. | |
#### Returns: | |
- `torch.FloatTensor`: The scheduler. | |
""" | |
s = model_sampling | |
start = s.timestep(s.sigma_max) | |
end = s.timestep(s.sigma_min) | |
timesteps = torch.linspace(start, end, steps) | |
sigs = [] | |
for x in range(len(timesteps)): | |
ts = timesteps[x] | |
sigs.append(s.sigma(ts)) | |
sigs += [0.0] | |
return torch.FloatTensor(sigs) | |
def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor: | |
"""#### Create a simple scheduler. | |
#### Args: | |
- `model_sampling` (torch.nn.Module): The model sampling module. | |
- `steps` (int): The number of steps. | |
#### Returns: | |
- `torch.FloatTensor`: The scheduler. | |
""" | |
s = model_sampling | |
sigs = [] | |
ss = len(s.sigmas) / steps | |
for x in range(steps): | |
sigs += [float(s.sigmas[-(1 + int(x * ss))])] | |
sigs += [0.0] | |
return torch.FloatTensor(sigs) | |
# Implemented based on: https://arxiv.org/abs/2407.12173 | |
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6): | |
"""Creates a beta scheduler for noise levels based on the beta distribution. | |
This optimized implementation efficiently computes sigmas using the beta | |
distribution and caches calculations where possible. | |
Args: | |
model_sampling: Model sampling module | |
steps: Number of steps | |
alpha: Alpha parameter for beta distribution | |
beta: Beta parameter for beta distribution | |
Returns: | |
torch.FloatTensor: Tensor of sigma values for each step | |
""" | |
# Calculate total timesteps once | |
total_timesteps = len(model_sampling.sigmas) - 1 | |
# Create a cache dictionary for reused values | |
model_sigmas = model_sampling.sigmas | |
# Generate evenly spaced values in [0,1) interval | |
ts_normalized = np.linspace(0, 1, steps, endpoint=False) | |
# Apply beta inverse CDF to get sampled time points - vectorized operation | |
ts_beta = scipy.stats.beta.ppf(1 - ts_normalized, alpha, beta) | |
# Scale to timestep indices and round to integers | |
ts_indices = np.rint(ts_beta * total_timesteps).astype(np.int32) | |
# Use numpy's unique function with return_index to efficiently find unique values | |
# while preserving order | |
unique_ts, indices = np.unique(ts_indices, return_index=True) | |
ordered_unique_ts = unique_ts[np.argsort(indices)] | |
# Map indices to sigma values efficiently | |
sigs = [float(model_sigmas[idx]) for idx in ordered_unique_ts] | |
# Add final sigma value of 0.0 | |
sigs.append(0.0) | |
return torch.FloatTensor(sigs) | |
def calculate_sigmas( | |
model_sampling: torch.nn.Module, scheduler_name: str, steps: int | |
) -> torch.Tensor: | |
"""#### Calculate the sigmas for a model. | |
#### Args: | |
- `model_sampling` (torch.nn.Module): The model sampling module. | |
- `scheduler_name` (str): The scheduler name. | |
- `steps` (int): The number of steps. | |
#### Returns: | |
- `torch.Tensor`: The calculated sigmas. | |
""" | |
if scheduler_name == "karras": | |
sigmas = sampling_util.get_sigmas_karras( | |
n=steps, | |
sigma_min=float(model_sampling.sigma_min), | |
sigma_max=float(model_sampling.sigma_max), | |
) | |
elif scheduler_name == "normal": | |
sigmas = normal_scheduler(model_sampling, steps) | |
elif scheduler_name == "simple": | |
sigmas = simple_scheduler(model_sampling, steps) | |
elif scheduler_name == "beta": | |
sigmas = beta_scheduler(model_sampling, steps) | |
else: | |
logging.error("error invalid scheduler {}".format(scheduler_name)) | |
return sigmas | |
def prepare_noise( | |
latent_image: torch.Tensor, seed: int, noise_inds: list = None | |
) -> torch.Tensor: | |
"""#### Prepare noise for a latent image. | |
#### Args: | |
- `latent_image` (torch.Tensor): The latent image tensor. | |
- `seed` (int): The seed for random noise. | |
- `noise_inds` (list, optional): The noise indices. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The prepared noise tensor. | |
""" | |
generator = torch.manual_seed(seed) | |
if noise_inds is None: | |
return torch.randn( | |
latent_image.size(), | |
dtype=latent_image.dtype, | |
layout=latent_image.layout, | |
generator=generator, | |
device="cpu", | |
) | |
unique_inds, inverse = np.unique(noise_inds, return_inverse=True) | |
noises = [] | |
for i in range(unique_inds[-1] + 1): | |
noise = torch.randn( | |
[1] + list(latent_image.size())[1:], | |
dtype=latent_image.dtype, | |
layout=latent_image.layout, | |
generator=generator, | |
device="cpu", | |
) | |
if i in unique_inds: | |
noises.append(noise) | |
noises = [noises[i] for i in inverse] | |
noises = torch.cat(noises, axis=0) | |
return noises | |