dbaranchuk commited on
Commit
bee9a9c
·
verified ·
1 Parent(s): 2a35740

Delete generation.py

Browse files
Files changed (1) hide show
  1. generation.py +0 -621
generation.py DELETED
@@ -1,621 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from PIL import Image, ImageDraw, ImageFont
4
- from tqdm import tqdm
5
- from typing import Union
6
- from IPython.display import display
7
- import p2p
8
-
9
-
10
- # Main function to run
11
- # ----------------------------------------------------------------------
12
- @torch.no_grad()
13
- def runner(
14
- model,
15
- prompt,
16
- controller,
17
- solver,
18
- is_cons_forward=False,
19
- num_inference_steps=50,
20
- guidance_scale=7.5,
21
- generator=None,
22
- latent=None,
23
- uncond_embeddings=None,
24
- start_time=50,
25
- return_type='image',
26
- dynamic_guidance=False,
27
- tau1=0.4,
28
- tau2=0.6,
29
- w_embed_dim=0,
30
- ):
31
- p2p.register_attention_control(model, controller)
32
- height = width = 512
33
- solver.init_prompt(prompt, None)
34
- latent, latents = init_latent(latent, model, 512, 512, generator, len(prompt))
35
- model.scheduler.set_timesteps(num_inference_steps)
36
- dynamic_guidance = True if tau1 < 1.0 or tau1 < 1.0 else False
37
-
38
- if not is_cons_forward:
39
- latents = solver.ddim_loop(latents,
40
- num_inference_steps,
41
- is_forward=False,
42
- guidance_scale=guidance_scale,
43
- dynamic_guidance=dynamic_guidance,
44
- tau1=tau1,
45
- tau2=tau2,
46
- w_embed_dim=w_embed_dim,
47
- uncond_embeddings=uncond_embeddings if uncond_embeddings is not None else None,
48
- controller=controller)
49
- latents = latents[-1]
50
- else:
51
- latents = solver.cons_generation(
52
- latents,
53
- guidance_scale=guidance_scale,
54
- w_embed_dim=w_embed_dim,
55
- dynamic_guidance=dynamic_guidance,
56
- tau1=tau1,
57
- tau2=tau2,
58
- controller=controller)
59
- latents = latents[-1]
60
-
61
- if return_type == 'image':
62
- image = latent2image(model.vae, latents.to(model.vae.dtype))
63
- else:
64
- image = latents
65
-
66
- return image, latent
67
-
68
-
69
- # ----------------------------------------------------------------------
70
-
71
-
72
- # Utils
73
- # ----------------------------------------------------------------------
74
- def linear_schedule_old(t, guidance_scale, tau1, tau2):
75
- t = t / 1000
76
- if t <= tau1:
77
- gamma = 1.0
78
- elif t >= tau2:
79
- gamma = 0.0
80
- else:
81
- gamma = (tau2 - t) / (tau2 - tau1)
82
- return gamma * guidance_scale
83
-
84
-
85
- def linear_schedule(t, guidance_scale, tau1=0.4, tau2=0.8):
86
- t = t / 1000
87
- if t <= tau1:
88
- return guidance_scale
89
- if t >= tau2:
90
- return 1.0
91
- gamma = (tau2 - t) / (tau2 - tau1) * (guidance_scale - 1.0) + 1.0
92
-
93
- return gamma
94
-
95
-
96
- def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
97
- """
98
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
99
-
100
- Args:
101
- timesteps (`torch.Tensor`):
102
- generate embedding vectors at these timesteps
103
- embedding_dim (`int`, *optional*, defaults to 512):
104
- dimension of the embeddings to generate
105
- dtype:
106
- data type of the generated embeddings
107
-
108
- Returns:
109
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
110
- """
111
- assert len(w.shape) == 1
112
- w = w * 1000.0
113
-
114
- half_dim = embedding_dim // 2
115
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
116
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
117
- emb = w.to(dtype)[:, None] * emb[None, :]
118
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
119
- if embedding_dim % 2 == 1: # zero pad
120
- emb = torch.nn.functional.pad(emb, (0, 1))
121
- assert emb.shape == (w.shape[0], embedding_dim)
122
- return emb
123
-
124
-
125
- # ----------------------------------------------------------------------
126
-
127
-
128
- # Diffusion step with scheduler from diffusers and controller for editing
129
- # ----------------------------------------------------------------------
130
- def extract_into_tensor(a, t, x_shape):
131
- b, *_ = t.shape
132
- out = a.gather(-1, t)
133
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
134
-
135
-
136
- def predicted_origin(model_output, timesteps, boundary_timesteps, sample, prediction_type, alphas, sigmas):
137
- sigmas_s = extract_into_tensor(sigmas, boundary_timesteps, sample.shape)
138
- alphas_s = extract_into_tensor(alphas, boundary_timesteps, sample.shape)
139
-
140
- sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
141
- alphas = extract_into_tensor(alphas, timesteps, sample.shape)
142
-
143
- # Set hard boundaries to ensure equivalence with forward (direct) CD
144
- alphas_s[boundary_timesteps == 0] = 1.0
145
- sigmas_s[boundary_timesteps == 0] = 0.0
146
-
147
- if prediction_type == "epsilon":
148
- pred_x_0 = (sample - sigmas * model_output) / alphas # x0 prediction
149
- pred_x_0 = alphas_s * pred_x_0 + sigmas_s * model_output # Euler step to the boundary step
150
- elif prediction_type == "v_prediction":
151
- assert boundary_timesteps == 0, "v_prediction does not support multiple endpoints at the moment"
152
- pred_x_0 = alphas * sample - sigmas * model_output
153
- else:
154
- raise ValueError(f"Prediction type {prediction_type} currently not supported.")
155
- return pred_x_0
156
-
157
-
158
- def guided_step(noise_prediction_text,
159
- noise_pred_uncond,
160
- t,
161
- guidance_scale,
162
- dynamic_guidance=False,
163
- tau1=0.4,
164
- tau2=0.6):
165
- if dynamic_guidance:
166
- if not isinstance(t, int):
167
- t = t.item()
168
- new_guidance_scale = linear_schedule(t, guidance_scale, tau1=tau1, tau2=tau2)
169
- else:
170
- new_guidance_scale = guidance_scale
171
-
172
- noise_pred = noise_pred_uncond + new_guidance_scale * (noise_prediction_text - noise_pred_uncond)
173
- return noise_pred
174
-
175
-
176
- # ----------------------------------------------------------------------
177
-
178
-
179
- # DDIM scheduler with inversion
180
- # ----------------------------------------------------------------------
181
- class Generator:
182
-
183
- def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
184
- sample: Union[torch.FloatTensor, np.ndarray]):
185
- prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
186
- alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
187
- alpha_prod_t_prev = self.scheduler.alphas_cumprod[
188
- prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
189
- beta_prod_t = 1 - alpha_prod_t
190
- pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
191
- pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
192
- prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
193
- return prev_sample
194
-
195
- def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
196
- sample: Union[torch.FloatTensor, np.ndarray]):
197
- timestep, next_timestep = min(
198
- timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
199
- alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
200
- alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
201
- beta_prod_t = 1 - alpha_prod_t
202
- next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
203
- next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
204
- next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
205
- return next_sample
206
-
207
- def get_noise_pred_single(self, latents, t, context):
208
- noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
209
- return noise_pred
210
-
211
- def get_noise_pred(self,
212
- model,
213
- latent,
214
- t,
215
- guidance_scale=1,
216
- context=None,
217
- w_embed_dim=0,
218
- dynamic_guidance=False,
219
- tau1=0.4,
220
- tau2=0.6):
221
- latents_input = torch.cat([latent] * 2)
222
- if context is None:
223
- context = self.context
224
-
225
- # w embed
226
- # --------------------------------------
227
- if w_embed_dim > 0:
228
- if dynamic_guidance:
229
- if not isinstance(t, int):
230
- t_item = t.item()
231
- guidance_scale = linear_schedule_old(t_item, guidance_scale, tau1=tau1, tau2=tau2) # TODO UPDATE
232
- if len(latents_input) == 4:
233
- guidance_scale_tensor = torch.tensor([0.0, 0.0, 0.0, guidance_scale])
234
- else:
235
- guidance_scale_tensor = torch.tensor([guidance_scale] * len(latents_input))
236
- w_embedding = guidance_scale_embedding(guidance_scale_tensor, embedding_dim=w_embed_dim)
237
- w_embedding = w_embedding.to(device=latent.device, dtype=latent.dtype)
238
- else:
239
- w_embedding = None
240
- # --------------------------------------
241
- noise_pred = model.unet(latents_input.to(dtype=model.unet.dtype),
242
- t,
243
- timestep_cond=w_embedding.to(dtype=model.unet.dtype) if w_embed_dim > 0 else None,
244
- encoder_hidden_states=context)["sample"]
245
- noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
246
-
247
- if guidance_scale > 1 and w_embedding is None:
248
- noise_pred = guided_step(noise_prediction_text, noise_pred_uncond, t, guidance_scale, dynamic_guidance,
249
- tau1, tau2)
250
- else:
251
- noise_pred = noise_prediction_text
252
-
253
- return noise_pred
254
-
255
- @torch.no_grad()
256
- def latent2image(self, latents, return_type='np'):
257
- latents = 1 / 0.18215 * latents.detach()
258
- image = self.model.vae.decode(latents.to(dtype=self.model.dtype))['sample']
259
- if return_type == 'np':
260
- image = (image / 2 + 0.5).clamp(0, 1)
261
- image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
262
- image = (image * 255).astype(np.uint8)
263
- return image
264
-
265
- @torch.no_grad()
266
- def image2latent(self, image):
267
- with torch.no_grad():
268
- if type(image) is Image:
269
- image = np.array(image)
270
- if type(image) is torch.Tensor and image.dim() == 4:
271
- latents = image
272
- elif type(image) is list:
273
- image = [np.array(i).reshape(1, 512, 512, 3) for i in image]
274
- image = np.concatenate(image)
275
- image = torch.from_numpy(image).float() / 127.5 - 1
276
- image = image.permute(0, 3, 1, 2).to(self.model.device, dtype=self.model.vae.dtype)
277
- latents = self.model.vae.encode(image)['latent_dist'].mean
278
- latents = latents * 0.18215
279
- else:
280
- image = torch.from_numpy(image).float() / 127.5 - 1
281
- image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device, dtype=self.model.dtype)
282
- latents = self.model.vae.encode(image)['latent_dist'].mean
283
- latents = latents * 0.18215
284
- return latents
285
-
286
- @torch.no_grad()
287
- def init_prompt(self, prompt, uncond_embeddings=None):
288
- if uncond_embeddings is None:
289
- uncond_input = self.model.tokenizer(
290
- [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
291
- return_tensors="pt"
292
- )
293
- uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
294
- text_input = self.model.tokenizer(
295
- prompt,
296
- padding="max_length",
297
- max_length=self.model.tokenizer.model_max_length,
298
- truncation=True,
299
- return_tensors="pt",
300
- )
301
- text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
302
- self.context = torch.cat([uncond_embeddings.expand(*text_embeddings.shape), text_embeddings])
303
- self.prompt = prompt
304
-
305
- @torch.no_grad()
306
- def ddim_loop(self,
307
- latent,
308
- n_steps,
309
- is_forward=True,
310
- guidance_scale=1,
311
- dynamic_guidance=False,
312
- tau1=0.4,
313
- tau2=0.6,
314
- w_embed_dim=0,
315
- uncond_embeddings=None,
316
- controller=None):
317
- all_latent = [latent]
318
- latent = latent.clone().detach()
319
- for i in tqdm(range(n_steps)):
320
- if uncond_embeddings is not None:
321
- self.init_prompt(self.prompt, uncond_embeddings[i])
322
- if is_forward:
323
- t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
324
- else:
325
- t = self.model.scheduler.timesteps[i]
326
- noise_pred = self.get_noise_pred(
327
- model=self.model,
328
- latent=latent,
329
- t=t,
330
- context=None,
331
- guidance_scale=guidance_scale,
332
- dynamic_guidance=dynamic_guidance,
333
- w_embed_dim=w_embed_dim,
334
- tau1=tau1,
335
- tau2=tau2)
336
- if is_forward:
337
- latent = self.next_step(noise_pred, t, latent)
338
- else:
339
- latent = self.prev_step(noise_pred, t, latent)
340
- if controller is not None:
341
- latent = controller.step_callback(latent)
342
- all_latent.append(latent)
343
- return all_latent
344
-
345
- @property
346
- def scheduler(self):
347
- return self.model.scheduler
348
-
349
- @torch.no_grad()
350
- def ddim_inversion(self,
351
- image,
352
- n_steps=None,
353
- guidance_scale=1,
354
- dynamic_guidance=False,
355
- tau1=0.4,
356
- tau2=0.6,
357
- w_embed_dim=0):
358
-
359
- if n_steps is None:
360
- n_steps = self.n_steps
361
- latent = self.image2latent(image)
362
- image_rec = self.latent2image(latent)
363
- ddim_latents = self.ddim_loop(latent,
364
- is_forward=True,
365
- guidance_scale=guidance_scale,
366
- n_steps=n_steps,
367
- dynamic_guidance=dynamic_guidance,
368
- tau1=tau1,
369
- tau2=tau2,
370
- w_embed_dim=w_embed_dim)
371
- return image_rec, ddim_latents
372
-
373
- @torch.no_grad()
374
- def cons_generation(self,
375
- latent,
376
- guidance_scale=1,
377
- dynamic_guidance=False,
378
- tau1=0.4,
379
- tau2=0.6,
380
- w_embed_dim=0,
381
- controller=None, ):
382
-
383
- all_latent = [latent]
384
- latent = latent.clone().detach()
385
- alpha_schedule = torch.sqrt(self.model.scheduler.alphas_cumprod).to(self.model.device)
386
- sigma_schedule = torch.sqrt(1 - self.model.scheduler.alphas_cumprod).to(self.model.device)
387
-
388
- for i, (t, s) in enumerate(tqdm(zip(self.reverse_timesteps, self.reverse_boundary_timesteps))):
389
- noise_pred = self.get_noise_pred(
390
- model=self.reverse_cons_model,
391
- latent=latent,
392
- t=t.to(self.model.device),
393
- context=None,
394
- tau1=tau1, tau2=tau2,
395
- w_embed_dim=w_embed_dim,
396
- guidance_scale=guidance_scale,
397
- dynamic_guidance=dynamic_guidance)
398
-
399
- latent = predicted_origin(
400
- noise_pred,
401
- torch.tensor([t] * len(latent), device=self.model.device),
402
- torch.tensor([s] * len(latent), device=self.model.device),
403
- latent,
404
- self.model.scheduler.config.prediction_type,
405
- alpha_schedule,
406
- sigma_schedule,
407
- )
408
- if controller is not None:
409
- latent = controller.step_callback(latent)
410
- all_latent.append(latent)
411
-
412
- return all_latent
413
-
414
- @torch.no_grad()
415
- def cons_inversion(self,
416
- image,
417
- guidance_scale=0.0,
418
- w_embed_dim=0,
419
- seed=0):
420
- alpha_schedule = torch.sqrt(self.model.scheduler.alphas_cumprod).to(self.model.device)
421
- sigma_schedule = torch.sqrt(1 - self.model.scheduler.alphas_cumprod).to(self.model.device)
422
-
423
- # 5. Prepare latent variables
424
- latent = self.image2latent(image)
425
- generator = torch.Generator().manual_seed(seed)
426
- noise = torch.randn(latent.shape, generator=generator).to(latent.device)
427
- latent = self.noise_scheduler.add_noise(latent, noise, torch.tensor([self.start_timestep]))
428
- image_rec = self.latent2image(latent)
429
-
430
- for i, (t, s) in enumerate(tqdm(zip(self.forward_timesteps, self.forward_boundary_timesteps))):
431
- # predict the noise residual
432
- noise_pred = self.get_noise_pred(
433
- model=self.forward_cons_model,
434
- latent=latent,
435
- t=t.to(self.model.device),
436
- context=None,
437
- guidance_scale=guidance_scale,
438
- w_embed_dim=w_embed_dim,
439
- dynamic_guidance=False)
440
-
441
- latent = predicted_origin(
442
- noise_pred,
443
- torch.tensor([t] * len(latent), device=self.model.device),
444
- torch.tensor([s] * len(latent), device=self.model.device),
445
- latent,
446
- self.model.scheduler.config.prediction_type,
447
- alpha_schedule,
448
- sigma_schedule,
449
- )
450
-
451
- return image_rec, [latent]
452
-
453
- def _create_forward_inverse_timesteps(self,
454
- num_endpoints,
455
- n_steps,
456
- max_inverse_timestep_index):
457
- timestep_interval = n_steps // num_endpoints + int(n_steps % num_endpoints > 0)
458
- endpoint_idxs = torch.arange(timestep_interval, n_steps, timestep_interval) - 1
459
- inverse_endpoint_idxs = torch.arange(timestep_interval, n_steps, timestep_interval) - 1
460
- inverse_endpoint_idxs = torch.tensor(inverse_endpoint_idxs.tolist() + [max_inverse_timestep_index])
461
-
462
- endpoints = torch.tensor([0] + self.ddim_timesteps[endpoint_idxs].tolist())
463
- inverse_endpoints = self.ddim_timesteps[inverse_endpoint_idxs]
464
-
465
- return endpoints, inverse_endpoints
466
-
467
- def __init__(self,
468
- model,
469
- n_steps,
470
- noise_scheduler,
471
- forward_cons_model=None,
472
- reverse_cons_model=None,
473
- num_endpoints=1,
474
- num_forward_endpoints=1,
475
- reverse_timesteps=None,
476
- forward_timesteps=None,
477
- max_forward_timestep_index=49,
478
- start_timestep=19):
479
-
480
- self.model = model
481
- self.forward_cons_model = forward_cons_model
482
- self.reverse_cons_model = reverse_cons_model
483
- self.noise_scheduler = noise_scheduler
484
-
485
- self.n_steps = n_steps
486
- self.tokenizer = self.model.tokenizer
487
- self.model.scheduler.set_timesteps(n_steps)
488
- self.prompt = None
489
- self.context = None
490
- step_ratio = 1000 // n_steps
491
- self.ddim_timesteps = (np.arange(1, n_steps + 1) * step_ratio).round().astype(np.int64) - 1
492
- self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
493
- self.start_timestep = start_timestep
494
-
495
- # Set endpoints for direct CTM
496
- if reverse_timesteps is None or forward_timesteps is None:
497
- endpoints, inverse_endpoints = self._create_forward_inverse_timesteps(num_endpoints, n_steps,
498
- max_forward_timestep_index)
499
- self.reverse_timesteps, self.reverse_boundary_timesteps = inverse_endpoints.flip(0), endpoints.flip(0)
500
-
501
- # Set endpoints for forward CTM
502
- endpoints, inverse_endpoints = self._create_forward_inverse_timesteps(num_forward_endpoints, n_steps,
503
- max_forward_timestep_index)
504
- self.forward_timesteps, self.forward_boundary_timesteps = endpoints, inverse_endpoints
505
- self.forward_timesteps[0] = self.start_timestep
506
- else:
507
- self.reverse_timesteps, self.reverse_boundary_timesteps = reverse_timesteps, reverse_timesteps
508
- self.reverse_timesteps.reverse()
509
- self.reverse_boundary_timesteps = self.reverse_boundary_timesteps[1:] + [self.reverse_boundary_timesteps[0]]
510
- self.reverse_boundary_timesteps[-1] = 0
511
- self.reverse_timesteps, self.reverse_boundary_timesteps = torch.tensor(reverse_timesteps), torch.tensor(
512
- self.reverse_boundary_timesteps)
513
-
514
- self.forward_timesteps, self.forward_boundary_timesteps = forward_timesteps, forward_timesteps
515
- self.forward_boundary_timesteps = self.forward_boundary_timesteps[1:] + [self.forward_boundary_timesteps[0]]
516
- self.forward_boundary_timesteps[-1] = 999
517
- self.forward_timesteps, self.forward_boundary_timesteps = torch.tensor(
518
- self.forward_timesteps), torch.tensor(self.forward_boundary_timesteps)
519
-
520
- print(f"Endpoints reverse CTM: {self.reverse_timesteps}, {self.reverse_boundary_timesteps}")
521
- print(f"Endpoints forward CTM: {self.forward_timesteps}, {self.forward_boundary_timesteps}")
522
-
523
- # ----------------------------------------------------------------------
524
-
525
- # 3rd party utils
526
- # ----------------------------------------------------------------------
527
- def latent2image(vae, latents):
528
- latents = 1 / 0.18215 * latents
529
- image = vae.decode(latents)['sample']
530
- image = (image / 2 + 0.5).clamp(0, 1)
531
- image = image.cpu().permute(0, 2, 3, 1).numpy()
532
- image = (image * 255).astype(np.uint8)
533
- return image
534
-
535
-
536
- def init_latent(latent, model, height, width, generator, batch_size):
537
- if latent is None:
538
- latent = torch.randn(
539
- (1, model.unet.in_channels, height // 8, width // 8),
540
- generator=generator,
541
- )
542
- latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
543
- return latent, latents
544
-
545
-
546
- def load_512(image_path, left=0, right=0, top=0, bottom=0):
547
- # if type(image_path) is str:
548
- # image = np.array(Image.open(image_path))[:, :, :3]
549
- # else:
550
- # image = image_path
551
- # h, w, c = image.shape
552
- # left = min(left, w - 1)
553
- # right = min(right, w - left - 1)
554
- # top = min(top, h - left - 1)
555
- # bottom = min(bottom, h - top - 1)
556
- # image = image[top:h - bottom, left:w - right]
557
- # h, w, c = image.shape
558
- # if h < w:
559
- # offset = (w - h) // 2
560
- # image = image[:, offset:offset + h]
561
- # elif w < h:
562
- # offset = (h - w) // 2
563
- # image = image[offset:offset + w]
564
- image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
565
- image = np.array(Image.fromarray(image).resize((512, 512)))
566
- return image
567
-
568
-
569
- def to_pil_images(images, num_rows=1, offset_ratio=0.02):
570
- if type(images) is list:
571
- num_empty = len(images) % num_rows
572
- elif images.ndim == 4:
573
- num_empty = images.shape[0] % num_rows
574
- else:
575
- images = [images]
576
- num_empty = 0
577
-
578
- empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
579
- images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
580
- num_items = len(images)
581
-
582
- h, w, c = images[0].shape
583
- offset = int(h * offset_ratio)
584
- num_cols = num_items // num_rows
585
- image_ = np.ones((h * num_rows + offset * (num_rows - 1),
586
- w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
587
- for i in range(num_rows):
588
- for j in range(num_cols):
589
- image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
590
- i * num_cols + j]
591
-
592
- pil_img = Image.fromarray(image_)
593
- return pil_img
594
-
595
-
596
- def view_images(images, num_rows=1, offset_ratio=0.02):
597
- if type(images) is list:
598
- num_empty = len(images) % num_rows
599
- elif images.ndim == 4:
600
- num_empty = images.shape[0] % num_rows
601
- else:
602
- images = [images]
603
- num_empty = 0
604
-
605
- empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
606
- images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
607
- num_items = len(images)
608
-
609
- h, w, c = images[0].shape
610
- offset = int(h * offset_ratio)
611
- num_cols = num_items // num_rows
612
- image_ = np.ones((h * num_rows + offset * (num_rows - 1),
613
- w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
614
- for i in range(num_rows):
615
- for j in range(num_cols):
616
- image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
617
- i * num_cols + j]
618
-
619
- pil_img = Image.fromarray(image_)
620
- display(pil_img)
621
- # ----------------------------------------------------------------------