FSFM-3C
Add V1.0
d4e7f2f
raw
history blame contribute delete
998 Bytes
from typing import Optional, Tuple
import torch
from PIL import Image
import matplotlib.pyplot as plt
from .util import bchw2hwc
def set_figsize(*args):
if len(args) == 0:
plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]
elif len(args) == 1:
plt.rcParams["figure.figsize"] = (args[0], args[0])
elif len(args) == 2:
plt.rcParams["figure.figsize"] = tuple(args)
else:
raise RuntimeError(
f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)')
def show_hwc(image: torch.Tensor):
if image.dtype != torch.uint8:
image = image.to(torch.uint8)
if image.size(2) == 1:
image = image.repeat(1, 1, 3)
pimage = Image.fromarray(image.cpu().numpy())
plt.imshow(pimage)
plt.imsave('12345.jpg', pimage)
plt.show()
def show_bchw(image: torch.Tensor):
show_hwc(bchw2hwc(image))
def show_bhw(image: torch.Tensor):
show_bchw(image.unsqueeze(1))