File size: 998 Bytes
d4e7f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Optional, Tuple
import torch
from PIL import Image
import matplotlib.pyplot as plt

from .util import bchw2hwc


def set_figsize(*args):
    if len(args) == 0:
        plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]
    elif len(args) == 1:
        plt.rcParams["figure.figsize"] = (args[0], args[0])
    elif len(args) == 2:
        plt.rcParams["figure.figsize"] = tuple(args)
    else:
        raise RuntimeError(
            f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)')


def show_hwc(image: torch.Tensor):
    if image.dtype != torch.uint8:
        image = image.to(torch.uint8)
    if image.size(2) == 1:
        image = image.repeat(1, 1, 3)
    pimage = Image.fromarray(image.cpu().numpy())
    plt.imshow(pimage)
    plt.imsave('12345.jpg', pimage)
    plt.show()


def show_bchw(image: torch.Tensor):
    show_hwc(bchw2hwc(image))


def show_bhw(image: torch.Tensor):
    show_bchw(image.unsqueeze(1))