Last commit not found
from typing import Optional, Tuple | |
import torch | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from .util import bchw2hwc | |
def set_figsize(*args): | |
if len(args) == 0: | |
plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"] | |
elif len(args) == 1: | |
plt.rcParams["figure.figsize"] = (args[0], args[0]) | |
elif len(args) == 2: | |
plt.rcParams["figure.figsize"] = tuple(args) | |
else: | |
raise RuntimeError( | |
f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)') | |
def show_hwc(image: torch.Tensor): | |
if image.dtype != torch.uint8: | |
image = image.to(torch.uint8) | |
if image.size(2) == 1: | |
image = image.repeat(1, 1, 3) | |
pimage = Image.fromarray(image.cpu().numpy()) | |
plt.imshow(pimage) | |
plt.imsave('12345.jpg', pimage) | |
plt.show() | |
def show_bchw(image: torch.Tensor): | |
show_hwc(bchw2hwc(image)) | |
def show_bhw(image: torch.Tensor): | |
show_bchw(image.unsqueeze(1)) | |