pcuenq HF Staff commited on
Commit
f025569
·
1 Parent(s): ecb8116

Update with Arroz con Cosas models.

Browse files
Files changed (2) hide show
  1. app.py +48 -16
  2. requirements.txt +1 -1
app.py CHANGED
@@ -7,15 +7,24 @@ from rudalle import get_vae
7
  from einops import rearrange
8
  from huggingface_hub import hf_hub_download
9
  from modules import DenoiseUNet
 
10
 
11
- model_repo = "pcuenq/Paella"
12
- model_file = "model_600000.pt"
 
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
 
16
  batch_size = 4
17
- steps = 11
18
- scale = 5
 
 
 
 
 
 
19
 
20
 
21
  def to_pil(images):
@@ -34,7 +43,7 @@ def gumbel_noise(t):
34
  def gumbel_sample(t, temperature=1., dim=-1):
35
  return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
36
 
37
- def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):
38
  with torch.inference_mode():
39
  r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)
40
  temperatures = torch.linspace(temp_range[0], temp_range[1], T)
@@ -51,7 +60,10 @@ def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_
51
  r, temp = r_range[i], temperatures[i]
52
  logits = model(x, c, r)
53
  if classifier_free_scale >= 0:
54
- logits_uncond = model(x, torch.zeros_like(c), r)
 
 
 
55
  logits = torch.lerp(logits_uncond, logits, classifier_free_scale)
56
  x = logits
57
  x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
@@ -70,8 +82,7 @@ def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_
70
  sorted_indices_to_remove[..., :typical_min_tokens] = 0
71
  indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)
72
  x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
73
- # x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]
74
- x_flat = gumbel_sample(x_flat, temperature=temp)
75
  x = x_flat.view(x.size(0), *x.shape[2:])
76
  if mask is not None:
77
  x = x * mask + (1-mask) * init_x
@@ -90,7 +101,7 @@ def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_
90
  vqmodel = get_vae().to(device)
91
  vqmodel.eval().requires_grad_(False)
92
 
93
- clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
94
  clip_model = clip_model.to(device).eval().requires_grad_(False)
95
 
96
  def encode(x):
@@ -108,22 +119,40 @@ def decode(img_seq, shape=(32,32)):
108
 
109
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
110
  state_dict = torch.load(model_path, map_location=device)
111
- model = DenoiseUNet(num_labels=8192).to(device)
112
  model.load_state_dict(state_dict)
113
  model.eval().requires_grad_()
114
 
 
 
 
 
 
 
 
 
 
115
  # -----
116
 
117
- def infer(prompt):
118
- latent_shape = (32, 32)
119
  tokenized_text = tokenizer.tokenize([prompt] * batch_size).to(device)
 
120
  with torch.inference_mode():
121
  with torch.autocast(device_type="cuda"):
122
  clip_embeddings = clip_model.encode_text(tokenized_text)
 
 
 
 
 
 
 
123
  images = sample(
124
- model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],
125
- typical_filtering=True, typical_mass=0.2, typical_min_tokens=1,
126
- classifier_free_scale=scale, renoise_steps=steps, renoise_mode="start"
 
 
127
  )
128
  images = decode(images[-1], latent_shape)
129
  return to_pil(images)
@@ -231,7 +260,7 @@ block = gr.Blocks(css=css)
231
 
232
  with block:
