from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F import math from diffusers.utils import BaseOutput, logging from diffusers.utils.torch_utils import is_compiled_module, randn_tensor from diffusers import DiffusionPipeline from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput, StableVideoDiffusionPipeline from PIL import Image import cv2 logger = logging.get_logger(__name__) # pylint: disable=invalid-name class NormalCrafterPipeline(StableVideoDiffusionPipeline): def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, scale=1, image_size=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.video_processor.pil_to_numpy(image) # (0, 255) -> (0, 1) image = self.video_processor.numpy_to_pt(image) # (n, h, w, c) -> (n, c, h, w) # We normalize the image before resizing to match with the original implementation. # Then we unnormalize it after resizing. pixel_values = image B, C, H, W = pixel_values.shape patches = [pixel_values] # patches = [] for i in range(1, scale): num_patches_HW_this_level = i + 1 patch_H = H // num_patches_HW_this_level + 1 patch_W = W // num_patches_HW_this_level + 1 for j in range(num_patches_HW_this_level): for k in range(num_patches_HW_this_level): patches.append(pixel_values[:, :, j*patch_H:(j+1)*patch_H, k*patch_W:(k+1)*patch_W]) def encode_image(image): image = image * 2.0 - 1.0 if image_size is not None: image = _resize_with_antialiasing(image, image_size) else: image = _resize_with_antialiasing(image, (224, 224)) image = (image + 1.0) / 2.0 # Normalize the image with for CLIP input image = self.feature_extractor( images=image, do_normalize=True, do_center_crop=False, do_resize=False, do_rescale=False, return_tensors="pt", ).pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds if len(image_embeddings.shape) < 3: image_embeddings = image_embeddings.unsqueeze(1) return image_embeddings image_embeddings = [] for patch in patches: image_embeddings.append(encode_image(patch)) image_embeddings = torch.cat(image_embeddings, dim=1) # duplicate image embeddings for each generation per prompt, using mps friendly method # import pdb # pdb.set_trace() bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_image_embeddings = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) return image_embeddings def ecnode_video_vae(self, images, chunk_size: int = 14): if isinstance(images, list): width, height = images[0].size else: height, width = images[0].shape[:2] needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.vae.to(dtype=torch.float32) device = self._execution_device images = self.video_processor.preprocess_video(images, height=height, width=width).to(device, self.vae.dtype) # torch type in range(-1, 1) with (1,3,h,w) images = images.squeeze(0) # from (1, c, t, h, w) -> (c, t, h, w) images = images.permute(1,0,2,3) # c, t, h, w -> (t, c, h, w) video_latents = [] # chunk_size = 14 for i in range(0, images.shape[0], chunk_size): video_latents.append(self.vae.encode(images[i : i + chunk_size]).latent_dist.mode()) image_latents = torch.cat(video_latents) # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) return image_latents def pad_image(self, images, scale=64): def get_pad(newW, W): pad_W = (newW - W) // 2 if W % 2 == 1: pad_Ws = [pad_W, pad_W + 1] else: pad_Ws = [pad_W, pad_W] return pad_Ws if type(images[0]) is np.ndarray: H, W = images[0].shape[:2] else: W, H = images[0].size if W % scale == 0 and H % scale == 0: return images, None newW = int(np.ceil(W / scale) * scale) newH = int(np.ceil(H / scale) * scale) pad_Ws = get_pad(newW, W) pad_Hs = get_pad(newH, H) new_images = [] for image in images: if type(image) is np.ndarray: image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(1.,1.,1.)) new_images.append(image) else: image = np.array(image) image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(255,255,255)) new_images.append(Image.fromarray(image)) return new_images, pad_Hs+pad_Ws def unpad_image(self, v, pad_HWs): t, b, l, r = pad_HWs if t > 0 or b > 0: v = v[:, :, t:-b] if l > 0 or r > 0: v = v[:, :, :, l:-r] return v @torch.no_grad() def __call__( self, images: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], decode_chunk_size: Optional[int] = None, time_step_size: Optional[int] = 1, window_size: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, return_dict: bool = True ): images, pad_HWs = self.pad_image(images) # 0. Default height and width to unet width, height = images[0].size num_frames = len(images) # 1. Check inputs. Raise error if not correct self.check_inputs(images, height, width) # 2. Define call parameters batch_size = 1 device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. self._guidance_scale = 1.0 num_videos_per_prompt = 1 do_classifier_free_guidance = False num_inference_steps = 1 fps = 7 motion_bucket_id = 127 noise_aug_strength = 0. num_videos_per_prompt = 1 output_type = "np" data_keys = ["normal"] use_linear_merge = True determineTrain = True encode_image_scale = 1 encode_image_WH = None decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 7 # 3. Encode input image using using clip. (num_image * num_videos_per_prompt, 1, 1024) image_embeddings = self._encode_image(images, device, num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, scale=encode_image_scale, image_size=encode_image_WH) # 4. Encode input image using VAE image_latents = self.ecnode_video_vae(images, chunk_size=decode_chunk_size).to(image_embeddings.dtype) # image_latents [num_frames, channels, height, width] ->[1, num_frames, channels, height, width] image_latents = image_latents.unsqueeze(0) # 5. Get Added Time IDs added_time_ids = self._get_add_time_ids( fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt, do_classifier_free_guidance, ) added_time_ids = added_time_ids.to(device) # get Start and End frame idx for each window def get_ses(num_frames): ses = [] for i in range(0, num_frames, time_step_size): ses.append([i, i+window_size]) num_to_remain = 0 for se in ses: if se[1] > num_frames: continue num_to_remain += 1 ses = ses[:num_to_remain] if ses[-1][-1] < num_frames: ses.append([num_frames - window_size, num_frames]) return ses ses = get_ses(num_frames) pred = None for i, se in enumerate(ses): window_num_frames = window_size window_image_embeddings = image_embeddings[se[0]:se[1]] window_image_latents = image_latents[:, se[0]:se[1]] window_added_time_ids = added_time_ids # import pdb # pdb.set_trace() if i == 0 or time_step_size == window_size: to_replace_latents = None else: last_se = ses[i-1] num_to_replace_latents = last_se[1] - se[0] to_replace_latents = pred[:, -num_to_replace_latents:] latents = self.generate( num_inference_steps, device, batch_size, num_videos_per_prompt, window_num_frames, height, width, window_image_embeddings, generator, determineTrain, to_replace_latents, do_classifier_free_guidance, window_image_latents, window_added_time_ids ) # merge last_latents and current latents in overlap window if to_replace_latents is not None and use_linear_merge: num_img_condition = to_replace_latents.shape[1] weight = torch.linspace(1., 0., num_img_condition+2)[1:-1].to(device) weight = weight[None, :, None, None, None] latents[:, :num_img_condition] = to_replace_latents * weight + latents[:, :num_img_condition] * (1 - weight) if pred is None: pred = latents else: pred = torch.cat([pred[:, :se[0]], latents], dim=1) if not output_type == "latent": # cast back to fp16 if needed needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.vae.to(dtype=torch.float16) # latents has shape (1, num_frames, 12, h, w) def decode_latents(latents, num_frames, decode_chunk_size): frames = self.decode_latents(latents, num_frames, decode_chunk_size) # in range(-1, 1) frames = self.video_processor.postprocess_video(video=frames, output_type="np") frames = frames * 2 - 1 # from range(0, 1) -> range(-1, 1) return frames frames = decode_latents(pred, num_frames, decode_chunk_size) if pad_HWs is not None: frames = self.unpad_image(frames, pad_HWs) else: frames = pred self.maybe_free_model_hooks() if not return_dict: return frames return StableVideoDiffusionPipelineOutput(frames=frames) def generate( self, num_inference_steps, device, batch_size, num_videos_per_prompt, num_frames, height, width, image_embeddings, generator, determineTrain, to_replace_latents, do_classifier_free_guidance, image_latents, added_time_ids, latents=None, ): # 6. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_frames, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) if determineTrain: latents[...] = 0. # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # replace part of latents with conditons. ToDo: t embedding should also replace if to_replace_latents is not None: num_img_condition = to_replace_latents.shape[1] if not determineTrain: _noise = randn_tensor(to_replace_latents.shape, generator=generator, device=device, dtype=image_embeddings.dtype) noisy_to_replace_latents = self.scheduler.add_noise(to_replace_latents, _noise, t.unsqueeze(0)) latents[:, :num_img_condition] = noisy_to_replace_latents else: latents[:, :num_img_condition] = to_replace_latents # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) timestep = t # Concatenate image_latents over channels dimention latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) # predict the noise residual noise_pred = self.unet( latent_model_input, timestep, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step(noise_pred, t, latents) latents = scheduler_output.prev_sample if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() return latents # resizing utils # TODO: clean up later def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): h, w = input.shape[-2:] factors = (h / size[0], w / size[1]) # First, we have to determine sigma # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 sigmas = ( max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001), ) # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) # Make sure it is odd if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1] if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1 input = _gaussian_blur2d(input, ks, sigmas) output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) return output def _compute_padding(kernel_size): """Compute padding tuple.""" # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad if len(kernel_size) < 2: raise AssertionError(kernel_size) computed = [k - 1 for k in kernel_size] # for even kernels we need to do asymmetric padding :( out_padding = 2 * len(kernel_size) * [0] for i in range(len(kernel_size)): computed_tmp = computed[-(i + 1)] pad_front = computed_tmp // 2 pad_rear = computed_tmp - pad_front out_padding[2 * i + 0] = pad_front out_padding[2 * i + 1] = pad_rear return out_padding def _filter2d(input, kernel): # prepare kernel b, c, h, w = input.shape tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) height, width = tmp_kernel.shape[-2:] padding_shape: list[int] = _compute_padding([height, width]) input = torch.nn.functional.pad(input, padding_shape, mode="reflect") # kernel and input tensor reshape to align element-wise or batch-wise params tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) # convolve the tensor with the kernel. output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) out = output.view(b, c, h, w) return out def _gaussian(window_size: int, sigma): if isinstance(sigma, float): sigma = torch.tensor([[sigma]]) batch_size = sigma.shape[0] x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) if window_size % 2 == 0: x = x + 0.5 gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) return gauss / gauss.sum(-1, keepdim=True) def _gaussian_blur2d(input, kernel_size, sigma): if isinstance(sigma, tuple): sigma = torch.tensor([sigma], dtype=input.dtype) else: sigma = sigma.to(dtype=input.dtype) ky, kx = int(kernel_size[0]), int(kernel_size[1]) bs = sigma.shape[0] kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) out_x = _filter2d(input, kernel_x[..., None, :]) out = _filter2d(out_x, kernel_y[..., None]) return out