Spaces:
Runtime error
Runtime error
Update with Arroz con Cosas models.
Browse files- app.py +48 -16
- requirements.txt +1 -1
app.py
CHANGED
@@ -7,15 +7,24 @@ from rudalle import get_vae
|
|
7 |
from einops import rearrange
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
from modules import DenoiseUNet
|
|
|
10 |
|
11 |
-
model_repo = "pcuenq/
|
12 |
-
model_file = "
|
|
|
13 |
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
15 |
|
16 |
batch_size = 4
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def to_pil(images):
|
@@ -34,7 +43,7 @@ def gumbel_noise(t):
|
|
34 |
def gumbel_sample(t, temperature=1., dim=-1):
|
35 |
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
36 |
|
37 |
-
def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):
|
38 |
with torch.inference_mode():
|
39 |
r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)
|
40 |
temperatures = torch.linspace(temp_range[0], temp_range[1], T)
|
@@ -51,7 +60,10 @@ def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_
|
|
51 |
r, temp = r_range[i], temperatures[i]
|
52 |
logits = model(x, c, r)
|
53 |
if classifier_free_scale >= 0:
|
54 |
-
|
|
|
|
|
|
|
55 |
logits = torch.lerp(logits_uncond, logits, classifier_free_scale)
|
56 |
x = logits
|
57 |
x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
|
@@ -70,8 +82,7 @@ def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_
|
|
70 |
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
71 |
indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)
|
72 |
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
73 |
-
|
74 |
-
x_flat = gumbel_sample(x_flat, temperature=temp)
|
75 |
x = x_flat.view(x.size(0), *x.shape[2:])
|
76 |
if mask is not None:
|
77 |
x = x * mask + (1-mask) * init_x
|
@@ -90,7 +101,7 @@ def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_
|
|
90 |
vqmodel = get_vae().to(device)
|
91 |
vqmodel.eval().requires_grad_(False)
|
92 |
|
93 |
-
clip_model, _, _ = open_clip.create_model_and_transforms('ViT-
|
94 |
clip_model = clip_model.to(device).eval().requires_grad_(False)
|
95 |
|
96 |
def encode(x):
|
@@ -108,22 +119,40 @@ def decode(img_seq, shape=(32,32)):
|
|
108 |
|
109 |
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
|
110 |
state_dict = torch.load(model_path, map_location=device)
|
111 |
-
model = DenoiseUNet(num_labels=8192).to(device)
|
112 |
model.load_state_dict(state_dict)
|
113 |
model.eval().requires_grad_()
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
# -----
|
116 |
|
117 |
-
def infer(prompt):
|
118 |
-
latent_shape = (32, 32)
|
119 |
tokenized_text = tokenizer.tokenize([prompt] * batch_size).to(device)
|
|
|
120 |
with torch.inference_mode():
|
121 |
with torch.autocast(device_type="cuda"):
|
122 |
clip_embeddings = clip_model.encode_text(tokenized_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
images = sample(
|
124 |
-
model,
|
125 |
-
|
126 |
-
|
|
|
|
|
127 |
)
|
128 |
images = decode(images[-1], latent_shape)
|
129 |
return to_pil(images)
|
@@ -231,7 +260,7 @@ block = gr.Blocks(css=css)
|
|
231 |
|
232 |
with block:
|
233 |
gr.HTML(
|
234 |
-
"""
|
235 |
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
236 |
<div
|
237 |
style="
|
@@ -278,6 +307,9 @@ with block:
|
|
278 |
Paella Demo
|
279 |
</h1>
|
280 |
</div>
|
|
|
|
|
|
|
281 |
<p style="margin-bottom: 10px; font-size: 94%">
|
282 |
Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
|
283 |
</p>
|
|
|
7 |
from einops import rearrange
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
from modules import DenoiseUNet
|
10 |
+
from arroz import Diffuzz, PriorModel
|
11 |
|
12 |
+
model_repo = "pcuenq/Arroz_con_cosas"
|
13 |
+
model_file = "model_1b_img.pt"
|
14 |
+
prior_file = "prior_v1_1500k_ema_fp16.pt"
|
15 |
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
device_text = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
|
18 |
|
19 |
batch_size = 4
|
20 |
+
latent_shape = (64, 64)
|
21 |
+
|
22 |
+
generator_timesteps = 12
|
23 |
+
generator_cfg = 5
|
24 |
+
prior_timesteps = 60
|
25 |
+
prior_cfg = 3.0
|
26 |
+
prior_sampler = 'ddpm'
|
27 |
+
clip_embedding_shape = (batch_size, 1024)
|
28 |
|
29 |
|
30 |
def to_pil(images):
|
|
|
43 |
def gumbel_sample(t, temperature=1., dim=-1):
|
44 |
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
45 |
|
46 |
+
def sample(model, c, x=None, negative_embeddings=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):
|
47 |
with torch.inference_mode():
|
48 |
r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)
|
49 |
temperatures = torch.linspace(temp_range[0], temp_range[1], T)
|
|
|
60 |
r, temp = r_range[i], temperatures[i]
|
61 |
logits = model(x, c, r)
|
62 |
if classifier_free_scale >= 0:
|
63 |
+
if negative_embeddings is not None:
|
64 |
+
logits_uncond = model(x, negative_embeddings, r)
|
65 |
+
else:
|
66 |
+
logits_uncond = model(x, torch.zeros_like(c), r)
|
67 |
logits = torch.lerp(logits_uncond, logits, classifier_free_scale)
|
68 |
x = logits
|
69 |
x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
|
|
|
82 |
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
83 |
indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)
|
84 |
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
85 |
+
x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]
|
|
|
86 |
x = x_flat.view(x.size(0), *x.shape[2:])
|
87 |
if mask is not None:
|
88 |
x = x * mask + (1-mask) * init_x
|
|
|
101 |
vqmodel = get_vae().to(device)
|
102 |
vqmodel.eval().requires_grad_(False)
|
103 |
|
104 |
+
clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
|
105 |
clip_model = clip_model.to(device).eval().requires_grad_(False)
|
106 |
|
107 |
def encode(x):
|
|
|
119 |
|
120 |
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
|
121 |
state_dict = torch.load(model_path, map_location=device)
|
122 |
+
model = DenoiseUNet(num_labels=8192, c_clip=1024, c_hidden=1280, down_levels=[1, 2, 8, 32], up_levels=[32, 8, 2, 1]).to(device)
|
123 |
model.load_state_dict(state_dict)
|
124 |
model.eval().requires_grad_()
|
125 |
|
126 |
+
prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
|
127 |
+
prior_ckpt = torch.load(prior_path, map_location=device)
|
128 |
+
prior = PriorModel().to(device)
|
129 |
+
prior.load_state_dict(prior_ckpt)
|
130 |
+
prior.eval().requires_grad_(False)
|
131 |
+
diffuzz = Diffuzz(device=device)
|
132 |
+
|
133 |
+
del prior_ckpt, state_dict
|
134 |
+
|
135 |
# -----
|
136 |
|
137 |
+
def infer(prompt, negative_prompt=""):
|
|
|
138 |
tokenized_text = tokenizer.tokenize([prompt] * batch_size).to(device)
|
139 |
+
negative_text = tokenizer.tokenize([negative_prompt] * batch_size).to(device)
|
140 |
with torch.inference_mode():
|
141 |
with torch.autocast(device_type="cuda"):
|
142 |
clip_embeddings = clip_model.encode_text(tokenized_text)
|
143 |
+
neg_clip_embeddings = clip_model.encode_text(negative_text)
|
144 |
+
|
145 |
+
sampled_image_embeddings = diffuzz.sample(
|
146 |
+
prior, {'c': clip_embeddings}, clip_embedding_shape,
|
147 |
+
timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
|
148 |
+
)[-1]
|
149 |
+
|
150 |
images = sample(
|
151 |
+
model, sampled_image_embeddings, negative_embeddings=neg_clip_embeddings,
|
152 |
+
T=generator_timesteps, size=latent_shape, starting_t=0, temp_range=[2.0, 0.1],
|
153 |
+
typical_filtering=False, typical_mass=0.2, typical_min_tokens=1,
|
154 |
+
classifier_free_scale=generator_cfg, renoise_steps=generator_timesteps-1,
|
155 |
+
renoise_mode="start"
|
156 |
)
|
157 |
images = decode(images[-1], latent_shape)
|
158 |
return to_pil(images)
|
|
|
260 |
|
261 |
with block:
|
262 |
gr.HTML(
|
263 |
+
f"""
|
264 |
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
265 |
<div
|
266 |
style="
|
|
|
307 |
Paella Demo
|
308 |
</h1>
|
309 |
</div>
|
310 |
+
<p>
|
311 |
+
Running on <b>{device_text}</b>
|
312 |
+
</p>
|
313 |
<p style="margin-bottom: 10px; font-size: 94%">
|
314 |
Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
|
315 |
</p>
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
-f https://download.pytorch.org/whl/cu116
|
2 |
torch
|
3 |
rudalle
|
4 |
open_clip_torch
|
5 |
einops
|
6 |
Pillow
|
7 |
huggingface_hub
|
|
|
|
|
|
1 |
torch
|
2 |
rudalle
|
3 |
open_clip_torch
|
4 |
einops
|
5 |
Pillow
|
6 |
huggingface_hub
|
7 |
+
git+https://github.com/pabloppp/Arroz-Con-Cosas
|