233
  gr.HTML(
234
- """
235
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
236
  <div
237
  style="
@@ -278,6 +307,9 @@ with block:
278
  Paella Demo
279
  </h1>
280
  </div>
 
 
 
281
  <p style="margin-bottom: 10px; font-size: 94%">
282
  Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
283
  </p>
 
7
  from einops import rearrange
8
  from huggingface_hub import hf_hub_download
9
  from modules import DenoiseUNet
10
+ from arroz import Diffuzz, PriorModel
11
 
12
+ model_repo = "pcuenq/Arroz_con_cosas"
13
+ model_file = "model_1b_img.pt"
14
+ prior_file = "prior_v1_1500k_ema_fp16.pt"
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ device_text = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
18
 
19
  batch_size = 4
20
+ latent_shape = (64, 64)
21
+
22
+ generator_timesteps = 12
23
+ generator_cfg = 5
24
+ prior_timesteps = 60
25
+ prior_cfg = 3.0
26
+ prior_sampler = 'ddpm'
27
+ clip_embedding_shape = (batch_size, 1024)
28
 
29
 
30
  def to_pil(images):
 
43
  def gumbel_sample(t, temperature=1., dim=-1):
44
  return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
45
 
46
+ def sample(model, c, x=None, negative_embeddings=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):
47
  with torch.inference_mode():
48
  r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)
49
  temperatures = torch.linspace(temp_range[0], temp_range[1], T)
 
60
  r, temp = r_range[i], temperatures[i]
61
  logits = model(x, c, r)
62
  if classifier_free_scale >= 0:
63
+ if negative_embeddings is not None:
64
+ logits_uncond = model(x, negative_embeddings, r)
65
+ else:
66
+ logits_uncond = model(x, torch.zeros_like(c), r)
67
  logits = torch.lerp(logits_uncond, logits, classifier_free_scale)
68
  x = logits
69
  x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
 
82
  sorted_indices_to_remove[..., :typical_min_tokens] = 0
83
  indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)
84
  x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
85
+ x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]
 
86
  x = x_flat.view(x.size(0), *x.shape[2:])
87
  if mask is not None:
88
  x = x * mask + (1-mask) * init_x
 
101
  vqmodel = get_vae().to(device)
102
  vqmodel.eval().requires_grad_(False)
103
 
104
+ clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
105
  clip_model = clip_model.to(device).eval().requires_grad_(False)
106
 
107
  def encode(x):
 
119
 
120
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
121
  state_dict = torch.load(model_path, map_location=device)
122
+ model = DenoiseUNet(num_labels=8192, c_clip=1024, c_hidden=1280, down_levels=[1, 2, 8, 32], up_levels=[32, 8, 2, 1]).to(device)
123
  model.load_state_dict(state_dict)
124
  model.eval().requires_grad_()
125
 
126
+ prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
127
+ prior_ckpt = torch.load(prior_path, map_location=device)
128
+ prior = PriorModel().to(device)
129
+ prior.load_state_dict(prior_ckpt)
130
+ prior.eval().requires_grad_(False)
131
+ diffuzz = Diffuzz(device=device)
132
+
133
+ del prior_ckpt, state_dict
134
+
135
  # -----
136
 
137
+ def infer(prompt, negative_prompt=""):
 
138
  tokenized_text = tokenizer.tokenize([prompt] * batch_size).to(device)
139
+ negative_text = tokenizer.tokenize([negative_prompt] * batch_size).to(device)
140
  with torch.inference_mode():
141
  with torch.autocast(device_type="cuda"):
142
  clip_embeddings = clip_model.encode_text(tokenized_text)
143
+ neg_clip_embeddings = clip_model.encode_text(negative_text)
144
+
145
+ sampled_image_embeddings = diffuzz.sample(
146
+ prior, {'c': clip_embeddings}, clip_embedding_shape,
147
+ timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
148
+ )[-1]
149
+
150
  images = sample(
151
+ model, sampled_image_embeddings, negative_embeddings=neg_clip_embeddings,
152
+ T=generator_timesteps, size=latent_shape, starting_t=0, temp_range=[2.0, 0.1],
153
+ typical_filtering=False, typical_mass=0.2, typical_min_tokens=1,
154
+ classifier_free_scale=generator_cfg, renoise_steps=generator_timesteps-1,
155
+ renoise_mode="start"
156
  )
157
  images = decode(images[-1], latent_shape)
158
  return to_pil(images)
 
260
 
261
  with block:
262
  gr.HTML(
263
+ f"""
264
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
265
  <div
266
  style="
 
307
  Paella Demo
308
  </h1>
309
  </div>
310
+ <p>
311
+ Running on <b>{device_text}</b>
312
+ </p>
313
  <p style="margin-bottom: 10px; font-size: 94%">
314
  Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
315
  </p>
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- -f https://download.pytorch.org/whl/cu116
2
  torch
3
  rudalle
4
  open_clip_torch
5
  einops
6
  Pillow
7
  huggingface_hub
 
 
 
1
  torch
2
  rudalle
3
  open_clip_torch
4
  einops
5
  Pillow
6
  huggingface_hub
7
+ git+https://github.com/pabloppp/Arroz-Con-Cosas