File size: 998 Bytes
d4e7f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from typing import Optional, Tuple
import torch
from PIL import Image
import matplotlib.pyplot as plt
from .util import bchw2hwc
def set_figsize(*args):
if len(args) == 0:
plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]
elif len(args) == 1:
plt.rcParams["figure.figsize"] = (args[0], args[0])
elif len(args) == 2:
plt.rcParams["figure.figsize"] = tuple(args)
else:
raise RuntimeError(
f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)')
def show_hwc(image: torch.Tensor):
if image.dtype != torch.uint8:
image = image.to(torch.uint8)
if image.size(2) == 1:
image = image.repeat(1, 1, 3)
pimage = Image.fromarray(image.cpu().numpy())
plt.imshow(pimage)
plt.imsave('12345.jpg', pimage)
plt.show()
def show_bchw(image: torch.Tensor):
show_hwc(bchw2hwc(image))
def show_bhw(image: torch.Tensor):
show_bchw(image.unsqueeze(1))
|