JiminHeo commited on
Commit
0bf5d84
·
1 Parent(s): 8732441
Files changed (1) hide show
  1. ldm/util.py +202 -0
ldm/util.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ import numpy as np
5
+ from collections import abc
6
+ from einops import rearrange
7
+ from functools import partial
8
+
9
+ import multiprocessing as mp
10
+ from threading import Thread
11
+ from queue import Queue
12
+
13
+ from inspect import isfunction
14
+ from PIL import Image, ImageDraw, ImageFont
15
+
16
+
17
+ def log_txt_as_img(wh, xc, size=10):
18
+ # wh a tuple of (width, height)
19
+ # xc a list of captions to plot
20
+ b = len(xc)
21
+ txts = list()
22
+ for bi in range(b):
23
+ txt = Image.new("RGB", wh, color="white")
24
+ draw = ImageDraw.Draw(txt)
25
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26
+ nc = int(40 * (wh[0] / 256))
27
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28
+
29
+ try:
30
+ draw.text((0, 0), lines, fill="black", font=font)
31
+ except UnicodeEncodeError:
32
+ print("Cant encode string for logging. Skipping.")
33
+
34
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35
+ txts.append(txt)
36
+ txts = np.stack(txts)
37
+ txts = torch.tensor(txts)
38
+ return txts
39
+
40
+
41
+ def ismap(x):
42
+ if not isinstance(x, torch.Tensor):
43
+ return False
44
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
45
+
46
+
47
+ def isimage(x):
48
+ if not isinstance(x, torch.Tensor):
49
+ return False
50
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51
+
52
+
53
+ def exists(x):
54
+ return x is not None
55
+
56
+
57
+ def default(val, d):
58
+ if exists(val):
59
+ return val
60
+ return d() if isfunction(d) else d
61
+
62
+
63
+ def mean_flat(tensor):
64
+ """
65
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66
+ Take the mean over all non-batch dimensions.
67
+ """
68
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
69
+
70
+
71
+ def count_params(model, verbose=False):
72
+ total_params = sum(p.numel() for p in model.parameters())
73
+ if verbose:
74
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75
+ return total_params
76
+
77
+
78
+ def instantiate_from_config(config):
79
+ if not "target" in config:
80
+ if config == '__is_first_stage__':
81
+ return None
82
+ elif config == "__is_unconditional__":
83
+ return None
84
+ raise KeyError("Expected key `target` to instantiate.")
85
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
86
+
87
+
88
+ def get_obj_from_str(string, reload=False):
89
+ module, cls = string.rsplit(".", 1)
90
+ if reload:
91
+ module_imp = importlib.import_module(module)
92
+ importlib.reload(module_imp)
93
+ return getattr(importlib.import_module(module, package=None), cls)
94
+
95
+
96
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97
+ # create dummy dataset instance
98
+ # run prefetching
99
+ if idx_to_fn:
100
+ res = func(data, worker_id=idx)
101
+ else:
102
+ res = func(data)
103
+ Q.put([idx, res])
104
+ Q.put("Done")
105
+
106
+
107
+ def parallel_data_prefetch(
108
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
109
+ ):
110
+ # if target_data_type not in ["ndarray", "list"]:
111
+ # raise ValueError(
112
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
113
+ # )
114
+ if isinstance(data, np.ndarray) and target_data_type == "list":
115
+ raise ValueError("list expected but function got ndarray.")
116
+ elif isinstance(data, abc.Iterable):
117
+ if isinstance(data, dict):
118
+ print(
119
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
120
+ )
121
+ data = list(data.values())
122
+ if target_data_type == "ndarray":
123
+ data = np.asarray(data)
124
+ else:
125
+ data = list(data)
126
+ else:
127
+ raise TypeError(
128
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
129
+ )
130
+
131
+ if cpu_intensive:
132
+ Q = mp.Queue(1000)
133
+ proc = mp.Process
134
+ else:
135
+ Q = Queue(1000)
136
+ proc = Thread
137
+ # spawn processes
138
+ if target_data_type == "ndarray":
139
+ arguments = [
140
+ [func, Q, part, i, use_worker_id]
141
+ for i, part in enumerate(np.array_split(data, n_proc))
142
+ ]
143
+ else:
144
+ step = (
145
+ int(len(data) / n_proc + 1)
146
+ if len(data) % n_proc != 0
147
+ else int(len(data) / n_proc)
148
+ )
149
+ arguments = [
150
+ [func, Q, part, i, use_worker_id]
151
+ for i, part in enumerate(
152
+ [data[i: i + step] for i in range(0, len(data), step)]
153
+ )
154
+ ]
155
+ processes = []
156
+ for i in range(n_proc):
157
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
158
+ processes += [p]
159
+
160
+ # start processes
161
+ print(f"Start prefetching...")
162
+ import time
163
+
164
+ start = time.time()
165
+ gather_res = [[] for _ in range(n_proc)]
166
+ try:
167
+ for p in processes:
168
+ p.start()
169
+
170
+ k = 0
171
+ while k < n_proc:
172
+ # get result
173
+ res = Q.get()
174
+ if res == "Done":
175
+ k += 1
176
+ else:
177
+ gather_res[res[0]] = res[1]
178
+
179
+ except Exception as e:
180
+ print("Exception: ", e)
181
+ for p in processes:
182
+ p.terminate()
183
+
184
+ raise e
185
+ finally:
186
+ for p in processes:
187
+ p.join()
188
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
189
+
190
+ if target_data_type == 'ndarray':
191
+ if not isinstance(gather_res[0], np.ndarray):
192
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
193
+
194
+ # order outputs
195
+ return np.concatenate(gather_res, axis=0)
196
+ elif target_data_type == 'list':
197
+ out = []
198
+ for r in gather_res:
199
+ out.extend(r)
200
+ return out
201
+ else:
202
+ return gather_res