Spaces:
Configuration error
Configuration error
Erwann Millon
commited on
Commit
·
eac223c
1
Parent(s):
71b70df
cleanup and refactoring
Browse files- ImageState.py +0 -9
- animation.py +1 -6
- app.py +1 -1
- app_backend.py +21 -55
- configs.py +0 -7
- edit.py +4 -15
- img_processing.py +1 -1
- loaders.py +6 -22
- utils.py +1 -1
- vqgan_latent_ops.py +0 -14
- vqgan_only.pt +0 -3
ImageState.py
CHANGED
@@ -63,24 +63,15 @@ class ImageState:
|
|
63 |
def _decode_latent_to_pil(self, latent):
|
64 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
65 |
return custom_to_pil(current_im)
|
66 |
-
# def _get_current_vector_transforms(self):
|
67 |
-
# current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
|
68 |
-
# return (self.blend_latent, current_vector_transforms)
|
69 |
-
# @cache
|
70 |
def get_mask(self, img, mask=None):
|
71 |
if img and "mask" in img and img["mask"] is not None:
|
72 |
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
73 |
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
74 |
-
plt.imshow(attn_mask.detach().cpu(), cmap="Blues")
|
75 |
-
plt.show()
|
76 |
-
torch.save(attn_mask, "test_mask.pt")
|
77 |
print("mask set successfully")
|
78 |
-
# attn_mask = self.rescale_mask(attn_mask)
|
79 |
print(type(attn_mask))
|
80 |
print(attn_mask.shape)
|
81 |
else:
|
82 |
attn_mask = mask
|
83 |
-
print("mask in apply ", get_resized_tensor(attn_mask), get_resized_tensor(attn_mask).shape)
|
84 |
return attn_mask
|
85 |
def set_mask(self, img):
|
86 |
attn_mask = self.get_mask(img)
|
|
|
63 |
def _decode_latent_to_pil(self, latent):
|
64 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
65 |
return custom_to_pil(current_im)
|
|
|
|
|
|
|
|
|
66 |
def get_mask(self, img, mask=None):
|
67 |
if img and "mask" in img and img["mask"] is not None:
|
68 |
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
69 |
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
|
|
|
|
|
|
70 |
print("mask set successfully")
|
|
|
71 |
print(type(attn_mask))
|
72 |
print(attn_mask.shape)
|
73 |
else:
|
74 |
attn_mask = mask
|
|
|
75 |
return attn_mask
|
76 |
def set_mask(self, img):
|
77 |
attn_mask = self.get_mask(img)
|
animation.py
CHANGED
@@ -8,7 +8,6 @@ def clear_img_dir():
|
|
8 |
os.mkdir(img_dir)
|
9 |
for filename in glob.glob(img_dir+"/*"):
|
10 |
os.remove(filename)
|
11 |
-
|
12 |
|
13 |
def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
|
14 |
images = []
|
@@ -23,12 +22,8 @@ def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="
|
|
23 |
if file_name.endswith('.png'):
|
24 |
file_path = os.path.join(folder, file_name)
|
25 |
images.append(imageio.imread(file_path))
|
26 |
-
# images[0] = images[0].set_meta_data({'duration': 1})
|
27 |
-
# images[-1] = images[-1].set_meta_data({'duration': 1})
|
28 |
imageio.mimsave(gif_name, images, duration=durations)
|
29 |
return gif_name
|
30 |
|
31 |
if __name__ == "__main__":
|
32 |
-
|
33 |
-
create_gif()
|
34 |
-
# make_animation()
|
|
|
8 |
os.mkdir(img_dir)
|
9 |
for filename in glob.glob(img_dir+"/*"):
|
10 |
os.remove(filename)
|
|
|
11 |
|
12 |
def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
|
13 |
images = []
|
|
|
22 |
if file_name.endswith('.png'):
|
23 |
file_path = os.path.join(folder, file_name)
|
24 |
images.append(imageio.imread(file_path))
|
|
|
|
|
25 |
imageio.mimsave(gif_name, images, duration=durations)
|
26 |
return gif_name
|
27 |
|
28 |
if __name__ == "__main__":
|
29 |
+
create_gif()
|
|
|
|
app.py
CHANGED
@@ -4,7 +4,7 @@ import sys
|
|
4 |
|
5 |
import wandb
|
6 |
|
7 |
-
from
|
8 |
|
9 |
sys.path.append("taming-transformers")
|
10 |
import functools
|
|
|
4 |
|
5 |
import wandb
|
6 |
|
7 |
+
from presets import set_major_global, set_major_local, set_small_local
|
8 |
|
9 |
sys.path.append("taming-transformers")
|
10 |
import functools
|
app_backend.py
CHANGED
@@ -81,7 +81,7 @@ class ImagePromptOptimizer(nn.Module):
|
|
81 |
self.make_grid = make_grid
|
82 |
self.return_val = return_val
|
83 |
self.quantize = quantize
|
84 |
-
self.disc = load_disc(self.device)
|
85 |
self.lpips_weight = lpips_weight
|
86 |
self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
|
87 |
def disc_loss_fn(self, logits):
|
@@ -89,7 +89,7 @@ class ImagePromptOptimizer(nn.Module):
|
|
89 |
def set_latent(self, latent):
|
90 |
self.latent = latent.detach().to(self.device)
|
91 |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
92 |
-
self.
|
93 |
self.iterations = iterations
|
94 |
self.lr = lr
|
95 |
self.lpips_weight = lpips_weight
|
@@ -131,32 +131,29 @@ class ImagePromptOptimizer(nn.Module):
|
|
131 |
else:
|
132 |
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
133 |
plt.show()
|
134 |
-
def
|
135 |
-
# print("attnmask 1")
|
136 |
-
# print(f"input grad.shape = {grad.shape}")
|
137 |
-
# print(f"input grad = {get_resized_tensor(grad)}")
|
138 |
newgrad = grad
|
139 |
-
if self.
|
140 |
-
|
141 |
-
newgrad = grad * (self.attn_mask)
|
142 |
-
# print("output grad, ", get_resized_tensor(newgrad))
|
143 |
-
# print("end atn 1")
|
144 |
return newgrad
|
145 |
-
def
|
146 |
-
# print("attnmask 2")
|
147 |
-
# print(f"input grad.shape = {grad.shape}")
|
148 |
-
# print(f"input grad = {get_resized_tensor(grad)}")
|
149 |
newgrad = grad
|
150 |
-
if self.
|
151 |
-
|
152 |
-
newgrad = grad * ((self.attn_mask - 1) * -1)
|
153 |
-
# print("output grad, ", get_resized_tensor(newgrad))
|
154 |
-
# print("end atn 2")
|
155 |
return newgrad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
def optimize(self, latent, pos_prompts, neg_prompts):
|
158 |
self.set_latent(latent)
|
159 |
-
# self.make_grid=True
|
160 |
transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
|
161 |
original_img = loop_post_process(transformed_img)
|
162 |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
@@ -167,27 +164,14 @@ class ImagePromptOptimizer(nn.Module):
|
|
167 |
for i in tqdm(range(self.iterations)):
|
168 |
optim.zero_grad()
|
169 |
transformed_img = self(vector)
|
170 |
-
processed_img =
|
171 |
-
processed_img.retain_grad()
|
172 |
-
lpips_input = processed_img.clone()
|
173 |
-
lpips_input.register_hook(self.attn_masking2)
|
174 |
-
lpips_input.retain_grad()
|
175 |
-
clip_clone = processed_img.clone()
|
176 |
-
clip_clone.register_hook(self.attn_masking)
|
177 |
-
clip_clone.retain_grad()
|
178 |
with torch.autocast("cuda"):
|
179 |
-
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts,
|
180 |
print("CLIP loss", clip_loss)
|
181 |
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
182 |
print("LPIPS loss: ", perceptual_loss)
|
183 |
-
with torch.no_grad():
|
184 |
-
disc_logits = self.disc(transformed_img)
|
185 |
-
disc_loss = self.disc_loss_fn(disc_logits)
|
186 |
-
print(f"disc_loss = {disc_loss}")
|
187 |
-
disc_loss2 = self.disc(processed_img)
|
188 |
if log:
|
189 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
190 |
-
wandb.log({"Discriminator Loss": disc_loss})
|
191 |
wandb.log({"CLIP Loss": clip_loss})
|
192 |
clip_loss.backward(retain_graph=True)
|
193 |
perceptual_loss.backward(retain_graph=True)
|
@@ -207,7 +191,7 @@ class ImagePromptOptimizer(nn.Module):
|
|
207 |
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
208 |
processed_img.retain_grad()
|
209 |
lpips_input = processed_img.clone()
|
210 |
-
lpips_input.register_hook(self.
|
211 |
lpips_input.retain_grad()
|
212 |
with torch.autocast("cuda"):
|
213 |
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
@@ -216,28 +200,10 @@ class ImagePromptOptimizer(nn.Module):
|
|
216 |
disc_loss = self.disc_loss_fn(disc_logits)
|
217 |
print(f"disc_loss = {disc_loss}")
|
218 |
disc_loss2 = self.disc(processed_img)
|
219 |
-
# print(f"disc_loss2 = {disc_loss2}")
|
220 |
if log:
|
221 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
222 |
print("LPIPS loss: ", perceptual_loss)
|
223 |
perceptual_loss.backward(retain_graph=True)
|
224 |
optim.step()
|
225 |
yield vector
|
226 |
-
# torch.save(vector, "nose_vector.pt")
|
227 |
-
# print("")
|
228 |
-
# print("DISC STEPS")
|
229 |
-
# print("*************")
|
230 |
-
# for i in range(self.reconstruction_steps):
|
231 |
-
# optim.zero_grad()
|
232 |
-
# transformed_img = self(vector)
|
233 |
-
# processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
234 |
-
# disc_logits = self.disc(transformed_img)
|
235 |
-
# disc_loss = self.disc_loss_fn(disc_logits)
|
236 |
-
# print(f"disc_loss = {disc_loss}")
|
237 |
-
# if log:
|
238 |
-
# wandb.log({"Disc Loss": disc_loss})
|
239 |
-
# print("LPIPS loss: ", perceptual_loss)
|
240 |
-
# disc_loss.backward(retain_graph=True)
|
241 |
-
# optim.step()
|
242 |
-
# yield vector
|
243 |
yield vector if self.return_val == "vector" else self.latent + vector
|
|
|
81 |
self.make_grid = make_grid
|
82 |
self.return_val = return_val
|
83 |
self.quantize = quantize
|
84 |
+
# self.disc = load_disc(self.device)
|
85 |
self.lpips_weight = lpips_weight
|
86 |
self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
|
87 |
def disc_loss_fn(self, logits):
|
|
|
89 |
def set_latent(self, latent):
|
90 |
self.latent = latent.detach().to(self.device)
|
91 |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
92 |
+
self._attn_mask = attn_mask
|
93 |
self.iterations = iterations
|
94 |
self.lr = lr
|
95 |
self.lpips_weight = lpips_weight
|
|
|
131 |
else:
|
132 |
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
133 |
plt.show()
|
134 |
+
def _attn_mask(self, grad):
|
|
|
|
|
|
|
135 |
newgrad = grad
|
136 |
+
if self._attn_mask is not None:
|
137 |
+
newgrad = grad * (self._attn_mask)
|
|
|
|
|
|
|
138 |
return newgrad
|
139 |
+
def _attn_mask_inverse(self, grad):
|
|
|
|
|
|
|
140 |
newgrad = grad
|
141 |
+
if self._attn_mask is not None:
|
142 |
+
newgrad = grad * ((self._attn_mask - 1) * -1)
|
|
|
|
|
|
|
143 |
return newgrad
|
144 |
+
def _get_next_inputs(self, transformed_img):
|
145 |
+
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
146 |
+
processed_img.retain_grad()
|
147 |
+
lpips_input = processed_img.clone()
|
148 |
+
lpips_input.register_hook(self._attn_mask_inverse)
|
149 |
+
lpips_input.retain_grad()
|
150 |
+
clip_input = processed_img.clone()
|
151 |
+
clip_input.register_hook(self._attn_mask)
|
152 |
+
clip_input.retain_grad()
|
153 |
+
return processed_img, lpips_input, clip_input
|
154 |
|
155 |
def optimize(self, latent, pos_prompts, neg_prompts):
|
156 |
self.set_latent(latent)
|
|
|
157 |
transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
|
158 |
original_img = loop_post_process(transformed_img)
|
159 |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
|
|
164 |
for i in tqdm(range(self.iterations)):
|
165 |
optim.zero_grad()
|
166 |
transformed_img = self(vector)
|
167 |
+
processed_img, lpips_input, clip_input = self._get_next_inputs(transformed_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
with torch.autocast("cuda"):
|
169 |
+
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_input)
|
170 |
print("CLIP loss", clip_loss)
|
171 |
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
172 |
print("LPIPS loss: ", perceptual_loss)
|
|
|
|
|
|
|
|
|
|
|
173 |
if log:
|
174 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
|
|
175 |
wandb.log({"CLIP Loss": clip_loss})
|
176 |
clip_loss.backward(retain_graph=True)
|
177 |
perceptual_loss.backward(retain_graph=True)
|
|
|
191 |
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
192 |
processed_img.retain_grad()
|
193 |
lpips_input = processed_img.clone()
|
194 |
+
lpips_input.register_hook(self._attn_mask_inverse)
|
195 |
lpips_input.retain_grad()
|
196 |
with torch.autocast("cuda"):
|
197 |
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
|
|
200 |
disc_loss = self.disc_loss_fn(disc_logits)
|
201 |
print(f"disc_loss = {disc_loss}")
|
202 |
disc_loss2 = self.disc(processed_img)
|
|
|
203 |
if log:
|
204 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
205 |
print("LPIPS loss: ", perceptual_loss)
|
206 |
perceptual_loss.backward(retain_graph=True)
|
207 |
optim.step()
|
208 |
yield vector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
yield vector if self.return_val == "vector" else self.latent + vector
|
configs.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
def set_small_local():
|
3 |
-
return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
|
4 |
-
def set_major_local():
|
5 |
-
return (gr.Slider.update(value=25), gr.Slider.update(value=0.25), gr.Slider.update(value=35), gr.Slider.update(value=10))
|
6 |
-
def set_major_global():
|
7 |
-
return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
edit.py
CHANGED
@@ -17,13 +17,13 @@ from utils import get_device
|
|
17 |
|
18 |
|
19 |
def get_embedding(model, path=None, img=None, device="cpu"):
|
20 |
-
assert path
|
21 |
if img is not None:
|
22 |
raise NotImplementedError
|
23 |
x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
|
24 |
x_processed = preprocess_vqgan(x)
|
25 |
-
|
26 |
-
return
|
27 |
|
28 |
|
29 |
def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
|
@@ -47,23 +47,12 @@ def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, devi
|
|
47 |
|
48 |
if __name__ == "__main__":
|
49 |
device = get_device()
|
50 |
-
# conf_path = "logs/2021-04-23T18-11-19_celebahq_transformer/configs/2021-04-23T18-11-19-project.yaml"
|
51 |
ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
|
52 |
-
# ckpt_path = "./faceshq/faceshq.pt"
|
53 |
conf_path = "./unwrapped.yaml"
|
54 |
-
# conf_path = "./faceshq/faceshq.yaml"
|
55 |
config = load_config(conf_path, display=False)
|
56 |
model = taming.models.vqgan.VQModel(**config.model.params)
|
57 |
sd = torch.load("./vqgan_only.pt", map_location="mps")
|
58 |
model.load_state_dict(sd, strict=True)
|
59 |
model.to(device)
|
60 |
blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
|
61 |
-
plt.show()
|
62 |
-
|
63 |
-
demo = gr.Interface(
|
64 |
-
get_image,
|
65 |
-
inputs=gr.inputs.Image(label="UploadZz a black and white face", type="filepath"),
|
66 |
-
outputs="image",
|
67 |
-
title="Upload a black and white face and get a colorized image!",
|
68 |
-
)
|
69 |
-
|
|
|
17 |
|
18 |
|
19 |
def get_embedding(model, path=None, img=None, device="cpu"):
|
20 |
+
assert path or img, "Input either path or tensor"
|
21 |
if img is not None:
|
22 |
raise NotImplementedError
|
23 |
x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
|
24 |
x_processed = preprocess_vqgan(x)
|
25 |
+
z, _, [_, _, indices] = model.encode(x_processed)
|
26 |
+
return z
|
27 |
|
28 |
|
29 |
def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
|
|
|
47 |
|
48 |
if __name__ == "__main__":
|
49 |
device = get_device()
|
|
|
50 |
ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
|
|
|
51 |
conf_path = "./unwrapped.yaml"
|
|
|
52 |
config = load_config(conf_path, display=False)
|
53 |
model = taming.models.vqgan.VQModel(**config.model.params)
|
54 |
sd = torch.load("./vqgan_only.pt", map_location="mps")
|
55 |
model.load_state_dict(sd, strict=True)
|
56 |
model.to(device)
|
57 |
blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
|
58 |
+
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_processing.py
CHANGED
@@ -32,7 +32,7 @@ def preprocess(img, target_image_size=256, map_dalle=False):
|
|
32 |
return img
|
33 |
|
34 |
def preprocess_vqgan(x):
|
35 |
-
x = 2
|
36 |
return x
|
37 |
|
38 |
def custom_to_pil(x, process=True, mode="RGB"):
|
|
|
32 |
return img
|
33 |
|
34 |
def preprocess_vqgan(x):
|
35 |
+
x = 2. * x - 1.
|
36 |
return x
|
37 |
|
38 |
def custom_to_pil(x, process=True, mode="RGB"):
|
loaders.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import importlib
|
2 |
-
|
3 |
import numpy as np
|
4 |
import taming
|
5 |
import torch
|
@@ -7,9 +6,8 @@ import yaml
|
|
7 |
from omegaconf import OmegaConf
|
8 |
from PIL import Image
|
9 |
from taming.models.vqgan import VQModel
|
10 |
-
|
11 |
from utils import get_device
|
12 |
-
|
13 |
|
14 |
def load_config(config_path, display=False):
|
15 |
config = OmegaConf.load(config_path)
|
@@ -17,37 +15,23 @@ def load_config(config_path, display=False):
|
|
17 |
print(yaml.dump(OmegaConf.to_container(config)))
|
18 |
return config
|
19 |
|
20 |
-
# def load_disc(device):
|
21 |
-
# dconf = load_config("disc_config.yaml")
|
22 |
-
# sd = torch.load("disc.pt", map_location=device)
|
23 |
-
# # print(sd.keys())
|
24 |
-
# model = discriminator.NLayerDiscriminator()
|
25 |
-
# model.load_state_dict(sd, strict=True)
|
26 |
-
# model.to(device)
|
27 |
-
# return model
|
28 |
-
# print(dconf.keys())
|
29 |
-
|
30 |
def load_default(device):
|
31 |
-
# device = get_device()
|
32 |
ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
|
33 |
conf_path = "./unwrapped.yaml"
|
34 |
config = load_config(conf_path, display=False)
|
35 |
model = taming.models.vqgan.VQModel(**config.model.params)
|
36 |
-
sd = torch.load("./vqgan_only.pt", map_location=device)
|
37 |
model.load_state_dict(sd, strict=True)
|
38 |
model.to(device)
|
39 |
return model
|
40 |
|
41 |
|
42 |
def load_vqgan(config, ckpt_path=None, is_gumbel=False):
|
43 |
-
if is_gumbel:
|
44 |
-
model = GumbelVQ(**config.model.params)
|
45 |
-
else:
|
46 |
model = VQModel(**config.model.params)
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
def load_ffhq():
|
53 |
conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
|
|
|
1 |
import importlib
|
|
|
2 |
import numpy as np
|
3 |
import taming
|
4 |
import torch
|
|
|
6 |
from omegaconf import OmegaConf
|
7 |
from PIL import Image
|
8 |
from taming.models.vqgan import VQModel
|
|
|
9 |
from utils import get_device
|
10 |
+
|
11 |
|
12 |
def load_config(config_path, display=False):
|
13 |
config = OmegaConf.load(config_path)
|
|
|
15 |
print(yaml.dump(OmegaConf.to_container(config)))
|
16 |
return config
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def load_default(device):
|
|
|
19 |
ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
|
20 |
conf_path = "./unwrapped.yaml"
|
21 |
config = load_config(conf_path, display=False)
|
22 |
model = taming.models.vqgan.VQModel(**config.model.params)
|
23 |
+
sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device)
|
24 |
model.load_state_dict(sd, strict=True)
|
25 |
model.to(device)
|
26 |
return model
|
27 |
|
28 |
|
29 |
def load_vqgan(config, ckpt_path=None, is_gumbel=False):
|
|
|
|
|
|
|
30 |
model = VQModel(**config.model.params)
|
31 |
+
if ckpt_path is not None:
|
32 |
+
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
33 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
34 |
+
return model.eval()
|
35 |
|
36 |
def load_ffhq():
|
37 |
conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
|
utils.py
CHANGED
@@ -7,10 +7,10 @@ import torch.nn.functional as F
|
|
7 |
from skimage.color import lab2rgb, rgb2lab
|
8 |
from torch import nn
|
9 |
|
10 |
-
|
11 |
def freeze_module(module):
|
12 |
for param in module.parameters():
|
13 |
param.requires_grad = False
|
|
|
14 |
def get_device():
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
|
|
7 |
from skimage.color import lab2rgb, rgb2lab
|
8 |
from torch import nn
|
9 |
|
|
|
10 |
def freeze_module(module):
|
11 |
for param in module.parameters():
|
12 |
param.requires_grad = False
|
13 |
+
|
14 |
def get_device():
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
vqgan_latent_ops.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
from gradient_flow_ops import ReplaceGrad
|
6 |
-
|
7 |
-
replace_grad = ReplaceGrad.apply
|
8 |
-
|
9 |
-
def vector_quantize(x, codebook):
|
10 |
-
|
11 |
-
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
|
12 |
-
indices = d.argmin(-1)
|
13 |
-
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
|
14 |
-
return replace_grad(x_q, x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vqgan_only.pt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:8e39472bae4489764c0ffc70ba84ec7815f245781020ce55cc2e7adc60e580e4
|
3 |
-
size 288690579
|
|
|
|
|
|
|
|