Spaces:
Runtime error
Runtime error
File size: 5,527 Bytes
19c4ddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from typing import Any, Callable, Dict, Optional, List
import torch
import torch.nn as nn
from .gaussian_diffusion import GaussianDiffusion
from .k_diffusion import karras_sample, karras_sample_addition_condition
DEFAULT_KARRAS_STEPS = 64
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
DEFAULT_KARRAS_SIGMA_MAX = 160
DEFAULT_KARRAS_S_CHURN = 0.0
def uncond_guide_model(
model: Callable[..., torch.Tensor], scale: float
) -> Callable[..., torch.Tensor]:
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
cond_out, uncond_out = torch.chunk(model_out, 2, dim=0)
cond_out = uncond_out + scale * (cond_out - uncond_out)
return torch.cat([cond_out, cond_out], dim=0)
return model_fn
def sample_latents(
*,
batch_size: int,
model: nn.Module,
diffusion: GaussianDiffusion,
model_kwargs: Dict[str, Any],
guidance_scale: float,
clip_denoised: bool,
use_fp16: bool,
use_karras: bool,
karras_steps: int,
sigma_min: float,
sigma_max: float,
s_churn: float,
device: Optional[torch.device] = None,
progress: bool = False,
initial_noise: Optional[torch.Tensor] = None,
) -> (torch.Tensor, List[torch.Tensor]):
sample_shape = (batch_size, model.d_latent)
if device is None:
device = next(model.parameters()).device
if hasattr(model, "cached_model_kwargs"):
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
if guidance_scale != 1.0 and guidance_scale != 0.0:
for k, v in model_kwargs.copy().items():
# print(k, v.shape)
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
sample_shape = (batch_size, model.d_latent)
with torch.autocast(device_type=device.type, enabled=use_fp16):
if use_karras:
samples, sample_sequence = karras_sample(
diffusion=diffusion,
model=model,
shape=sample_shape,
steps=karras_steps,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
device=device,
sigma_min=sigma_min,
sigma_max=sigma_max,
s_churn=s_churn,
guidance_scale=guidance_scale,
progress=progress,
initial_noise=initial_noise,
)
else:
internal_batch_size = batch_size
if guidance_scale != 1.0:
model = uncond_guide_model(model, guidance_scale)
internal_batch_size *= 2
samples = diffusion.p_sample_loop(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=model_kwargs,
device=device,
clip_denoised=clip_denoised,
progress=progress,
)
return samples
def sample_latents_with_additional_latent(
*,
batch_size: int,
model: nn.Module,
diffusion: GaussianDiffusion,
model_kwargs: Dict[str, Any],
text_guidance_scale: float,
image_guidance_scale: float,
clip_denoised: bool,
use_fp16: bool,
use_karras: bool,
karras_steps: int,
sigma_min: float,
sigma_max: float,
s_churn: float,
device: Optional[torch.device] = None,
progress: bool = False,
condition_latent: Optional[torch.Tensor] = None,
) -> (torch.Tensor, List[torch.Tensor]):
if device is None:
device = next(model.parameters()).device
if hasattr(model, "cached_model_kwargs"):
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0):
for k, v in model_kwargs.copy().items():
# print(k, v.shape)
model_kwargs[k] = torch.cat([v, torch.zeros_like(v), torch.zeros_like(v)], dim=0)
condition_latent = torch.cat([condition_latent, condition_latent, torch.zeros_like(condition_latent)], dim=0)
sample_shape = (batch_size, model.d_latent)
# print("sample_shape", sample_shape)
with torch.autocast(device_type=device.type, enabled=use_fp16):
if use_karras:
samples, samples_squence = karras_sample_addition_condition(
diffusion=diffusion,
model=model,
shape=sample_shape,
steps=karras_steps,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
device=device,
sigma_min=sigma_min,
sigma_max=sigma_max,
s_churn=s_churn,
text_guidance_scale=text_guidance_scale,
image_guidance_scale=image_guidance_scale,
progress=progress,
condition_latent=condition_latent,
)
else:
internal_batch_size = batch_size
if text_guidance_scale != 1.0:
model = uncond_guide_model(model, text_guidance_scale)
internal_batch_size *= 2
samples = diffusion.p_sample_loop(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=model_kwargs,
device=device,
clip_denoised=clip_denoised,
progress=progress,
)
return samples |