Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
from einops import rearrange | |
from PIL import Image | |
def tensor_to_pil(tensor, mask=None, normalize: bool = True): | |
""" | |
Convert tensor to PIL Image. | |
:param tensor: torch.Tensor or str (file path to tensor), shape can be (Nv, H, W, C), (Nv, C, H, W), (H, W, C), (C, H, W) | |
:param mask: torch.Tensor or str (file path to tensor), shape same as tensor, effective when C=3 | |
:return: PIL.Image | |
""" | |
# If input is a file path, load the tensor | |
if isinstance(tensor, str): | |
from utils.file_utils import load_tensor_from_file | |
tensor = load_tensor_from_file(tensor, map_location="cpu") | |
if mask is not None and isinstance(mask, str): | |
from utils.file_utils import load_tensor_from_file | |
mask = load_tensor_from_file(mask, map_location="cpu") | |
# Move to cpu | |
tensor = tensor.detach() | |
if tensor.is_cuda: | |
tensor = tensor.cpu() | |
if mask is not None and mask.is_cuda: | |
mask = mask.cpu() | |
# Convert to float32 | |
tensor = tensor.float() | |
if mask is not None: | |
mask = mask.float() | |
if normalize: | |
tensor = (tensor + 1.0) / 2.0 | |
tensor = torch.clamp(tensor, 0.0, 1.0) | |
if mask is not None: | |
if mask.shape[-1] not in [1, 3]: | |
mask = mask.unsqueeze(-1) | |
tensor = torch.cat([tensor, mask], dim=-1) | |
shape = tensor.shape | |
# 4D: (Nv, H, W, C) or (Nv, C, H, W) | |
if len(shape) == 4: | |
Nv = shape[0] | |
if shape[-1] in [3, 4]: # (Nv, H, W, C) | |
tensor = rearrange(tensor, 'nv h w c -> h (nv w) c') | |
else: # (Nv, C, H, W) | |
tensor = rearrange(tensor, 'nv c h w -> h (nv w) c') | |
# 3D: (H, W, C) or (C, H, W) | |
elif len(shape) == 3: | |
if shape[-1] in [3, 4]: # (H, W, C) | |
tensor = rearrange(tensor, 'h w c -> h w c') | |
else: # (C, H, W) | |
tensor = rearrange(tensor, 'c h w -> h w c') | |
else: | |
raise ValueError(f"Unsupported tensor shape: {shape}") | |
# Convert to numpy | |
np_img = (tensor.numpy() * 255).round().astype(np.uint8) | |
# Create PIL Image | |
if np_img.shape[2] == 3: | |
return Image.fromarray(np_img, mode="RGB") | |
elif np_img.shape[2] == 4: | |
return Image.fromarray(np_img, mode="RGBA") | |
else: | |
raise ValueError("Only support 3 or 4 channel images.") |