stable-fast-3d / sf3d /utils.py
jammmmm's picture
Update to latest inference code
77d8010
raw
history blame
2.7 kB
import os
from typing import Any, Union
import numpy as np
import rembg
import torch
import torchvision.transforms.functional as torchvision_F
from PIL import Image
import sf3d.models.utils as sf3d_utils
def get_device():
if os.environ.get("SF3D_USE_CPU", "0") == "1":
return "cpu"
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
return device
def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
intrinsic = sf3d_utils.get_intrinsic_from_fov(
np.deg2rad(fov_deg),
H=cond_height,
W=cond_width,
)
intrinsic_normed_cond = intrinsic.clone()
intrinsic_normed_cond[..., 0, 2] /= cond_width
intrinsic_normed_cond[..., 1, 2] /= cond_height
intrinsic_normed_cond[..., 0, 0] /= cond_width
intrinsic_normed_cond[..., 1, 1] /= cond_height
return intrinsic, intrinsic_normed_cond
def default_cond_c2w(distance: float):
c2w_cond = torch.as_tensor(
[
[0, 0, 1, distance],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
).float()
return c2w_cond
def remove_background(
image: Image,
rembg_session: Any = None,
force: bool = False,
**rembg_kwargs,
) -> Image:
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
do_remove = do_remove or force
if do_remove:
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
return image
def get_1d_bounds(arr):
nz = np.flatnonzero(arr)
return nz[0], nz[-1]
def get_bbox_from_mask(mask, thr=0.5):
masks_for_box = (mask > thr).astype(np.float32)
assert masks_for_box.sum() > 0, "Empty mask!"
x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
return x0, y0, x1, y1
def resize_foreground(
image: Union[Image.Image, np.ndarray],
ratio: float,
out_size=None,
) -> Image:
if isinstance(image, np.ndarray):
image = Image.fromarray(image, mode="RGBA")
assert image.mode == "RGBA"
# Get bounding box
mask_np = np.array(image)[:, :, -1]
x1, y1, x2, y2 = get_bbox_from_mask(mask_np, thr=0.5)
h, w = y2 - y1, x2 - x1
yc, xc = (y1 + y2) / 2, (x1 + x2) / 2
scale = max(h, w) / ratio
new_image = torchvision_F.crop(
image,
top=int(yc - scale / 2),
left=int(xc - scale / 2),
height=int(scale),
width=int(scale),
)
if out_size is not None:
new_image = new_image.resize(out_size)
return new_image