SunderAli17 commited on
Commit
7778272
·
verified ·
1 Parent(s): 4cbf63e

Update toonmage/utils.py

Browse files
Files changed (1) hide show
  1. toonmage/utils.py +85 -0
toonmage/utils.py CHANGED
@@ -7,6 +7,91 @@ import numpy as np
7
  import torch
8
  import torch.nn.functional as F
9
  from transformers import PretrainedConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def seed_everything(seed):
 
7
  import torch
8
  import torch.nn.functional as F
9
  from transformers import PretrainedConfig
10
+ from torchvision.utils import make_grid
11
+ import math
12
+
13
+
14
+ # from basicsr
15
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
16
+ """Numpy array to tensor.
17
+ Args:
18
+ imgs (list[ndarray] | ndarray): Input images.
19
+ bgr2rgb (bool): Whether to change bgr to rgb.
20
+ float32 (bool): Whether to change to float32.
21
+ Returns:
22
+ list[tensor] | tensor: Tensor images. If returned results only have
23
+ one element, just return tensor.
24
+ """
25
+
26
+ def _totensor(img, bgr2rgb, float32):
27
+ if img.shape[2] == 3 and bgr2rgb:
28
+ if img.dtype == 'float64':
29
+ img = img.astype('float32')
30
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
31
+ img = torch.from_numpy(img.transpose(2, 0, 1))
32
+ if float32:
33
+ img = img.float()
34
+ return img
35
+
36
+ if isinstance(imgs, list):
37
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
38
+ return _totensor(imgs, bgr2rgb, float32)
39
+
40
+
41
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
42
+ """Convert torch Tensors into image numpy arrays.
43
+ After clamping to [min, max], values will be normalized to [0, 1].
44
+ Args:
45
+ tensor (Tensor or list[Tensor]): Accept shapes:
46
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
47
+ 2) 3D Tensor of shape (3/1 x H x W);
48
+ 3) 2D Tensor of shape (H x W).
49
+ Tensor channel should be in RGB order.
50
+ rgb2bgr (bool): Whether to change rgb to bgr.
51
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
52
+ to uint8 type with range [0, 255]; otherwise, float type with
53
+ range [0, 1]. Default: ``np.uint8``.
54
+ min_max (tuple[int]): min and max values for clamp.
55
+ Returns:
56
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57
+ shape (H x W). The channel order is BGR.
58
+ """
59
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61
+
62
+ if torch.is_tensor(tensor):
63
+ tensor = [tensor]
64
+ result = []
65
+ for _tensor in tensor:
66
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68
+
69
+ n_dim = _tensor.dim()
70
+ if n_dim == 4:
71
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72
+ img_np = img_np.transpose(1, 2, 0)
73
+ if rgb2bgr:
74
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75
+ elif n_dim == 3:
76
+ img_np = _tensor.numpy()
77
+ img_np = img_np.transpose(1, 2, 0)
78
+ if img_np.shape[2] == 1: # gray image
79
+ img_np = np.squeeze(img_np, axis=2)
80
+ else:
81
+ if rgb2bgr:
82
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83
+ elif n_dim == 2:
84
+ img_np = _tensor.numpy()
85
+ else:
86
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87
+ if out_type == np.uint8:
88
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89
+ img_np = (img_np * 255.0).round()
90
+ img_np = img_np.astype(out_type)
91
+ result.append(img_np)
92
+ if len(result) == 1:
93
+ result = result[0]
94
+ return result
95
 
96
 
97
  def seed_everything(seed):