from typing import Optional, Tuple import torch from PIL import Image import matplotlib.pyplot as plt from .util import bchw2hwc def set_figsize(*args): if len(args) == 0: plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"] elif len(args) == 1: plt.rcParams["figure.figsize"] = (args[0], args[0]) elif len(args) == 2: plt.rcParams["figure.figsize"] = tuple(args) else: raise RuntimeError( f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)') def show_hwc(image: torch.Tensor): if image.dtype != torch.uint8: image = image.to(torch.uint8) if image.size(2) == 1: image = image.repeat(1, 1, 3) pimage = Image.fromarray(image.cpu().numpy()) plt.imshow(pimage) plt.imsave('12345.jpg', pimage) plt.show() def show_bchw(image: torch.Tensor): show_hwc(bchw2hwc(image)) def show_bhw(image: torch.Tensor): show_bchw(image.unsqueeze(1))