Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Callable, Dict, List, Optional, Union | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| from diffusers.utils import BaseOutput, logging | |
| from diffusers.utils.torch_utils import is_compiled_module, randn_tensor | |
| from diffusers import DiffusionPipeline | |
| from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput, StableVideoDiffusionPipeline | |
| from PIL import Image | |
| import cv2 | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class NormalCrafterPipeline(StableVideoDiffusionPipeline): | |
| def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, scale=1, image_size=None): | |
| dtype = next(self.image_encoder.parameters()).dtype | |
| if not isinstance(image, torch.Tensor): | |
| image = self.video_processor.pil_to_numpy(image) # (0, 255) -> (0, 1) | |
| image = self.video_processor.numpy_to_pt(image) # (n, h, w, c) -> (n, c, h, w) | |
| # We normalize the image before resizing to match with the original implementation. | |
| # Then we unnormalize it after resizing. | |
| pixel_values = image | |
| B, C, H, W = pixel_values.shape | |
| patches = [pixel_values] | |
| # patches = [] | |
| for i in range(1, scale): | |
| num_patches_HW_this_level = i + 1 | |
| patch_H = H // num_patches_HW_this_level + 1 | |
| patch_W = W // num_patches_HW_this_level + 1 | |
| for j in range(num_patches_HW_this_level): | |
| for k in range(num_patches_HW_this_level): | |
| patches.append(pixel_values[:, :, j*patch_H:(j+1)*patch_H, k*patch_W:(k+1)*patch_W]) | |
| def encode_image(image): | |
| image = image * 2.0 - 1.0 | |
| if image_size is not None: | |
| image = _resize_with_antialiasing(image, image_size) | |
| else: | |
| image = _resize_with_antialiasing(image, (224, 224)) | |
| image = (image + 1.0) / 2.0 | |
| # Normalize the image with for CLIP input | |
| image = self.feature_extractor( | |
| images=image, | |
| do_normalize=True, | |
| do_center_crop=False, | |
| do_resize=False, | |
| do_rescale=False, | |
| return_tensors="pt", | |
| ).pixel_values | |
| image = image.to(device=device, dtype=dtype) | |
| image_embeddings = self.image_encoder(image).image_embeds | |
| if len(image_embeddings.shape) < 3: | |
| image_embeddings = image_embeddings.unsqueeze(1) | |
| return image_embeddings | |
| image_embeddings = [] | |
| for patch in patches: | |
| image_embeddings.append(encode_image(patch)) | |
| image_embeddings = torch.cat(image_embeddings, dim=1) | |
| # duplicate image embeddings for each generation per prompt, using mps friendly method | |
| # import pdb | |
| # pdb.set_trace() | |
| bs_embed, seq_len, _ = image_embeddings.shape | |
| image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) | |
| image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) | |
| if do_classifier_free_guidance: | |
| negative_image_embeddings = torch.zeros_like(image_embeddings) | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) | |
| return image_embeddings | |
| def ecnode_video_vae(self, images, chunk_size: int = 14): | |
| if isinstance(images, list): | |
| width, height = images[0].size | |
| else: | |
| height, width = images[0].shape[:2] | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float32) | |
| device = self._execution_device | |
| images = self.video_processor.preprocess_video(images, height=height, width=width).to(device, self.vae.dtype) # torch type in range(-1, 1) with (1,3,h,w) | |
| images = images.squeeze(0) # from (1, c, t, h, w) -> (c, t, h, w) | |
| images = images.permute(1,0,2,3) # c, t, h, w -> (t, c, h, w) | |
| video_latents = [] | |
| # chunk_size = 14 | |
| for i in range(0, images.shape[0], chunk_size): | |
| video_latents.append(self.vae.encode(images[i : i + chunk_size]).latent_dist.mode()) | |
| image_latents = torch.cat(video_latents) | |
| # cast back to fp16 if needed | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| return image_latents | |
| def pad_image(self, images, scale=64): | |
| def get_pad(newW, W): | |
| pad_W = (newW - W) // 2 | |
| if W % 2 == 1: | |
| pad_Ws = [pad_W, pad_W + 1] | |
| else: | |
| pad_Ws = [pad_W, pad_W] | |
| return pad_Ws | |
| if type(images[0]) is np.ndarray: | |
| H, W = images[0].shape[:2] | |
| else: | |
| W, H = images[0].size | |
| if W % scale == 0 and H % scale == 0: | |
| return images, None | |
| newW = int(np.ceil(W / scale) * scale) | |
| newH = int(np.ceil(H / scale) * scale) | |
| pad_Ws = get_pad(newW, W) | |
| pad_Hs = get_pad(newH, H) | |
| new_images = [] | |
| for image in images: | |
| if type(image) is np.ndarray: | |
| image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(1.,1.,1.)) | |
| new_images.append(image) | |
| else: | |
| image = np.array(image) | |
| image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(255,255,255)) | |
| new_images.append(Image.fromarray(image)) | |
| return new_images, pad_Hs+pad_Ws | |
| def unpad_image(self, v, pad_HWs): | |
| t, b, l, r = pad_HWs | |
| if t > 0 or b > 0: | |
| v = v[:, :, t:-b] | |
| if l > 0 or r > 0: | |
| v = v[:, :, :, l:-r] | |
| return v | |
| def __call__( | |
| self, | |
| images: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], | |
| decode_chunk_size: Optional[int] = None, | |
| time_step_size: Optional[int] = 1, | |
| window_size: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| return_dict: bool = True | |
| ): | |
| images, pad_HWs = self.pad_image(images) | |
| # 0. Default height and width to unet | |
| width, height = images[0].size | |
| num_frames = len(images) | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs(images, height, width) | |
| # 2. Define call parameters | |
| batch_size = 1 | |
| device = self._execution_device | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| self._guidance_scale = 1.0 | |
| num_videos_per_prompt = 1 | |
| do_classifier_free_guidance = False | |
| num_inference_steps = 1 | |
| fps = 7 | |
| motion_bucket_id = 127 | |
| noise_aug_strength = 0. | |
| num_videos_per_prompt = 1 | |
| output_type = "np" | |
| data_keys = ["normal"] | |
| use_linear_merge = True | |
| determineTrain = True | |
| encode_image_scale = 1 | |
| encode_image_WH = None | |
| decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 7 | |
| # 3. Encode input image using using clip. (num_image * num_videos_per_prompt, 1, 1024) | |
| image_embeddings = self._encode_image(images, device, num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, scale=encode_image_scale, image_size=encode_image_WH) | |
| # 4. Encode input image using VAE | |
| image_latents = self.ecnode_video_vae(images, chunk_size=decode_chunk_size).to(image_embeddings.dtype) | |
| # image_latents [num_frames, channels, height, width] ->[1, num_frames, channels, height, width] | |
| image_latents = image_latents.unsqueeze(0) | |
| # 5. Get Added Time IDs | |
| added_time_ids = self._get_add_time_ids( | |
| fps, | |
| motion_bucket_id, | |
| noise_aug_strength, | |
| image_embeddings.dtype, | |
| batch_size, | |
| num_videos_per_prompt, | |
| do_classifier_free_guidance, | |
| ) | |
| added_time_ids = added_time_ids.to(device) | |
| # get Start and End frame idx for each window | |
| def get_ses(num_frames): | |
| ses = [] | |
| for i in range(0, num_frames, time_step_size): | |
| ses.append([i, i+window_size]) | |
| num_to_remain = 0 | |
| for se in ses: | |
| if se[1] > num_frames: | |
| continue | |
| num_to_remain += 1 | |
| ses = ses[:num_to_remain] | |
| if ses[-1][-1] < num_frames: | |
| ses.append([num_frames - window_size, num_frames]) | |
| return ses | |
| ses = get_ses(num_frames) | |
| pred = None | |
| for i, se in enumerate(ses): | |
| window_num_frames = window_size | |
| window_image_embeddings = image_embeddings[se[0]:se[1]] | |
| window_image_latents = image_latents[:, se[0]:se[1]] | |
| window_added_time_ids = added_time_ids | |
| # import pdb | |
| # pdb.set_trace() | |
| if i == 0 or time_step_size == window_size: | |
| to_replace_latents = None | |
| else: | |
| last_se = ses[i-1] | |
| num_to_replace_latents = last_se[1] - se[0] | |
| to_replace_latents = pred[:, -num_to_replace_latents:] | |
| latents = self.generate( | |
| num_inference_steps, | |
| device, | |
| batch_size, | |
| num_videos_per_prompt, | |
| window_num_frames, | |
| height, | |
| width, | |
| window_image_embeddings, | |
| generator, | |
| determineTrain, | |
| to_replace_latents, | |
| do_classifier_free_guidance, | |
| window_image_latents, | |
| window_added_time_ids | |
| ) | |
| # merge last_latents and current latents in overlap window | |
| if to_replace_latents is not None and use_linear_merge: | |
| num_img_condition = to_replace_latents.shape[1] | |
| weight = torch.linspace(1., 0., num_img_condition+2)[1:-1].to(device) | |
| weight = weight[None, :, None, None, None] | |
| latents[:, :num_img_condition] = to_replace_latents * weight + latents[:, :num_img_condition] * (1 - weight) | |
| if pred is None: | |
| pred = latents | |
| else: | |
| pred = torch.cat([pred[:, :se[0]], latents], dim=1) | |
| if not output_type == "latent": | |
| # cast back to fp16 if needed | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| # latents has shape (1, num_frames, 12, h, w) | |
| def decode_latents(latents, num_frames, decode_chunk_size): | |
| frames = self.decode_latents(latents, num_frames, decode_chunk_size) # in range(-1, 1) | |
| frames = self.video_processor.postprocess_video(video=frames, output_type="np") | |
| frames = frames * 2 - 1 # from range(0, 1) -> range(-1, 1) | |
| return frames | |
| frames = decode_latents(pred, num_frames, decode_chunk_size) | |
| if pad_HWs is not None: | |
| frames = self.unpad_image(frames, pad_HWs) | |
| else: | |
| frames = pred | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return frames | |
| return StableVideoDiffusionPipelineOutput(frames=frames) | |
| def generate( | |
| self, | |
| num_inference_steps, | |
| device, | |
| batch_size, | |
| num_videos_per_prompt, | |
| num_frames, | |
| height, | |
| width, | |
| image_embeddings, | |
| generator, | |
| determineTrain, | |
| to_replace_latents, | |
| do_classifier_free_guidance, | |
| image_latents, | |
| added_time_ids, | |
| latents=None, | |
| ): | |
| # 6. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # 7. Prepare latent variables | |
| num_channels_latents = self.unet.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_videos_per_prompt, | |
| num_frames, | |
| num_channels_latents, | |
| height, | |
| width, | |
| image_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| if determineTrain: | |
| latents[...] = 0. | |
| # 8. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| self._num_timesteps = len(timesteps) | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| # replace part of latents with conditons. ToDo: t embedding should also replace | |
| if to_replace_latents is not None: | |
| num_img_condition = to_replace_latents.shape[1] | |
| if not determineTrain: | |
| _noise = randn_tensor(to_replace_latents.shape, generator=generator, device=device, dtype=image_embeddings.dtype) | |
| noisy_to_replace_latents = self.scheduler.add_noise(to_replace_latents, _noise, t.unsqueeze(0)) | |
| latents[:, :num_img_condition] = noisy_to_replace_latents | |
| else: | |
| latents[:, :num_img_condition] = to_replace_latents | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| timestep = t | |
| # Concatenate image_latents over channels dimention | |
| latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=image_embeddings, | |
| added_time_ids=added_time_ids, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| scheduler_output = self.scheduler.step(noise_pred, t, latents) | |
| latents = scheduler_output.prev_sample | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| return latents | |
| # resizing utils | |
| # TODO: clean up later | |
| def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): | |
| h, w = input.shape[-2:] | |
| factors = (h / size[0], w / size[1]) | |
| # First, we have to determine sigma | |
| # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 | |
| sigmas = ( | |
| max((factors[0] - 1.0) / 2.0, 0.001), | |
| max((factors[1] - 1.0) / 2.0, 0.001), | |
| ) | |
| # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma | |
| # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 | |
| # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now | |
| ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) | |
| # Make sure it is odd | |
| if (ks[0] % 2) == 0: | |
| ks = ks[0] + 1, ks[1] | |
| if (ks[1] % 2) == 0: | |
| ks = ks[0], ks[1] + 1 | |
| input = _gaussian_blur2d(input, ks, sigmas) | |
| output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) | |
| return output | |
| def _compute_padding(kernel_size): | |
| """Compute padding tuple.""" | |
| # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) | |
| # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad | |
| if len(kernel_size) < 2: | |
| raise AssertionError(kernel_size) | |
| computed = [k - 1 for k in kernel_size] | |
| # for even kernels we need to do asymmetric padding :( | |
| out_padding = 2 * len(kernel_size) * [0] | |
| for i in range(len(kernel_size)): | |
| computed_tmp = computed[-(i + 1)] | |
| pad_front = computed_tmp // 2 | |
| pad_rear = computed_tmp - pad_front | |
| out_padding[2 * i + 0] = pad_front | |
| out_padding[2 * i + 1] = pad_rear | |
| return out_padding | |
| def _filter2d(input, kernel): | |
| # prepare kernel | |
| b, c, h, w = input.shape | |
| tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) | |
| tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) | |
| height, width = tmp_kernel.shape[-2:] | |
| padding_shape: list[int] = _compute_padding([height, width]) | |
| input = torch.nn.functional.pad(input, padding_shape, mode="reflect") | |
| # kernel and input tensor reshape to align element-wise or batch-wise params | |
| tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) | |
| input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) | |
| # convolve the tensor with the kernel. | |
| output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) | |
| out = output.view(b, c, h, w) | |
| return out | |
| def _gaussian(window_size: int, sigma): | |
| if isinstance(sigma, float): | |
| sigma = torch.tensor([[sigma]]) | |
| batch_size = sigma.shape[0] | |
| x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) | |
| if window_size % 2 == 0: | |
| x = x + 0.5 | |
| gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) | |
| return gauss / gauss.sum(-1, keepdim=True) | |
| def _gaussian_blur2d(input, kernel_size, sigma): | |
| if isinstance(sigma, tuple): | |
| sigma = torch.tensor([sigma], dtype=input.dtype) | |
| else: | |
| sigma = sigma.to(dtype=input.dtype) | |
| ky, kx = int(kernel_size[0]), int(kernel_size[1]) | |
| bs = sigma.shape[0] | |
| kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) | |
| kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) | |
| out_x = _filter2d(input, kernel_x[..., None, :]) | |
| out = _filter2d(out_x, kernel_y[..., None]) | |
| return out | |