| from typing import Optional, Tuple | |
| import torch | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from .util import bchw2hwc | |
| def set_figsize(*args): | |
| if len(args) == 0: | |
| plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"] | |
| elif len(args) == 1: | |
| plt.rcParams["figure.figsize"] = (args[0], args[0]) | |
| elif len(args) == 2: | |
| plt.rcParams["figure.figsize"] = tuple(args) | |
| else: | |
| raise RuntimeError( | |
| f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)') | |
| def show_hwc(image: torch.Tensor): | |
| if image.dtype != torch.uint8: | |
| image = image.to(torch.uint8) | |
| if image.size(2) == 1: | |
| image = image.repeat(1, 1, 3) | |
| pimage = Image.fromarray(image.cpu().numpy()) | |
| plt.imshow(pimage) | |
| plt.imsave('12345.jpg', pimage) | |
| plt.show() | |
| def show_bchw(image: torch.Tensor): | |
| show_hwc(bchw2hwc(image)) | |
| def show_bhw(image: torch.Tensor): | |
| show_bchw(image.unsqueeze(1)) | |