MohamedRashad commited on
Commit
0e3e704
·
1 Parent(s): 1776f2c

Refactoring

Browse files
Files changed (1) hide show
  1. app.py +112 -135
app.py CHANGED
@@ -8,19 +8,16 @@ import os.path as osp
8
  import time
9
  import hashlib
10
  import argparse
11
- import shutil
12
- import re
13
  import random
14
  from pathlib import Path
15
- from typing import List
16
- import json
17
 
18
  import cv2
19
  import numpy as np
20
  import torch
21
  import torch.nn.functional as F
22
  from PIL import Image, ImageEnhance
23
- import PIL.Image as PImage
24
  from torchvision.transforms.functional import to_tensor
25
  from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
26
  from huggingface_hub import hf_hub_download
@@ -29,12 +26,54 @@ import spaces
29
 
30
  from models.infinity import Infinity
31
  from models.basic import *
32
- from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
33
  from gradio_client import Client
34
 
 
35
  torch._dynamo.config.cache_size_limit = 64
 
36
  client = Client("Qwen/Qwen2.5-72B-Instruct")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Define a function to download weights if not present
39
  def download_infinity_weights(weights_path):
40
  try:
@@ -96,60 +135,6 @@ def enhance_image(image):
96
  color_image = color_enhancer.enhance(1.05) # 增强饱和度
97
  return color_image
98
 
99
- def gen_one_img(
100
- infinity_test,
101
- vae,
102
- text_tokenizer,
103
- text_encoder,
104
- prompt,
105
- cfg_list=[],
106
- tau_list=[],
107
- negative_prompt='',
108
- scale_schedule=None,
109
- top_k=900,
110
- top_p=0.97,
111
- cfg_sc=3,
112
- cfg_exp_k=0.0,
113
- cfg_insertion_layer=-5,
114
- vae_type=0,
115
- gumbel=0,
116
- softmax_merge_topk=-1,
117
- gt_leak=-1,
118
- gt_ls_Bl=None,
119
- g_seed=None,
120
- sampling_per_bits=1,
121
- enable_positive_prompt=0,
122
- ):
123
- sstt = time.time()
124
- if not isinstance(cfg_list, list):
125
- cfg_list = [cfg_list] * len(scale_schedule)
126
- if not isinstance(tau_list, list):
127
- tau_list = [tau_list] * len(scale_schedule)
128
- text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
129
- if negative_prompt:
130
- negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
131
- else:
132
- negative_label_B_or_BLT = None
133
- print(f'cfg: {cfg_list}, tau: {tau_list}')
134
- with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
135
- stt = time.time()
136
- _, _, img_list = infinity_test.autoregressive_infer_cfg(
137
- vae=vae,
138
- scale_schedule=scale_schedule,
139
- label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
140
- B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
141
- cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
142
- returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
143
- cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
144
- vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
145
- ret_img=True, trunk_scale=1000,
146
- gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
147
- sampling_per_bits=sampling_per_bits,
148
- )
149
- print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
150
- img = img_list[0]
151
- return img
152
-
153
  def get_prompt_id(prompt):
154
  md5 = hashlib.md5()
155
  md5.update(prompt.encode('utf-8'))
@@ -173,7 +158,7 @@ def load_tokenizer(t5_path =''):
173
  text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
174
  text_tokenizer.model_max_length = 512
175
  text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
176
- text_encoder.to('cuda')
177
  text_encoder.eval()
178
  text_encoder.requires_grad_(False)
179
  return text_tokenizer, text_encoder
