import numpy as np import rembg import torch import torch.nn.functional as F import torchvision.transforms.functional as TF from diffusers import DiffusionPipeline class LGMPipeline(DiffusionPipeline): def __init__(self, lgm): super().__init__() self.bg_remover = rembg.new_session() self.imagenet_default_mean = (0.485, 0.456, 0.406) self.imagenet_default_std = (0.229, 0.224, 0.225) lgm = lgm.half().cuda() self.register_modules(lgm=lgm) def save_ply(self, gaussians, path): self.lgm.gs.save_ply(gaussians, path) @torch.no_grad() def __call__(self, images): unstacked = [] for i in range(4): image = rembg.remove(images[i], session=self.bg_remover) image = images.astype(np.float32) / 255.0 image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:]) unstacked.append(image) images = np.concatenate( [ np.concatenate([unstacked[1], unstacked[2]], axis=1), np.concatenate([unstacked[3], unstacked[0]], axis=1), ], axis=0, ) images = np.stack([images[1], images[2], images[3], images[0]], axis=0) images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda() images = F.interpolate( images, size=(256, 256), mode="bilinear", align_corners=False, ) images = TF.normalize( images, self.imagenet_default_mean, self.imagenet_default_std ) rays_embeddings = self.lgm.prepare_default_rays("cuda", elevation=0) images = torch.cat([images, rays_embeddings], dim=1).unsqueeze(0) images = images.half().cuda() result = self.lgm(images) return result