|
import json |
|
import math |
|
import os |
|
import re |
|
import shutil |
|
from typing import List, Optional, Union |
|
|
|
import cv2 |
|
import imageio |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
|
|
|
|
from matplotlib import cm |
|
from matplotlib.colors import LinearSegmentedColormap |
|
from PIL import Image, ImageDraw |
|
|
|
from .typing import * |
|
|
|
|
|
def tensor_to_image( |
|
data: Union[Image.Image, torch.Tensor, np.ndarray], |
|
batched: bool = False, |
|
format: str = "HWC", |
|
) -> Union[Image.Image, List[Image.Image]]: |
|
if isinstance(data, Image.Image): |
|
return data |
|
if isinstance(data, torch.Tensor): |
|
data = data.detach().cpu().numpy() |
|
if data.dtype == np.float32 or data.dtype == np.float16: |
|
data = (data * 255).astype(np.uint8) |
|
elif data.dtype == np.bool_: |
|
data = data.astype(np.uint8) * 255 |
|
assert data.dtype == np.uint8 |
|
if format == "CHW": |
|
if batched and data.ndim == 4: |
|
data = data.transpose((0, 2, 3, 1)) |
|
elif not batched and data.ndim == 3: |
|
data = data.transpose((1, 2, 0)) |
|
|
|
if batched: |
|
return [Image.fromarray(d) for d in data] |
|
return Image.fromarray(data) |
|
|
|
|
|
def largest_factor_near_sqrt(n: int) -> int: |
|
""" |
|
Finds the largest factor of n that is closest to the square root of n. |
|
|
|
Args: |
|
n (int): The integer for which to find the largest factor near its square root. |
|
|
|
Returns: |
|
int: The largest factor of n that is closest to the square root of n. |
|
""" |
|
sqrt_n = int(math.sqrt(n)) |
|
|
|
|
|
if sqrt_n * sqrt_n == n: |
|
return sqrt_n |
|
|
|
|
|
for i in range(sqrt_n, 0, -1): |
|
if n % i == 0: |
|
return i |
|
|
|
|
|
return 1 |
|
|
|
|
|
def make_image_grid( |
|
images: List[Image.Image], |
|
rows: Optional[int] = None, |
|
cols: Optional[int] = None, |
|
resize: Optional[int] = None, |
|
) -> Image.Image: |
|
""" |
|
Prepares a single grid of images. Useful for visualization purposes. |
|
""" |
|
if rows is None and cols is not None: |
|
assert len(images) % cols == 0 |
|
rows = len(images) // cols |
|
elif cols is None and rows is not None: |
|
assert len(images) % rows == 0 |
|
cols = len(images) // rows |
|
elif rows is None and cols is None: |
|
rows = largest_factor_near_sqrt(len(images)) |
|
cols = len(images) // rows |
|
|
|
assert len(images) == rows * cols |
|
|
|
if resize is not None: |
|
images = [img.resize((resize, resize)) for img in images] |
|
|
|
w, h = images[0].size |
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
|
|
for i, img in enumerate(images): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|
|
|
|
class SaverMixin: |
|
_save_dir: Optional[str] = None |
|
_wandb_logger: Optional[Any] = None |
|
|
|
def set_save_dir(self, save_dir: str): |
|
self._save_dir = save_dir |
|
|
|
def get_save_dir(self): |
|
if self._save_dir is None: |
|
raise ValueError("Save dir is not set") |
|
return self._save_dir |
|
|
|
def convert_data(self, data): |
|
if data is None: |
|
return None |
|
elif isinstance(data, np.ndarray): |
|
return data |
|
elif isinstance(data, torch.Tensor): |
|
if data.dtype in [torch.float16, torch.bfloat16]: |
|
data = data.float() |
|
return data.detach().cpu().numpy() |
|
elif isinstance(data, list): |
|
return [self.convert_data(d) for d in data] |
|
elif isinstance(data, dict): |
|
return {k: self.convert_data(v) for k, v in data.items()} |
|
else: |
|
raise TypeError( |
|
"Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", |
|
type(data), |
|
) |
|
|
|
def get_save_path(self, filename): |
|
save_path = os.path.join(self.get_save_dir(), filename) |
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
return save_path |
|
|
|
DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} |
|
DEFAULT_UV_KWARGS = { |
|
"data_format": "HWC", |
|
"data_range": (0, 1), |
|
"cmap": "checkerboard", |
|
} |
|
DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} |
|
DEFAULT_GRID_KWARGS = {"align": "max"} |
|
|
|
def get_rgb_image_(self, img, data_format, data_range, rgba=False): |
|
img = self.convert_data(img) |
|
assert data_format in ["CHW", "HWC"] |
|
if data_format == "CHW": |
|
img = img.transpose(1, 2, 0) |
|
if img.dtype != np.uint8: |
|
img = img.clip(min=data_range[0], max=data_range[1]) |
|
img = ( |
|
(img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 |
|
).astype(np.uint8) |
|
nc = 4 if rgba else 3 |
|
imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] |
|
imgs = [ |
|
( |
|
img_ |
|
if img_.shape[-1] == nc |
|
else np.concatenate( |
|
[ |
|
img_, |
|
np.zeros( |
|
(img_.shape[0], img_.shape[1], nc - img_.shape[2]), |
|
dtype=img_.dtype, |
|
), |
|
], |
|
axis=-1, |
|
) |
|
) |
|
for img_ in imgs |
|
] |
|
img = np.concatenate(imgs, axis=1) |
|
if rgba: |
|
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) |
|
else: |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
return img |
|
|
|
def _save_rgb_image( |
|
self, |
|
filename, |
|
img, |
|
data_format, |
|
data_range, |
|
name: Optional[str] = None, |
|
step: Optional[int] = None, |
|
): |
|
img = self.get_rgb_image_(img, data_format, data_range) |
|
cv2.imwrite(filename, img) |
|
if name and self._wandb_logger: |
|
self._wandb_logger.log_image( |
|
key=name, images=[self.get_save_path(filename)], step=step |
|
) |
|
|
|
def save_rgb_image( |
|
self, |
|
filename, |
|
img, |
|
data_format=DEFAULT_RGB_KWARGS["data_format"], |
|
data_range=DEFAULT_RGB_KWARGS["data_range"], |
|
name: Optional[str] = None, |
|
step: Optional[int] = None, |
|
) -> str: |
|
save_path = self.get_save_path(filename) |
|
self._save_rgb_image(save_path, img, data_format, data_range, name, step) |
|
return save_path |
|
|
|
def get_uv_image_(self, img, data_format, data_range, cmap): |
|
img = self.convert_data(img) |
|
assert data_format in ["CHW", "HWC"] |
|
if data_format == "CHW": |
|
img = img.transpose(1, 2, 0) |
|
img = img.clip(min=data_range[0], max=data_range[1]) |
|
img = (img - data_range[0]) / (data_range[1] - data_range[0]) |
|
assert cmap in ["checkerboard", "color"] |
|
if cmap == "checkerboard": |
|
n_grid = 64 |
|
mask = (img * n_grid).astype(int) |
|
mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 |
|
img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 |
|
img[mask] = np.array([255, 0, 255], dtype=np.uint8) |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
elif cmap == "color": |
|
img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) |
|
img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) |
|
img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) |
|
img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) |
|
img = img_ |
|
return img |
|
|
|
def save_uv_image( |
|
self, |
|
filename, |
|
img, |
|
data_format=DEFAULT_UV_KWARGS["data_format"], |
|
data_range=DEFAULT_UV_KWARGS["data_range"], |
|
cmap=DEFAULT_UV_KWARGS["cmap"], |
|
) -> str: |
|
save_path = self.get_save_path(filename) |
|
img = self.get_uv_image_(img, data_format, data_range, cmap) |
|
cv2.imwrite(save_path, img) |
|
return save_path |
|
|
|
def get_grayscale_image_(self, img, data_range, cmap): |
|
img = self.convert_data(img) |
|
img = np.nan_to_num(img) |
|
if data_range is None: |
|
img = (img - img.min()) / (img.max() - img.min()) |
|
else: |
|
img = img.clip(data_range[0], data_range[1]) |
|
img = (img - data_range[0]) / (data_range[1] - data_range[0]) |
|
assert cmap in [None, "jet", "magma", "spectral"] |
|
if cmap == None: |
|
img = (img * 255.0).astype(np.uint8) |
|
img = np.repeat(img[..., None], 3, axis=2) |
|
elif cmap == "jet": |
|
img = (img * 255.0).astype(np.uint8) |
|
img = cv2.applyColorMap(img, cv2.COLORMAP_JET) |
|
elif cmap == "magma": |
|
img = 1.0 - img |
|
base = cm.get_cmap("magma") |
|
num_bins = 256 |
|
colormap = LinearSegmentedColormap.from_list( |
|
f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins |
|
)(np.linspace(0, 1, num_bins))[:, :3] |
|
a = np.floor(img * 255.0) |
|
b = (a + 1).clip(max=255.0) |
|
f = img * 255.0 - a |
|
a = a.astype(np.uint16).clip(0, 255) |
|
b = b.astype(np.uint16).clip(0, 255) |
|
img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] |
|
img = (img * 255.0).astype(np.uint8) |
|
elif cmap == "spectral": |
|
colormap = plt.get_cmap("Spectral") |
|
|
|
def blend_rgba(image): |
|
image = image[..., :3] * image[..., -1:] + ( |
|
1.0 - image[..., -1:] |
|
) |
|
return image |
|
|
|
img = colormap(img) |
|
img = blend_rgba(img) |
|
img = (img * 255).astype(np.uint8) |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
return img |
|
|
|
def _save_grayscale_image( |
|
self, |
|
filename, |
|
img, |
|
data_range, |
|
cmap, |
|
name: Optional[str] = None, |
|
step: Optional[int] = None, |
|
): |
|
img = self.get_grayscale_image_(img, data_range, cmap) |
|
cv2.imwrite(filename, img) |
|
if name and self._wandb_logger: |
|
self._wandb_logger.log_image( |
|
key=name, images=[self.get_save_path(filename)], step=step |
|
) |
|
|
|
def save_grayscale_image( |
|
self, |
|
filename, |
|
img, |
|
data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], |
|
cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], |
|
name: Optional[str] = None, |
|
step: Optional[int] = None, |
|
) -> str: |
|
save_path = self.get_save_path(filename) |
|
self._save_grayscale_image(save_path, img, data_range, cmap, name, step) |
|
return save_path |
|
|
|
def get_image_grid_(self, imgs, align): |
|
if isinstance(imgs[0], list): |
|
return np.concatenate( |
|
[self.get_image_grid_(row, align) for row in imgs], axis=0 |
|
) |
|
cols = [] |
|
for col in imgs: |
|
assert col["type"] in ["rgb", "uv", "grayscale"] |
|
if col["type"] == "rgb": |
|
rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() |
|
rgb_kwargs.update(col["kwargs"]) |
|
cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) |
|
elif col["type"] == "uv": |
|
uv_kwargs = self.DEFAULT_UV_KWARGS.copy() |
|
uv_kwargs.update(col["kwargs"]) |
|
cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) |
|
elif col["type"] == "grayscale": |
|
grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() |
|
grayscale_kwargs.update(col["kwargs"]) |
|
cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) |
|
|
|
if align == "max": |
|
h = max([col.shape[0] for col in cols]) |
|
elif align == "min": |
|
h = min([col.shape[0] for col in cols]) |
|
elif isinstance(align, int): |
|
h = align |
|
else: |
|
raise ValueError( |
|
f"Unsupported image grid align: {align}, should be min, max, or int" |
|
) |
|
|
|
for i in range(len(cols)): |
|
if cols[i].shape[0] != h: |
|
w = int(cols[i].shape[1] * h / cols[i].shape[0]) |
|
cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_CUBIC) |
|
return np.concatenate(cols, axis=1) |
|
|
|
def save_image_grid( |
|
self, |
|
filename, |
|
imgs, |
|
align=DEFAULT_GRID_KWARGS["align"], |
|
name: Optional[str] = None, |
|
step: Optional[int] = None, |
|
texts: Optional[List[float]] = None, |
|
): |
|
save_path = self.get_save_path(filename) |
|
img = self.get_image_grid_(imgs, align=align) |
|
|
|
if texts is not None: |
|
img = Image.fromarray(img) |
|
draw = ImageDraw.Draw(img) |
|
black, white = (0, 0, 0), (255, 255, 255) |
|
for i, text in enumerate(texts): |
|
draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) |
|
draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) |
|
draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) |
|
draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) |
|
draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) |
|
img = np.asarray(img) |
|
|
|
cv2.imwrite(save_path, img) |
|
if name and self._wandb_logger: |
|
self._wandb_logger.log_image(key=name, images=[save_path], step=step) |
|
return save_path |
|
|
|
def save_image(self, filename, img) -> str: |
|
save_path = self.get_save_path(filename) |
|
img = self.convert_data(img) |
|
assert img.dtype == np.uint8 or img.dtype == np.uint16 |
|
if img.ndim == 3 and img.shape[-1] == 3: |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
elif img.ndim == 3 and img.shape[-1] == 4: |
|
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) |
|
cv2.imwrite(save_path, img) |
|
return save_path |
|
|
|
def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: |
|
save_path = self.get_save_path(filename) |
|
img = self.convert_data(img) |
|
assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] |
|
|
|
imgs_full = [] |
|
for start in range(0, img.shape[-1], 3): |
|
img_ = img[..., start : start + 3] |
|
img_ = np.stack( |
|
[ |
|
self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) |
|
for i in range(img_.shape[0]) |
|
], |
|
axis=0, |
|
) |
|
size = img_.shape[1] |
|
placeholder = np.zeros((size, size, 3), dtype=np.float32) |
|
img_full = np.concatenate( |
|
[ |
|
np.concatenate( |
|
[placeholder, img_[2], placeholder, placeholder], axis=1 |
|
), |
|
np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), |
|
np.concatenate( |
|
[placeholder, img_[3], placeholder, placeholder], axis=1 |
|
), |
|
], |
|
axis=0, |
|
) |
|
imgs_full.append(img_full) |
|
|
|
imgs_full = np.concatenate(imgs_full, axis=1) |
|
cv2.imwrite(save_path, imgs_full) |
|
return save_path |
|
|
|
def save_data(self, filename, data) -> str: |
|
data = self.convert_data(data) |
|
if isinstance(data, dict): |
|
if not filename.endswith(".npz"): |
|
filename += ".npz" |
|
save_path = self.get_save_path(filename) |
|
np.savez(save_path, **data) |
|
else: |
|
if not filename.endswith(".npy"): |
|
filename += ".npy" |
|
save_path = self.get_save_path(filename) |
|
np.save(save_path, data) |
|
return save_path |
|
|
|
def save_state_dict(self, filename, data) -> str: |
|
save_path = self.get_save_path(filename) |
|
torch.save(data, save_path) |
|
return save_path |
|
|
|
def save_img_sequence( |
|
self, |
|
filename, |
|
img_dir, |
|
matcher, |
|
save_format="mp4", |
|
fps=30, |
|
name: Optional[str] = None, |
|
step: Optional[int] = None, |
|
) -> str: |
|
assert save_format in ["gif", "mp4"] |
|
if not filename.endswith(save_format): |
|
filename += f".{save_format}" |
|
save_path = self.get_save_path(filename) |
|
matcher = re.compile(matcher) |
|
img_dir = os.path.join(self.get_save_dir(), img_dir) |
|
imgs = [] |
|
for f in os.listdir(img_dir): |
|
if matcher.search(f): |
|
imgs.append(f) |
|
imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) |
|
imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] |
|
|
|
if save_format == "gif": |
|
imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] |
|
imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) |
|
elif save_format == "mp4": |
|
imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] |
|
imageio.mimsave(save_path, imgs, fps=fps) |
|
if name and self._wandb_logger: |
|
from .core import warn |
|
|
|
warn("Wandb logger does not support video logging yet!") |
|
return save_path |
|
|
|
def save_img_sequences( |
|
self, |
|
seq_dir, |
|
matcher, |
|
save_format="mp4", |
|
fps=30, |
|
delete=True, |
|
name: Optional[str] = None, |
|
step: Optional[int] = None, |
|
): |
|
seq_dir_ = os.path.join(self.get_save_dir(), seq_dir) |
|
for f in os.listdir(seq_dir_): |
|
img_dir_ = os.path.join(seq_dir_, f) |
|
if not os.path.isdir(img_dir_): |
|
continue |
|
try: |
|
self.save_img_sequence( |
|
os.path.join(seq_dir, f), |
|
os.path.join(seq_dir, f), |
|
matcher, |
|
save_format=save_format, |
|
fps=fps, |
|
name=f"{name}_{f}", |
|
step=step, |
|
) |
|
if delete: |
|
shutil.rmtree(img_dir_) |
|
except: |
|
from .core import warn |
|
|
|
warn(f"Video saving for directory {seq_dir_} failed!") |
|
|
|
def save_file(self, filename, src_path, delete=False) -> str: |
|
save_path = self.get_save_path(filename) |
|
shutil.copyfile(src_path, save_path) |
|
if delete: |
|
os.remove(src_path) |
|
return save_path |
|
|
|
def save_json(self, filename, payload) -> str: |
|
save_path = self.get_save_path(filename) |
|
with open(save_path, "w") as f: |
|
f.write(json.dumps(payload)) |
|
return save_path |
|
|