@@ -188,7 +173,6 @@ def load_infinity(
188
  model_path='',
189
  scale_schedule=None,
190
  vae=None,
191
- device=None, # Make device optional
192
  model_kwargs=None,
193
  text_channels=2048,
194
  apply_spatial_patchify=0,
@@ -197,13 +181,8 @@ def load_infinity(
197
  ):
198
  print(f'[Loading Infinity]')
199
 
200
- # Set device if not provided
201
- if device is None:
202
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
203
- print(f'Using device: {device}')
204
-
205
  # Set autocast dtype based on bf16 and device support
206
- if bf16 and device == 'cuda' and torch.cuda.is_bf16_supported():
207
  autocast_dtype = torch.bfloat16
208
  else:
209
  autocast_dtype = torch.float32
@@ -212,7 +191,7 @@ def load_infinity(
212
  text_maxlen = 512
213
  torch.cuda.empty_cache()
214
 
215
- with torch.amp.autocast(device_type=device, dtype=autocast_dtype), torch.no_grad():
216
  infinity_test: Infinity = Infinity(
217
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
218
  shared_aln=True, raw_scale_schedule=scale_schedule,
@@ -230,7 +209,7 @@ def load_infinity(
230
  inference_mode=True,
231
  train_h_div_w_list=[1.0],
232
  **model_kwargs,
233
- ).to(device)
234
 
235
  print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
236
 
@@ -242,17 +221,11 @@ def load_infinity(
242
  infinity_test.requires_grad_(False)
243
 
244
  print(f'[Load Infinity weights]')
245
- state_dict = torch.load(model_path, map_location=device)
246
  print(infinity_test.load_state_dict(state_dict))
247
 
248
- # Initialize random number generator, falling back to CPU if CUDA is not available
249
- try:
250
- infinity_test.rng = torch.Generator(device=device)
251
- except RuntimeError:
252
- print("CUDA device not available. Falling back to CPU...")
253
- device = 'cpu'
254
- infinity_test = infinity_test.to(device)
255
- infinity_test.rng = torch.Generator(device=device)
256
 
257
  return infinity_test
258
 
@@ -294,7 +267,7 @@ def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, t
294
  return gt_img, recons_img, all_bit_indices
295
 
296
  def load_visual_tokenizer(args):
297
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
298
  # load vae
299
  if args.vae_type in [16,18,20,24,32,64]:
300
  from models.bsq_vae.vae import vae_model
@@ -337,7 +310,7 @@ def load_transformer(vae, args):
337
  if not osp.exists(local_model_path):
338
  print(f'copy {model_path} to {local_model_path}')
339
  shutil.copyfile(model_path, local_model_path)
340
- save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
341
  print(f'copy {local_slim_model_path} to {slim_model_path}')
342
  if not osp.exists(slim_model_path):
343
  shutil.copyfile(local_slim_model_path, slim_model_path)
@@ -348,20 +321,7 @@ def load_transformer(vae, args):
348
  slim_model_path = model_path
349
  print(f'load checkpoint from {slim_model_path}')
350
 
351
- if args.model_type == 'infinity_2b':
352
- kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
353
- elif args.model_type == 'infinity_layer12':
354
- kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
355
- elif args.model_type == 'infinity_layer16':
356
- kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
357
- elif args.model_type == 'infinity_layer24':
358
- kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
359
- elif args.model_type == 'infinity_layer32':
360
- kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
361
- elif args.model_type == 'infinity_layer40':
362
- kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
363
- elif args.model_type == 'infinity_layer48':
364
- kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
365
  infinity = load_infinity(
366
  rope2d_each_sa_layer=args.rope2d_each_sa_layer,
367
  rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
@@ -372,8 +332,7 @@ def load_transformer(vae, args):
372
  model_path=slim_model_path,
373
  scale_schedule=None,
374
  vae=vae,
375
- device=None,
376
- model_kwargs=kwargs_model,
377
  text_channels=args.text_channels,
378
  apply_spatial_patchify=args.apply_spatial_patchify,
379
  use_flex_attn=args.use_flex_attn,
@@ -440,10 +399,6 @@ weights_path = Path(__file__).parent / 'weights'
440
  weights_path.mkdir(exist_ok=True)
441
  download_infinity_weights(weights_path)
442
 
443
- # Device setup
444
- dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
445
- print(f"Using dtype: {dtype}")
446
-
447
  # Define args
448
  args = argparse.Namespace(
449
  pn='1M',
@@ -465,7 +420,7 @@ args = argparse.Namespace(
465
  cache_dir='/dev/shm',
466
  checkpoint_type='torch',
467
  seed=0,
468
- bf16=1 if dtype == torch.bfloat16 else 0,
469
  save_file='tmp.jpg',
470
  enable_model_cache=False,
471
  )
@@ -478,44 +433,61 @@ infinity = load_transformer(vae, args)
478
  # Define the image generation function
479
  @spaces.GPU
480
  def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt=False):
 
481
  try:
482
- args.prompt = prompt
483
- args.cfg = cfg
484
- args.tau = tau
485
- args.h_div_w = h_div_w
486
- args.seed = seed
487
- args.enable_positive_prompt = enable_positive_prompt
 
 
 
488
 
489
- # Find the closest h_div_w_template
490
- h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))]
 
 
 
 
 
 
 
 
491
 
492
- # Get scale_schedule based on h_div_w_template_
493
- scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
494
- scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
495
 
496
- # Generate the image
497
- generated_image = gen_one_img(
498
- infinity,
499
- vae,
500
- text_tokenizer,
501
- text_encoder,
502
- prompt,
503
- g_seed=seed,
504
- gt_leak=0,
505
- gt_ls_Bl=None,
506
- cfg_list=cfg,
507
- tau_list=tau,
508
- scale_schedule=scale_schedule,
509
- cfg_insertion_layer=[args.cfg_insertion_layer],
510
- vae_type=args.vae_type,
511
- sampling_per_bits=args.sampling_per_bits,
512
- enable_positive_prompt=enable_positive_prompt,
513
- )
 
 
 
 
 
 
514
 
515
- # Convert the image to RGB and uint8
516
- image = generated_image.cpu().numpy()
517
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
518
- image = np.uint8(image)
 
519
 
520
  return image
521
  except Exception as e:
@@ -525,6 +497,11 @@ def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt=False
525
  # Set up Gradio interface
526
  with gr.Blocks() as demo:
527
  gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
 
 
 
 
 
528
 
529
  with gr.Row():
530
  with gr.Column():
 
8
  import time
9
  import hashlib
10
  import argparse
 
 
11
  import random
12
  from pathlib import Path
13
+ from typing import List, Dict, Optional
14
+ from dataclasses import dataclass
15
 
16
  import cv2
17
  import numpy as np
18
  import torch
19
  import torch.nn.functional as F
20
  from PIL import Image, ImageEnhance
 
21
  from torchvision.transforms.functional import to_tensor
22
  from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
23
  from huggingface_hub import hf_hub_download
 
26
 
27
  from models.infinity import Infinity
28
  from models.basic import *
29
+ from utils.dynamic_resolution import dynamic_resolution_h_w
30
  from gradio_client import Client
31
 
32
+ # Performance optimizations
33
  torch._dynamo.config.cache_size_limit = 64
34
+ torch.backends.cudnn.benchmark = True # Enable cudnn auto-tuner
35
  client = Client("Qwen/Qwen2.5-72B-Instruct")
36
 
37
+ @dataclass
38
+ class ModelConfig:
39
+ """Configuration for Infinity model."""
40
+ depth: int
41
+ embed_dim: int
42
+ num_heads: int
43
+ drop_path_rate: float = 0.1
44
+ mlp_ratio: float = 4.0
45
+ block_chunks: int = 8
46
+
47
+ @classmethod
48
+ def from_type(cls, model_type: str) -> 'ModelConfig':
49
+ """Create model config from predefined types."""
50
+ configs = {
51
+ 'infinity_2b': dict(depth=32, embed_dim=2048, num_heads=2048//128),
52
+ 'infinity_layer12': dict(depth=12, embed_dim=768, num_heads=8),
53
+ 'infinity_layer16': dict(depth=16, embed_dim=1152, num_heads=12),
54
+ 'infinity_layer24': dict(depth=24, embed_dim=1536, num_heads=16),
55
+ 'infinity_layer32': dict(depth=32, embed_dim=2080, num_heads=20),
56
+ 'infinity_layer40': dict(depth=40, embed_dim=2688, num_heads=24),
57
+ 'infinity_layer48': dict(depth=48, embed_dim=3360, num_heads=28),
58
+ }
59
+ if model_type not in configs:
60
+ raise ValueError(f"Unknown model type: {model_type}")
61
+ return cls(**configs[model_type])
62
+
63
+ def to_dict(self) -> Dict:
64
+ """Convert config to dictionary."""
65
+ return {
66
+ 'depth': self.depth,
67
+ 'embed_dim': self.embed_dim,
68
+ 'num_heads': self.num_heads,
69
+ 'drop_path_rate': self.drop_path_rate,
70
+ 'mlp_ratio': self.mlp_ratio,
71
+ 'block_chunks': self.block_chunks
72
+ }
73
+
74
+ # Global device configuration
75
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76
+
77
  # Define a function to download weights if not present
78
  def download_infinity_weights(weights_path):
79
  try:
 
135
  color_image = color_enhancer.enhance(1.05) # 增强饱和度
136
  return color_image
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def get_prompt_id(prompt):
139
  md5 = hashlib.md5()
140
  md5.update(prompt.encode('utf-8'))
 
158
  text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
159
  text_tokenizer.model_max_length = 512
160
  text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
161
+ text_encoder.to(DEVICE)
162
  text_encoder.eval()
163
  text_encoder.requires_grad_(False)
164
  return text_tokenizer, text_encoder
 
173
  model_path='',
174
  scale_schedule=None,
175
  vae=None,
 
176
  model_kwargs=None,
177
  text_channels=2048,
178
  apply_spatial_patchify=0,
 
181
  ):
182
  print(f'[Loading Infinity]')
183
 
 
 
 
 
 
184
  # Set autocast dtype based on bf16 and device support
185
+ if bf16 and DEVICE.type == 'cuda' and torch.cuda.is_bf16_supported():
186
  autocast_dtype = torch.bfloat16
187
  else:
188
  autocast_dtype = torch.float32
 
191
  text_maxlen = 512
192
  torch.cuda.empty_cache()
193
 
194
+ with torch.amp.autocast(device_type=DEVICE.type, dtype=autocast_dtype), torch.no_grad():
195
  infinity_test: Infinity = Infinity(
196
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
197
  shared_aln=True, raw_scale_schedule=scale_schedule,
 
209
  inference_mode=True,
210
  train_h_div_w_list=[1.0],
211
  **model_kwargs,
212
+ ).to(DEVICE)
213
 
214
  print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
215
 
 
221
  infinity_test.requires_grad_(False)
222
 
223
  print(f'[Load Infinity weights]')
224
+ state_dict = torch.load(model_path, map_location=DEVICE)
225
  print(infinity_test.load_state_dict(state_dict))
226
 
227
+ # Initialize random number generator
228
+ infinity_test.rng = torch.Generator(device=DEVICE)
 
 
 
 
 
 
229
 
230
  return infinity_test
231
 
 
267
  return gt_img, recons_img, all_bit_indices
268
 
269
  def load_visual_tokenizer(args):
270
+ device = DEVICE
271
  # load vae
272
  if args.vae_type in [16,18,20,24,32,64]:
273
  from models.bsq_vae.vae import vae_model
 
310
  if not osp.exists(local_model_path):
311
  print(f'copy {model_path} to {local_model_path}')
312
  shutil.copyfile(model_path, local_model_path)
313
+ save_slim_model(local_model_path, save_file=local_slim_model_path, device=DEVICE)
314
  print(f'copy {local_slim_model_path} to {slim_model_path}')
315
  if not osp.exists(slim_model_path):
316
  shutil.copyfile(local_slim_model_path, slim_model_path)
 
321
  slim_model_path = model_path
322
  print(f'load checkpoint from {slim_model_path}')
323
 
324
+ model_config = ModelConfig.from_type(args.model_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  infinity = load_infinity(
326
  rope2d_each_sa_layer=args.rope2d_each_sa_layer,
327
  rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
 
332
  model_path=slim_model_path,
333
  scale_schedule=None,
334
  vae=vae,
335
+ model_kwargs=model_config.to_dict(),
 
336
  text_channels=args.text_channels,
337
  apply_spatial_patchify=args.apply_spatial_patchify,
338
  use_flex_attn=args.use_flex_attn,
 
399
  weights_path.mkdir(exist_ok=True)
400
  download_infinity_weights(weights_path)
401
 
 
 
 
 
402
  # Define args
403
  args = argparse.Namespace(
404
  pn='1M',
 
420
  cache_dir='/dev/shm',
421
  checkpoint_type='torch',
422
  seed=0,
423
+ bf16=1 if torch.bfloat16 == torch.get_default_dtype() else 0,
424
  save_file='tmp.jpg',
425
  enable_model_cache=False,
426
  )
 
433
  # Define the image generation function
434
  @spaces.GPU
435
  def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt=False):
436
+ """Generate an image from a prompt with integrated generation logic."""
437
  try:
438
+ # Set random seed for reproducibility
439
+ if seed is not None:
440
+ torch.manual_seed(seed)
441
+ random.seed(seed)
442
+ np.random.seed(seed)
443
+
444
+ # Calculate image dimensions
445
+ tgt_h, tgt_w = dynamic_resolution_h_w(h_div_w)
446
+ scale_schedule = None
447
 
448
+ # Process text prompt
449
+ text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
450
+
451
+ # Set up negative prompt if needed
452
+ negative_prompt = ''
453
+ if negative_prompt:
454
+ negative_cond_tuple = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
455
+ negative_label_B_or_BLT = negative_cond_tuple[0]
456
+ else:
457
+ negative_label_B_or_BLT = None
458
 
459
+ print(f'cfg: {cfg}, tau: {tau}')
 
 
460
 
461
+ # Generate image with automatic mixed precision
462
+ with torch.amp.autocast(device_type=DEVICE.type, dtype=torch.bfloat16):
463
+ stt = time.time()
464
+ _, _, img_list = infinity.autoregressive_infer_cfg(
465
+ vae=vae,
466
+ text_cond_tuple=text_cond_tuple,
467
+ negative_label_B_or_BLT=negative_label_B_or_BLT,
468
+ cfg_list=[cfg],
469
+ tau_list=[tau],
470
+ top_k=900,
471
+ top_p=0.97,
472
+ cfg_sc=3,
473
+ cfg_exp_k=0.0,
474
+ cfg_insertion_layer=[args.cfg_insertion_layer],
475
+ vae_type=args.vae_type,
476
+ gumbel=0,
477
+ softmax_merge_topk=-1,
478
+ gt_leak=0,
479
+ gt_ls_Bl=None,
480
+ g_seed=seed,
481
+ sampling_per_bits=args.sampling_per_bits,
482
+ scale_schedule=scale_schedule,
483
+ )
484
+ print(f'inference time: {time.time()-stt:.3f}s')
485
 
486
+ # Convert the image efficiently
487
+ with torch.no_grad():
488
+ image = img_list[0].cpu().numpy()
489
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
490
+ image = np.uint8(image)
491
 
492
  return image
493
  except Exception as e:
 
497
  # Set up Gradio interface
498
  with gr.Blocks() as demo:
499
  gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
500
+ gr.Markdown("### Instructions")
501
+ gr.Markdown("1. Enter a prompt in the **Prompt Settings** section.")
502
+ gr.Markdown("2. Click the **Enhance Prompt** button to generate a more creative and detailed prompt.")
503
+ gr.Markdown("3. Adjust the **Image Settings** as desired.")
504
+ gr.Markdown("4. Click the **Generate Image** button to generate the image on the right.")
505
 
506
  with gr.Row():
507
  with gr.Column():