Spaces:
Runtime error
Runtime error
Delete inversion.py
Browse files- inversion.py +0 -104
inversion.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
import torch.nn.functional as nnf
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
from tqdm import tqdm
|
6 |
-
from torch.optim.adam import Adam
|
7 |
-
from PIL import Image
|
8 |
-
|
9 |
-
from generation import load_512
|
10 |
-
from p2p import register_attention_control
|
11 |
-
|
12 |
-
|
13 |
-
def null_optimization(solver,
|
14 |
-
latents,
|
15 |
-
guidance_scale,
|
16 |
-
num_inner_steps,
|
17 |
-
epsilon):
|
18 |
-
uncond_embeddings, cond_embeddings = solver.context.chunk(2)
|
19 |
-
uncond_embeddings_list = []
|
20 |
-
latent_cur = latents[-1]
|
21 |
-
bar = tqdm(total=num_inner_steps * solver.n_steps)
|
22 |
-
for i in range(solver.n_steps):
|
23 |
-
uncond_embeddings = uncond_embeddings.clone().detach()
|
24 |
-
uncond_embeddings.requires_grad = True
|
25 |
-
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
|
26 |
-
latent_prev = latents[len(latents) - i - 2]
|
27 |
-
t = solver.model.scheduler.timesteps[i]
|
28 |
-
with torch.no_grad():
|
29 |
-
noise_pred_cond = solver.get_noise_pred_single(latent_cur, t, cond_embeddings)
|
30 |
-
for j in range(num_inner_steps):
|
31 |
-
noise_pred_uncond = solver.get_noise_pred_single(latent_cur, t, uncond_embeddings)
|
32 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
33 |
-
latents_prev_rec = solver.prev_step(noise_pred, t, latent_cur)
|
34 |
-
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
|
35 |
-
optimizer.zero_grad()
|
36 |
-
loss.backward()
|
37 |
-
optimizer.step()
|
38 |
-
loss_item = loss.item()
|
39 |
-
bar.update()
|
40 |
-
if loss_item < epsilon + i * 2e-5:
|
41 |
-
break
|
42 |
-
for j in range(j + 1, num_inner_steps):
|
43 |
-
bar.update()
|
44 |
-
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
|
45 |
-
with torch.no_grad():
|
46 |
-
context = torch.cat([uncond_embeddings, cond_embeddings])
|
47 |
-
noise_pred = solver.get_noise_pred(solver.model, latent_cur, t, guidance_scale, context)
|
48 |
-
latent_cur = solver.prev_step(noise_pred, t, latent_cur)
|
49 |
-
bar.close()
|
50 |
-
return uncond_embeddings_list
|
51 |
-
|
52 |
-
|
53 |
-
def invert(solver,
|
54 |
-
stop_step,
|
55 |
-
is_cons_inversion=False,
|
56 |
-
inv_guidance_scale=1,
|
57 |
-
nti_guidance_scale=8,
|
58 |
-
dynamic_guidance=False,
|
59 |
-
tau1=0.4,
|
60 |
-
tau2=0.6,
|
61 |
-
w_embed_dim=0,
|
62 |
-
image_path=None,
|
63 |
-
prompt='',
|
64 |
-
offsets=(0, 0, 0, 0),
|
65 |
-
do_nti=False,
|
66 |
-
do_npi=False,
|
67 |
-
num_inner_steps=10,
|
68 |
-
early_stop_epsilon=1e-5,
|
69 |
-
seed=0,
|
70 |
-
):
|
71 |
-
solver.init_prompt(prompt)
|
72 |
-
uncond_embeddings, cond_embeddings = solver.context.chunk(2)
|
73 |
-
register_attention_control(solver.model, None)
|
74 |
-
if isinstance(image_path, list):
|
75 |
-
image_gt = [load_512(path, *offsets) for path in image_path]
|
76 |
-
elif isinstance(image_path, str):
|
77 |
-
image_gt = load_512(image_path, *offsets)
|
78 |
-
else:
|
79 |
-
image_gt = np.array(Image.fromarray(image_path).resize((512, 512)))
|
80 |
-
|
81 |
-
if is_cons_inversion:
|
82 |
-
image_rec, ddim_latents = solver.cons_inversion(image_gt,
|
83 |
-
w_embed_dim=w_embed_dim,
|
84 |
-
guidance_scale=inv_guidance_scale,
|
85 |
-
seed=seed,)
|
86 |
-
else:
|
87 |
-
image_rec, ddim_latents = solver.ddim_inversion(image_gt,
|
88 |
-
n_steps=stop_step,
|
89 |
-
guidance_scale=inv_guidance_scale,
|
90 |
-
dynamic_guidance=dynamic_guidance,
|
91 |
-
tau1=tau1, tau2=tau2,
|
92 |
-
w_embed_dim=w_embed_dim)
|
93 |
-
if do_nti:
|
94 |
-
print("Null-text optimization...")
|
95 |
-
uncond_embeddings = null_optimization(solver,
|
96 |
-
ddim_latents,
|
97 |
-
nti_guidance_scale,
|
98 |
-
num_inner_steps,
|
99 |
-
early_stop_epsilon)
|
100 |
-
elif do_npi:
|
101 |
-
uncond_embeddings = [cond_embeddings] * solver.n_steps
|
102 |
-
else:
|
103 |
-
uncond_embeddings = None
|
104 |
-
return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|