dbaranchuk commited on
Commit
3ad0e52
·
verified ·
1 Parent(s): bee9a9c

Delete inversion.py

Browse files
Files changed (1) hide show
  1. inversion.py +0 -104
inversion.py DELETED
@@ -1,104 +0,0 @@
1
- import torch.nn.functional as nnf
2
- import torch
3
- import numpy as np
4
-
5
- from tqdm import tqdm
6
- from torch.optim.adam import Adam
7
- from PIL import Image
8
-
9
- from generation import load_512
10
- from p2p import register_attention_control
11
-
12
-
13
- def null_optimization(solver,
14
- latents,
15
- guidance_scale,
16
- num_inner_steps,
17
- epsilon):
18
- uncond_embeddings, cond_embeddings = solver.context.chunk(2)
19
- uncond_embeddings_list = []
20
- latent_cur = latents[-1]
21
- bar = tqdm(total=num_inner_steps * solver.n_steps)
22
- for i in range(solver.n_steps):
23
- uncond_embeddings = uncond_embeddings.clone().detach()
24
- uncond_embeddings.requires_grad = True
25
- optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
26
- latent_prev = latents[len(latents) - i - 2]
27
- t = solver.model.scheduler.timesteps[i]
28
- with torch.no_grad():
29
- noise_pred_cond = solver.get_noise_pred_single(latent_cur, t, cond_embeddings)
30
- for j in range(num_inner_steps):
31
- noise_pred_uncond = solver.get_noise_pred_single(latent_cur, t, uncond_embeddings)
32
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
33
- latents_prev_rec = solver.prev_step(noise_pred, t, latent_cur)
34
- loss = nnf.mse_loss(latents_prev_rec, latent_prev)
35
- optimizer.zero_grad()
36
- loss.backward()
37
- optimizer.step()
38
- loss_item = loss.item()
39
- bar.update()
40
- if loss_item < epsilon + i * 2e-5:
41
- break
42
- for j in range(j + 1, num_inner_steps):
43
- bar.update()
44
- uncond_embeddings_list.append(uncond_embeddings[:1].detach())
45
- with torch.no_grad():
46
- context = torch.cat([uncond_embeddings, cond_embeddings])
47
- noise_pred = solver.get_noise_pred(solver.model, latent_cur, t, guidance_scale, context)
48
- latent_cur = solver.prev_step(noise_pred, t, latent_cur)
49
- bar.close()
50
- return uncond_embeddings_list
51
-
52
-
53
- def invert(solver,
54
- stop_step,
55
- is_cons_inversion=False,
56
- inv_guidance_scale=1,
57
- nti_guidance_scale=8,
58
- dynamic_guidance=False,
59
- tau1=0.4,
60
- tau2=0.6,
61
- w_embed_dim=0,
62
- image_path=None,
63
- prompt='',
64
- offsets=(0, 0, 0, 0),
65
- do_nti=False,
66
- do_npi=False,
67
- num_inner_steps=10,
68
- early_stop_epsilon=1e-5,
69
- seed=0,
70
- ):
71
- solver.init_prompt(prompt)
72
- uncond_embeddings, cond_embeddings = solver.context.chunk(2)
73
- register_attention_control(solver.model, None)
74
- if isinstance(image_path, list):
75
- image_gt = [load_512(path, *offsets) for path in image_path]
76
- elif isinstance(image_path, str):
77
- image_gt = load_512(image_path, *offsets)
78
- else:
79
- image_gt = np.array(Image.fromarray(image_path).resize((512, 512)))
80
-
81
- if is_cons_inversion:
82
- image_rec, ddim_latents = solver.cons_inversion(image_gt,
83
- w_embed_dim=w_embed_dim,
84
- guidance_scale=inv_guidance_scale,
85
- seed=seed,)
86
- else:
87
- image_rec, ddim_latents = solver.ddim_inversion(image_gt,
88
- n_steps=stop_step,
89
- guidance_scale=inv_guidance_scale,
90
- dynamic_guidance=dynamic_guidance,
91
- tau1=tau1, tau2=tau2,
92
- w_embed_dim=w_embed_dim)
93
- if do_nti:
94
- print("Null-text optimization...")
95
- uncond_embeddings = null_optimization(solver,
96
- ddim_latents,
97
- nti_guidance_scale,
98
- num_inner_steps,
99
- early_stop_epsilon)
100
- elif do_npi:
101
- uncond_embeddings = [cond_embeddings] * solver.n_steps
102
- else:
103
- uncond_embeddings = None
104
- return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings