LGM / pipeline.py
dylanebert's picture
dylanebert HF staff
add pipeline
5e1c565
raw
history blame
1.83 kB
import numpy as np
import rembg
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from diffusers import DiffusionPipeline
class LGMPipeline(DiffusionPipeline):
def __init__(self, lgm):
super().__init__()
self.bg_remover = rembg.new_session()
self.imagenet_default_mean = (0.485, 0.456, 0.406)
self.imagenet_default_std = (0.229, 0.224, 0.225)
lgm = lgm.half().cuda()
self.register_modules(lgm=lgm)
def save_ply(self, gaussians, path):
self.lgm.gs.save_ply(gaussians, path)
@torch.no_grad()
def __call__(self, images):
unstacked = []
for i in range(4):
image = rembg.remove(images[i], session=self.bg_remover)
image = images.astype(np.float32) / 255.0
image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
unstacked.append(image)
images = np.concatenate(
[
np.concatenate([unstacked[1], unstacked[2]], axis=1),
np.concatenate([unstacked[3], unstacked[0]], axis=1),
],
axis=0,
)
images = np.stack([images[1], images[2], images[3], images[0]], axis=0)
images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda()
images = F.interpolate(
images,
size=(256, 256),
mode="bilinear",
align_corners=False,
)
images = TF.normalize(
images, self.imagenet_default_mean, self.imagenet_default_std
)
rays_embeddings = self.lgm.prepare_default_rays("cuda", elevation=0)
images = torch.cat([images, rays_embeddings], dim=1).unsqueeze(0)
images = images.half().cuda()
result = self.lgm(images)
return result