|
import torch |
|
import numpy as np |
|
|
|
import rembg |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from diffusers import DDIMScheduler |
|
from torchvision import transforms |
|
|
|
from step1x3d_geometry.utils.typing import * |
|
from step1x3d_geometry.utils.misc import get_device |
|
|
|
|
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: Optional[int] = None, |
|
device: Optional[Union[str, torch.device]] = None, |
|
timesteps: Optional[List[int]] = None, |
|
sigmas: Optional[List[float]] = None, |
|
**kwargs, |
|
): |
|
r""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
`num_inference_steps` and `sigmas` must be `None`. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None and sigmas is not None: |
|
raise ValueError( |
|
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" |
|
) |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set( |
|
inspect.signature(scheduler.set_timesteps).parameters.keys() |
|
) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
elif sigmas is not None: |
|
accept_sigmas = "sigmas" in set( |
|
inspect.signature(scheduler.set_timesteps).parameters.keys() |
|
) |
|
if not accept_sigmas: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" sigmas schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
@torch.no_grad() |
|
def ddim_sample( |
|
ddim_scheduler: DDIMScheduler, |
|
diffusion_model: torch.nn.Module, |
|
shape: Union[List[int], Tuple[int]], |
|
visual_cond: torch.FloatTensor, |
|
caption_cond: torch.FloatTensor, |
|
label_cond: torch.FloatTensor, |
|
steps: int, |
|
eta: float = 0.0, |
|
guidance_scale: float = 3.0, |
|
do_classifier_free_guidance: bool = True, |
|
generator: Optional[torch.Generator] = None, |
|
device: torch.device = "cuda:0", |
|
disable_prog: bool = True, |
|
): |
|
|
|
assert steps > 0, f"{steps} must > 0." |
|
|
|
|
|
if visual_cond is not None: |
|
bsz = visual_cond.shape[0] |
|
device = visual_cond.device |
|
dtype = visual_cond.dtype |
|
if caption_cond is not None: |
|
bsz = caption_cond.shape[0] |
|
device = caption_cond.device |
|
dtype = caption_cond.dtype |
|
if label_cond is not None: |
|
bsz = label_cond.shape[0] |
|
device = label_cond.device |
|
dtype = label_cond.dtype |
|
|
|
if do_classifier_free_guidance: |
|
bsz = bsz // 2 |
|
latents = torch.randn( |
|
(bsz, *shape), |
|
generator=generator, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
try: |
|
|
|
latents = latents * scheduler.init_noise_sigma |
|
except AttributeError: |
|
pass |
|
|
|
|
|
extra_step_kwargs = {"generator": generator} |
|
|
|
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
scheduler, |
|
steps, |
|
device, |
|
) |
|
if eta > 0: |
|
assert 0 <= eta <= 1, f"eta must be between [0, 1]. Got {eta}." |
|
assert ( |
|
scheduler.__class__.__name__ == "DDIMScheduler" |
|
), f"eta is only used with the DDIMScheduler." |
|
|
|
|
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
for i, t in enumerate( |
|
tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False) |
|
): |
|
|
|
latent_model_input = ( |
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
) |
|
|
|
|
|
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) |
|
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) |
|
noise_pred = diffusion_model.forward( |
|
latent_model_input, timestep_tensor, visual_cond, caption_cond, label_cond |
|
).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
|
|
latents = ddim_scheduler.step( |
|
noise_pred, t, latents, **extra_step_kwargs |
|
).prev_sample |
|
|
|
yield latents, t |
|
|
|
|
|
@torch.no_grad() |
|
def flow_sample( |
|
scheduler: DDIMScheduler, |
|
diffusion_model: torch.nn.Module, |
|
shape: Union[List[int], Tuple[int]], |
|
visual_cond: torch.FloatTensor, |
|
caption_cond: torch.FloatTensor, |
|
label_cond: torch.FloatTensor, |
|
steps: int, |
|
eta: float = 0.0, |
|
guidance_scale: float = 3.0, |
|
do_classifier_free_guidance: bool = True, |
|
generator: Optional[torch.Generator] = None, |
|
device: torch.device = "cuda:0", |
|
disable_prog: bool = True, |
|
): |
|
|
|
assert steps > 0, f"{steps} must > 0." |
|
|
|
|
|
if visual_cond is not None: |
|
bsz = visual_cond.shape[0] |
|
device = visual_cond.device |
|
dtype = visual_cond.dtype |
|
if caption_cond is not None: |
|
bsz = caption_cond.shape[0] |
|
device = caption_cond.device |
|
dtype = caption_cond.dtype |
|
if label_cond is not None: |
|
bsz = label_cond.shape[0] |
|
device = label_cond.device |
|
dtype = label_cond.dtype |
|
|
|
if do_classifier_free_guidance: |
|
bsz = bsz // 2 |
|
latents = torch.randn( |
|
(bsz, *shape), |
|
generator=generator, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
try: |
|
|
|
latents = latents * scheduler.init_noise_sigma |
|
except AttributeError: |
|
pass |
|
|
|
|
|
extra_step_kwargs = {"generator": generator} |
|
|
|
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
scheduler, |
|
steps + 1, |
|
device, |
|
) |
|
if eta > 0: |
|
assert 0 <= eta <= 1, f"eta must be between [0, 1]. Got {eta}." |
|
assert ( |
|
scheduler.__class__.__name__ == "DDIMScheduler" |
|
), f"eta is only used with the DDIMScheduler." |
|
|
|
|
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
distance = (timesteps[:-1] - timesteps[1:]) / scheduler.config.num_train_timesteps |
|
for i, t in enumerate( |
|
tqdm(timesteps[:-1], disable=disable_prog, desc="Flow Sampling:", leave=False) |
|
): |
|
|
|
latent_model_input = ( |
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
) |
|
|
|
timestep_tensor = torch.tensor([t], dtype=latents.dtype, device=device) |
|
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) |
|
noise_pred = diffusion_model.forward( |
|
latent_model_input, timestep_tensor, visual_cond, caption_cond, label_cond |
|
).sample |
|
if isinstance(noise_pred, tuple): |
|
noise_pred, layer_idx_list, ones_list, pred_c_list = noise_pred |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
|
|
latents = latents - distance[i] * noise_pred |
|
|
|
yield latents, t |
|
|
|
|
|
def compute_snr(noise_scheduler, timesteps): |
|
""" |
|
Computes SNR as per |
|
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
|
""" |
|
alphas_cumprod = noise_scheduler.alphas_cumprod |
|
sqrt_alphas_cumprod = alphas_cumprod**0.5 |
|
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
|
|
|
|
|
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
|
timesteps |
|
].float() |
|
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
|
alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
|
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
|
device=timesteps.device |
|
)[timesteps].float() |
|
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
|
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
|
|
|
|
|
snr = (alpha / sigma) ** 2 |
|
return snr |
|
|
|
|
|
def read_image(img, img_size=224): |
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
img_size, transforms.InterpolationMode.BICUBIC, antialias=True |
|
), |
|
transforms.CenterCrop(img_size), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
rgb = Image.open(img) |
|
rgb = transform(rgb)[:3, ...].permute(1, 2, 0) |
|
return rgb |
|
|
|
|
|
def preprocess_image( |
|
images_pil: List[Image.Image], |
|
force: bool = False, |
|
background_color: List[int] = [255, 255, 255], |
|
foreground_ratio: float = 0.95, |
|
): |
|
r""" |
|
Crop and remote the background of the input image |
|
Args: |
|
image_pil (`List[PIL.Image.Image]`): |
|
List of `PIL.Image.Image` objects representing the input image. |
|
force (`bool`, *optional*, defaults to `False`): |
|
Whether to force remove the background even if the image has an alpha channel. |
|
Returns: |
|
`List[PIL.Image.Image]`: List of `PIL.Image.Image` objects representing the preprocessed image. |
|
""" |
|
preprocessed_images = [] |
|
for i in range(len(images_pil)): |
|
image = images_pil[i] |
|
width, height, size = image.width, image.height, image.size |
|
do_remove = True |
|
if image.mode == "RGBA" and image.getextrema()[3][0] < 255: |
|
|
|
print( |
|
"alhpa channl not empty, skip remove background, using alpha channel as mask" |
|
) |
|
do_remove = False |
|
do_remove = do_remove or force |
|
if do_remove: |
|
image = rembg.remove(image) |
|
|
|
|
|
alpha = image.split()[-1] |
|
bboxs = alpha.getbbox() |
|
x1, y1, x2, y2 = bboxs |
|
dy, dx = y2 - y1, x2 - x1 |
|
s = min(height * foreground_ratio / dy, width * foreground_ratio / dx) |
|
Ht, Wt = int(dy * s), int(dx * s) |
|
|
|
background = Image.new("RGBA", image.size, (*background_color, 255)) |
|
image = Image.alpha_composite(background, image) |
|
image = image.crop(alpha.getbbox()) |
|
alpha = alpha.crop(alpha.getbbox()) |
|
|
|
|
|
new_size = tuple(int(dim * foreground_ratio) for dim in size) |
|
|
|
resized_image = image.resize((Wt, Ht)) |
|
resized_alpha = alpha.resize((Wt, Ht)) |
|
|
|
padded_image = Image.new("RGB", size, tuple(background_color)) |
|
padded_alpha = Image.new("L", size, (0)) |
|
paste_position = ( |
|
(width - resized_image.width) // 2, |
|
(height - resized_image.height) // 2, |
|
) |
|
padded_image.paste(resized_image, paste_position) |
|
padded_alpha.paste(resized_alpha, paste_position) |
|
|
|
|
|
width, height = padded_image.size |
|
if width == height: |
|
padded_image.putalpha(padded_alpha) |
|
preprocessed_images.append(padded_image) |
|
continue |
|
new_size = (max(width, height), max(width, height)) |
|
new_image = Image.new("RGB", new_size, tuple(background_color)) |
|
new_alpha = Image.new("L", new_size, (0)) |
|
paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) |
|
new_image.paste(padded_image, paste_position) |
|
new_alpha.paste(padded_alpha, paste_position) |
|
new_image.putalpha(new_alpha) |
|
preprocessed_images.append(new_image) |
|
|
|
return preprocessed_images |
|
|