# Author: Bingxin Ke # Last modified: 2023-12-15 from typing import List, Dict, Union import torch from torch.utils.data import DataLoader, TensorDataset import numpy as np from tqdm.auto import tqdm from PIL import Image from diffusers import ( DiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, ) from diffusers.utils import BaseOutput from transformers import CLIPTextModel, CLIPTokenizer from .util.image_util import chw2hwc, colorize_depth_maps, resize_max_res from .util.batchsize import find_batch_size from .util.ensemble import ensemble_depths class MarigoldDepthOutput(BaseOutput): """ Output class for Marigold monocular depth prediction pipeline. Args: depth_np (np.ndarray): Predicted depth map, with depth values in the range of [0, 1] depth_colored (PIL.Image.Image): Colorized depth map, with the shape of [3, H, W] and values in [0, 1] uncertainty (None` or `np.ndarray): Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. """ depth_np: np.ndarray depth_colored: Image.Image uncertainty: Union[None, np.ndarray] class MarigoldPipeline(DiffusionPipeline): """ Pipeline for monocular depth estimation using Marigold: https://arxiv.org/abs/2312.02145. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: unet (UNet2DConditionModel): Conditional U-Net to denoise the depth latent, conditioned on image latent. vae (AutoencoderKL): Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps to and from latent representations. scheduler (DDIMScheduler): A scheduler to be used in combination with `unet` to denoise the encoded image latents. text_encoder (CLIPTextModel): Text-encoder, for empty text embedding. tokenizer (CLIPTokenizer): CLIP tokenizer. """ rgb_latent_scale_factor = 0.18215 depth_latent_scale_factor = 0.18215 def __init__( self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: DDIMScheduler, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, ): super().__init__() self.register_modules( unet=unet, vae=vae, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, ) self.empty_text_embed = None @torch.no_grad() def __call__( self, input_image: Image, denoising_steps: int = 10, ensemble_size: int = 10, processing_res: int = 768, match_input_res: bool = True, batch_size: int = 0, color_map: str = "Spectral", show_progress_bar: bool = True, ensemble_kwargs: Dict = None, ) -> MarigoldDepthOutput: """ Function invoked when calling the pipeline. Args: input_image (Image): Input RGB (or gray-scale) image. processing_res (int, optional): Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768. match_input_res (bool, optional): Resize depth prediction to match input resolution. Only valid if `limit_input_res` is not None. Defaults to True. denoising_steps (int, optional): Number of diffusion denoising steps (DDIM) during inference. Defaults to 10. ensemble_size (int, optional): Number of predictions to be ensembled. Defaults to 10. batch_size (int, optional): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0. show_progress_bar (bool, optional): Display a progress bar of diffusion denoising. Defaults to True. color_map (str, optional): Colormap used to colorize the depth map. Defaults to "Spectral". ensemble_kwargs () Returns: `MarigoldDepthOutput` """ device = self.device input_size = input_image.size if not match_input_res: assert ( processing_res is not None ), "Value error: `resize_output_back` is only valid with " assert processing_res >= 0 assert denoising_steps >= 1 assert ensemble_size >= 1 # ----------------- Image Preprocess ----------------- # Resize image if processing_res > 0: input_image = resize_max_res( input_image, max_edge_resolution=processing_res ) # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel input_image = input_image.convert("RGB") image = np.asarray(input_image) # Normalize rgb values rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W] rgb_norm = rgb / 255.0 rgb_norm = torch.from_numpy(rgb_norm).to(self.vae.dtype) rgb_norm = rgb_norm.to(device) assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0 # ----------------- Predicting depth ----------------- # Batch repeated input image duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) single_rgb_dataset = TensorDataset(duplicated_rgb) if batch_size > 0: _bs = batch_size else: _bs = find_batch_size( ensemble_size=ensemble_size, input_res=max(rgb_norm.shape[1:]) ) single_rgb_loader = DataLoader( single_rgb_dataset, batch_size=_bs, shuffle=False ) # Predict depth maps (batched) depth_pred_ls = [] if show_progress_bar: iterable = tqdm( single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False ) else: iterable = single_rgb_loader for batch in iterable: (batched_img,) = batch depth_pred_raw = self.single_infer( rgb_in=batched_img, num_inference_steps=denoising_steps, show_pbar=show_progress_bar, ) depth_pred_ls.append(depth_pred_raw.detach().clone()) depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() torch.cuda.empty_cache() # clear vram cache for ensembling # ----------------- Test-time ensembling ----------------- if ensemble_size > 1: depth_pred, pred_uncert = ensemble_depths( depth_preds, **(ensemble_kwargs or {}) ) else: depth_pred = depth_preds pred_uncert = None # ----------------- Post processing ----------------- # Scale prediction to [0, 1] min_d = torch.min(depth_pred) max_d = torch.max(depth_pred) depth_pred = (depth_pred - min_d) / (max_d - min_d) # Convert to numpy depth_pred = depth_pred.cpu().numpy().astype(np.float32) # Resize back to original resolution if match_input_res: pred_img = Image.fromarray(depth_pred) pred_img = pred_img.resize(input_size) depth_pred = np.asarray(pred_img) # Clip output range depth_pred = depth_pred.clip(0, 1) # Colorize depth_colored = colorize_depth_maps( depth_pred, 0, 1, cmap=color_map ).squeeze() # [3, H, W], value in (0, 1) depth_colored = (depth_colored * 255).astype(np.uint8) depth_colored_hwc = chw2hwc(depth_colored) depth_colored_img = Image.fromarray(depth_colored_hwc) return MarigoldDepthOutput( depth_np=depth_pred, depth_colored=depth_colored_img, uncertainty=pred_uncert, ) def __encode_empty_text(self): """ Encode text embedding for empty prompt """ prompt = "" text_inputs = self.tokenizer( prompt, padding="do_not_pad", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) self.empty_text_embed = self.text_encoder(text_input_ids)[0] @torch.no_grad() def single_infer( self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool ) -> torch.Tensor: """ Perform an individual depth prediction without ensembling. Args: rgb_in (torch.Tensor): Input RGB image. num_inference_steps (int): Number of diffusion denoisign steps (DDIM) during inference. show_pbar (bool): Display a progress bar of diffusion denoising. Returns: torch.Tensor: Predicted depth map. """ device = rgb_in.device # Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # [T] # Encode image rgb_latent = self.encode_rgb(rgb_in) # Initial depth map (noise) depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=rgb_latent.dtype) # [B, 4, h, w] # Batched empty text embedding if self.empty_text_embed is None: self.__encode_empty_text() batch_empty_text_embed = self.empty_text_embed.repeat( (rgb_latent.shape[0], 1, 1) ) # [B, 2, 1024] # Denoising loop if show_pbar: iterable = tqdm( enumerate(timesteps), total=len(timesteps), leave=False, desc=" " * 4 + "Diffusion denoising", ) else: iterable = enumerate(timesteps) for i, t in iterable: unet_input = torch.cat( [rgb_latent, depth_latent], dim=1 ) # this order is important # predict the noise residual noise_pred = self.unet( unet_input, t, encoder_hidden_states=batch_empty_text_embed ).sample # [B, 4, h, w] # compute the previous noisy sample x_t -> x_t-1 depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample depth = self.decode_depth(depth_latent) # clip prediction depth = torch.clip(depth, -1.0, 1.0) # shift to [0, 1] depth = depth * 2.0 - 1.0 return depth def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: """ Encode RGB image into latent. Args: rgb_in (torch.Tensor): Input RGB image to be encoded. Returns: torch.Tensor: Image latent """ # encode h = self.vae.encoder(rgb_in) moments = self.vae.quant_conv(h) mean, logvar = torch.chunk(moments, 2, dim=1) # scale latent rgb_latent = mean * self.rgb_latent_scale_factor return rgb_latent def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: """ Decode depth latent into depth map. Args: depth_latent (torch.Tensor): Depth latent to be decoded. Returns: torch.Tensor: Decoded depth map. """ # scale latent depth_latent = depth_latent / self.depth_latent_scale_factor # decode z = self.vae.post_quant_conv(depth_latent) stacked = self.vae.decoder(z) # mean of output channels depth_mean = stacked.mean(dim=1, keepdim=True) return depth_mean