File size: 6,065 Bytes
499e141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# -*- coding: utf-8 -*-
# @Author  : xuelun

import os
import time
import yaml
import torch
import random
import numpy as np


project_name = os.path.basename(os.getcwd())


def make_reproducible(iscuda, seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if iscuda:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # set True will make data load faster
        #   but, it will influence reproducible
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True


def hint(msg):
    timestamp = f'{time.strftime("%m/%d %H:%M:%S", time.localtime(time.time()))}'
    print('\033[1m' + project_name + ' >> ' + timestamp + ' >> ' + '\033[0m' + msg)


def datainfo(infos, datalen, gpuid):
    if gpuid != 0: return
    # print informations about benchmarks
    print('')
    print(f'{" Benchmarks":14}|{" Sequence":20}|{" Count":8}')
    print(f'{"-" * 45}')
    for k0, v0 in infos.items():
        isfirst = True
        for k1, v1 in v0.items():
            line = f' {k0:13}|' if isfirst else f'{" " * 14}|'
            line += f' {k1:19}|'
            line += f' {str(v1):7}'
            print(line)
            print(f'{"-" * 45}')
            isfirst = False
    print(f'{" " * 37}{str(datalen)}')
    print(f'{"-" * 45}')
    print('')


# noinspection PyTypeChecker
def mesh_positions(h: int, w: int):
    gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w))
    gx, gy = gx.contiguous()[None, :], gy.contiguous()[None, :]
    pos = torch.cat((gx.view(1, -1), gy.view(1, -1)))  # [2, H*W]
    return pos


def current_time(f=None):
    """
    :param f: default for log, "f" for file name
    :return: formatted time
    """
    if f == "f":
        return f'{time.strftime("%m.%d_%H.%M.%S", time.localtime(time.time()))}'
    return f'{time.strftime("%m/%d %H:%M:%S", time.localtime(time.time()))}'


def mkdir(dir):
    if not os.path.isdir(dir):
        os.makedirs(dir, exist_ok=False)


def pdist(x, y=None):
    """
    Pairwise Distance
    Args:
        x: [bs, n, 2]
        y: [bs, n, 2]
    Returns: [bs, n, n] value in euclidean *square* distance
    """
    # B, n, two = x.shape
    x = x.double()  # [bs, n, 2]
    
    x_norm = (x ** 2).sum(-1, keepdim=True)  # [bs, n, 1]
    if y is not None:
        y = y.double()
        y_t = y.transpose(1, 2)  # [bs, 2, n]
        y_norm = (y ** 2).sum(-1, keepdim=True).transpose(1, 2)  # [bs, 1, n]
    else:
        y_t = x.transpose(1, 2)  # [bs, 2, n]
        y_norm = x_norm.transpose(1, 2)  # [bs, 1, n]
    
    dist = x_norm + y_norm - 2.0 * torch.matmul(x, y_t)  # [bs, n, n]
    return dist


mean = lambda lis: sum(lis) / len(lis)
eps = lambda x: x + 1e-8


def load_configs(configs):
    with open(configs, 'r') as stream:
        try:
            x = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)
        return x


def find_in_dir(run, dir):
    runs = os.listdir(dir)
    runs = [r for r in runs if run in r]
    if len(runs) <= 0:
        hint(f'Not exist run name contain : {run}')
        exit(-1)
    elif len(runs) >= 2:
        hint(f'{len(runs)} runs name contain : {run}')
        hint(f'I will return the first one : {runs[-1]}')
    else:
        hint(f'Success match {runs[-1]}')
    return runs[-1]


def ckpt_in_dir(key, dir):
    runs = os.listdir(dir)
    runs = [r for r in runs if key in r]
    if len(runs) <= 0:
        hint(f'Not exist run name contain : {key}')
        exit(-1)
    elif len(runs) >= 2:
        hint(f'{len(runs)} runs name contain : {key}')
        hint(f'I will return the first one : {runs[-1]}')
    else:
        hint(f'Success match {runs[-1]}')
    return runs[-1]


def kpts2grid(kpts, scale, size):
    """
    change coordinates for keypoints from size0 to size1
    and format as grid which coordinates from [-1, 1]
    Args:
        kpts: (b, n, 2) - (x, y)
        scale: (b, 2) - (w, h) - the keypoints working shape to unet working shape 
        size: (b, 2) - (h, w) - the unet working shape which is 'resize0/1' in data
    Returns: new kpts: (b, 1, n, 2) - (x, y) in [-1, 1]
    """
    # kpts coordinates in unet shape
    kpts /= scale[:,None,:]
    # kpts[:,:,0] - (b, n)
    kpts[:, :, 0] *= 2 / (size[:, 1][:, None] - 1)
    kpts[:, :, 1] *= 2 / (size[:, 0][:, None] - 1)
    # make kpts from [0, 2] to [-1, 1]
    kpts -= 1
    # assume all kpts in [-1, 1]
    kpts = kpts.clamp(min=-1, max=1) # (b, n, 2)
    # make kpts shape from (b, n, 2) to (b, 1, n, 2)
    kpts = kpts[:,None]

    return kpts


def debug(x):
    if 'DATASET' in list(x.keys()):
        y = x.DATASET
        y.TRAIN.LIST_PATH = y.TRAIN.LIST_PATH.replace('scene_list', 'scene_list_debug')
        y.VALID.LIST_PATH = y.VALID.LIST_PATH.replace('scene_list', 'scene_list_debug')
    return x


def summary_loss(loss_list):
    n = 0
    sums = 0
    for loss in loss_list:
        if (loss is not None) and (not torch.isnan(loss)):
            sums += loss
            n += 1
    sums = sums / n if n != 0 else None
    return sums


def summary_metrics(dic, h1, h2):
    print('')

    # Head
    print(f'RunID     {h1:9}', end='')
    print('   | ', end='')
    print(f'Version   {h2:10}', end='')

    # Content
    print(f'{"| ".join(f"{key:10}" for key in dic[0].keys())}')
    for metric in dic:
        print(f'{"-" * 12 * len(dic[0].keys())}')
        print(f'{"| ".join(f"{metric[key]:<10.5f}" for key in metric.keys())}')

    print('')


def get_padding_size(image, h, w):
    orig_width = image.shape[3]
    orig_height = image.shape[2]
    aspect_ratio = w / h

    new_width = max(orig_width, int(orig_height * aspect_ratio))
    new_height = max(orig_height, int(orig_width / aspect_ratio))

    pad_height = new_height - orig_height
    pad_width = new_width - orig_width

    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left

    return orig_width, orig_height, pad_left, pad_right, pad_top, pad_bottom