MohamedRashad commited on
Commit
9d8246c
·
1 Parent(s): 9508399

Refactoring

Browse files
Files changed (1) hide show
  1. app.py +140 -135
app.py CHANGED
@@ -8,73 +8,33 @@ import os.path as osp
8
  import time
9
  import hashlib
10
  import argparse
 
 
11
  import random
12
  from pathlib import Path
13
- from typing import List, Dict, Optional
14
- from dataclasses import dataclass
15
 
16
  import cv2
17
  import numpy as np
18
  import torch
19
  import torch.nn.functional as F
20
  from PIL import Image, ImageEnhance
 
21
  from torchvision.transforms.functional import to_tensor
22
  from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
23
  from huggingface_hub import hf_hub_download
24
  import gradio as gr
25
  import spaces
26
- import json
27
 
28
  from models.infinity import Infinity
29
  from models.basic import *
30
- from utils.dynamic_resolution import dynamic_resolution_h_w
31
  from gradio_client import Client
32
 
33
- # Performance optimizations
34
  torch._dynamo.config.cache_size_limit = 64
35
- torch.backends.cudnn.benchmark = True # Enable cudnn auto-tuner
36
  client = Client("Qwen/Qwen2.5-72B-Instruct")
37
 
38
- @dataclass
39
- class ModelConfig:
40
- """Configuration for Infinity model."""
41
- depth: int
42
- embed_dim: int
43
- num_heads: int
44
- drop_path_rate: float = 0.1
45
- mlp_ratio: float = 4.0
46
- block_chunks: int = 8
47
-
48
- @classmethod
49
- def from_type(cls, model_type: str) -> 'ModelConfig':
50
- """Create model config from predefined types."""
51
- configs = {
52
- 'infinity_2b': dict(depth=32, embed_dim=2048, num_heads=2048//128),
53
- 'infinity_layer12': dict(depth=12, embed_dim=768, num_heads=8),
54
- 'infinity_layer16': dict(depth=16, embed_dim=1152, num_heads=12),
55
- 'infinity_layer24': dict(depth=24, embed_dim=1536, num_heads=16),
56
- 'infinity_layer32': dict(depth=32, embed_dim=2080, num_heads=20),
57
- 'infinity_layer40': dict(depth=40, embed_dim=2688, num_heads=24),
58
- 'infinity_layer48': dict(depth=48, embed_dim=3360, num_heads=28),
59
- }
60
- if model_type not in configs:
61
- raise ValueError(f"Unknown model type: {model_type}")
62
- return cls(**configs[model_type])
63
-
64
- def to_dict(self) -> Dict:
65
- """Convert config to dictionary."""
66
- return {
67
- 'depth': self.depth,
68
- 'embed_dim': self.embed_dim,
69
- 'num_heads': self.num_heads,
70
- 'drop_path_rate': self.drop_path_rate,
71
- 'mlp_ratio': self.mlp_ratio,
72
- 'block_chunks': self.block_chunks
73
- }
74
-
75
- # Global device configuration
76
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
77
-
78
  # Define a function to download weights if not present
79
  def download_infinity_weights(weights_path):
80
  try:
@@ -89,19 +49,7 @@ def download_infinity_weights(weights_path):
89
  except Exception as e:
90
  print(f"Error downloading weights: {e}")
91
 
92
- def extract_key_val(text):
93
- pattern = r'<(.+?):(.+?)>'
94
- matches = re.findall(pattern, text)
95
- key_val = {}
96
- for match in matches:
97
- key_val[match[0]] = match[1].lstrip()
98
- return key_val
99
-
100
- def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
101
- if enable_positive_prompt:
102
- print(f'before positive_prompt aug: {prompt}')
103
- prompt = aug_with_positive_prompt(prompt)
104
- print(f'after positive_prompt aug: {prompt}')
105
  print(f'prompt={prompt}')
106
  captions = [prompt]
107
  tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
@@ -118,14 +66,6 @@ def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=F
118
  text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
119
  return text_cond_tuple
120
 
121
- def aug_with_positive_prompt(prompt):
122
- for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee',
123
- 'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
124
- if key in prompt:
125
- prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
126
- break
127
- return prompt
128
-
129
  def enhance_image(image):
130
  for t in range(1):
131
  contrast_image = image.copy()
@@ -136,6 +76,71 @@ def enhance_image(image):
136
  color_image = color_enhancer.enhance(1.05) # 增强饱和度
137
  return color_image
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def get_prompt_id(prompt):
140
  md5 = hashlib.md5()
141
  md5.update(prompt.encode('utf-8'))
@@ -159,7 +164,7 @@ def load_tokenizer(t5_path =''):
159
  text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
160
  text_tokenizer.model_max_length = 512
161
  text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
162
- text_encoder.to(DEVICE)
163
  text_encoder.eval()
164
  text_encoder.requires_grad_(False)
165
  return text_tokenizer, text_encoder
@@ -174,6 +179,7 @@ def load_infinity(
174
  model_path='',
175
  scale_schedule=None,
176
  vae=None,
 
177
  model_kwargs=None,
178
  text_channels=2048,
179
  apply_spatial_patchify=0,
@@ -182,8 +188,13 @@ def load_infinity(
182
  ):
183
  print(f'[Loading Infinity]')
184
 
 
 
 
 
 
185
  # Set autocast dtype based on bf16 and device support
186
- if bf16 and DEVICE.type == 'cuda' and torch.cuda.is_bf16_supported():
187
  autocast_dtype = torch.bfloat16
188
  else:
189
  autocast_dtype = torch.float32
@@ -192,7 +203,7 @@ def load_infinity(
192
  text_maxlen = 512
193
  torch.cuda.empty_cache()
194
 
195
- with torch.amp.autocast(device_type=DEVICE.type, dtype=autocast_dtype), torch.no_grad():
196
  infinity_test: Infinity = Infinity(
197
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
198
  shared_aln=True, raw_scale_schedule=scale_schedule,
@@ -210,7 +221,7 @@ def load_infinity(
210
  inference_mode=True,
211
  train_h_div_w_list=[1.0],
212
  **model_kwargs,
213
- ).to(DEVICE)
214
 
215
  print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
216
 
@@ -222,11 +233,11 @@ def load_infinity(
222
  infinity_test.requires_grad_(False)
223
 
224
  print(f'[Load Infinity weights]')
225
- state_dict = torch.load(model_path, map_location=DEVICE)
226
  print(infinity_test.load_state_dict(state_dict))
227
 
228
- # Initialize random number generator
229
- infinity_test.rng = torch.Generator(device=DEVICE)
230
 
231
  return infinity_test
232
 
@@ -268,7 +279,7 @@ def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, t
268
  return gt_img, recons_img, all_bit_indices
269
 
270
  def load_visual_tokenizer(args):
271
- device = DEVICE
272
  # load vae
273
  if args.vae_type in [16,18,20,24,32,64]:
274
  from models.bsq_vae.vae import vae_model
@@ -311,7 +322,7 @@ def load_transformer(vae, args):
311
  if not osp.exists(local_model_path):
312
  print(f'copy {model_path} to {local_model_path}')
313
  shutil.copyfile(model_path, local_model_path)
314
- save_slim_model(local_model_path, save_file=local_slim_model_path, device=DEVICE)
315
  print(f'copy {local_slim_model_path} to {slim_model_path}')
316
  if not osp.exists(slim_model_path):
317
  shutil.copyfile(local_slim_model_path, slim_model_path)
@@ -322,7 +333,20 @@ def load_transformer(vae, args):
322
  slim_model_path = model_path
323
  print(f'load checkpoint from {slim_model_path}')
324
 
325
- model_config = ModelConfig.from_type(args.model_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  infinity = load_infinity(
327
  rope2d_each_sa_layer=args.rope2d_each_sa_layer,
328
  rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
@@ -333,7 +357,8 @@ def load_transformer(vae, args):
333
  model_path=slim_model_path,
334
  scale_schedule=None,
335
  vae=vae,
336
- model_kwargs=model_config.to_dict(),
 
337
  text_channels=args.text_channels,
338
  apply_spatial_patchify=args.apply_spatial_patchify,
339
  use_flex_attn=args.use_flex_attn,
@@ -400,6 +425,10 @@ weights_path = Path(__file__).parent / 'weights'
400
  weights_path.mkdir(exist_ok=True)
401
  download_infinity_weights(weights_path)
402
 
 
 
 
 
403
  # Define args
404
  args = argparse.Namespace(
405
  pn='1M',
@@ -421,7 +450,7 @@ args = argparse.Namespace(
421
  cache_dir='/dev/shm',
422
  checkpoint_type='torch',
423
  seed=0,
424
- bf16=1 if torch.bfloat16 == torch.get_default_dtype() else 0,
425
  save_file='tmp.jpg',
426
  enable_model_cache=False,
427
  )
@@ -433,62 +462,43 @@ infinity = load_transformer(vae, args)
433
 
434
  # Define the image generation function
435
  @spaces.GPU
436
- def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt=False):
437
- """Generate an image from a prompt with integrated generation logic."""
438
  try:
439
- # Set random seed for reproducibility
440
- if seed is not None:
441
- torch.manual_seed(seed)
442
- random.seed(seed)
443
- np.random.seed(seed)
444
-
445
- # Calculate image dimensions
446
- tgt_h, tgt_w = dynamic_resolution_h_w(h_div_w)
447
- scale_schedule = None
448
 
449
- # Process text prompt
450
- text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
451
-
452
- # Set up negative prompt if needed
453
- negative_prompt = ''
454
- if negative_prompt:
455
- negative_cond_tuple = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
456
- negative_label_B_or_BLT = negative_cond_tuple[0]
457
- else:
458
- negative_label_B_or_BLT = None
459
 
460
- print(f'cfg: {cfg}, tau: {tau}')
 
 
461
 
462
- # Generate image with automatic mixed precision
463
- with torch.amp.autocast(device_type=DEVICE.type, dtype=torch.bfloat16):
464
- stt = time.time()
465
- _, _, img_list = infinity.autoregressive_infer_cfg(
466
- vae=vae,
467
- text_cond_tuple=text_cond_tuple,
468
- negative_label_B_or_BLT=negative_label_B_or_BLT,
469
- cfg_list=[cfg],
470
- tau_list=[tau],
471
- top_k=900,
472
- top_p=0.97,
473
- cfg_sc=3,
474
- cfg_exp_k=0.0,
475
- cfg_insertion_layer=[args.cfg_insertion_layer],
476
- vae_type=args.vae_type,
477
- gumbel=0,
478
- softmax_merge_topk=-1,
479
- gt_leak=0,
480
- gt_ls_Bl=None,
481
- g_seed=seed,
482
- sampling_per_bits=args.sampling_per_bits,
483
- scale_schedule=scale_schedule,
484
- )
485
- print(f'inference time: {time.time()-stt:.3f}s')
486
 
487
- # Convert the image efficiently
488
- with torch.no_grad():
489
- image = img_list[0].cpu().numpy()
490
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
491
- image = np.uint8(image)
492
 
493
  return image
494
  except Exception as e:
@@ -498,11 +508,6 @@ def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt=False
498
  # Set up Gradio interface
499
  with gr.Blocks() as demo:
500
  gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
501
- gr.Markdown("### Instructions")
502
- gr.Markdown("1. Enter a prompt in the **Prompt Settings** section.")
503
- gr.Markdown("2. Click the **Enhance Prompt** button to generate a more creative and detailed prompt.")
504
- gr.Markdown("3. Adjust the **Image Settings** as desired.")
505
- gr.Markdown("4. Click the **Generate Image** button to generate the image on the right.")
506
 
507
  with gr.Row():
508
  with gr.Column():
 
8
  import time
9
  import hashlib
10
  import argparse
11
+ import shutil
12
+ import re
13
  import random
14
  from pathlib import Path
15
+ from typing import List
16
+ import json
17
 
18
  import cv2
19
  import numpy as np
20
  import torch
21
  import torch.nn.functional as F
22
  from PIL import Image, ImageEnhance
23
+ import PIL.Image as PImage
24
  from torchvision.transforms.functional import to_tensor
25
  from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
26
  from huggingface_hub import hf_hub_download
27
  import gradio as gr
28
  import spaces
 
29
 
30
  from models.infinity import Infinity
31
  from models.basic import *
32
+ from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
33
  from gradio_client import Client
34
 
 
35
  torch._dynamo.config.cache_size_limit = 64
 
36
  client = Client("Qwen/Qwen2.5-72B-Instruct")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Define a function to download weights if not present
39
  def download_infinity_weights(weights_path):
40
  try:
 
49
  except Exception as e:
50
  print(f"Error downloading weights: {e}")
51
 
52
+ def encode_prompt(text_tokenizer, text_encoder, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
53
  print(f'prompt={prompt}')
54
  captions = [prompt]
55
  tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
 
66
  text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
67
  return text_cond_tuple
68
 
 
 
 
 
 
 
 
 
69
  def enhance_image(image):
70
  for t in range(1):
71
  contrast_image = image.copy()
 
76
  color_image = color_enhancer.enhance(1.05) # 增强饱和度
77
  return color_image
78
 
79
+ def gen_one_img(
80
+ infinity_test,
81
+ vae,
82
+ text_tokenizer,
83
+ text_encoder,
84
+ prompt,
85
+ cfg_list=[],
86
+ tau_list=[],
87
+ negative_prompt='',
88
+ scale_schedule=None,
89
+ top_k=900,
90
+ top_p=0.97,
91
+ cfg_sc=3,
92
+ cfg_exp_k=0.0,
93
+ cfg_insertion_layer=-5,
94
+ vae_type=0,
95
+ gumbel=0,
96
+ softmax_merge_topk=-1,
97
+ gt_leak=-1,
98
+ gt_ls_Bl=None,
99
+ g_seed=None,
100
+ sampling_per_bits=1,
101
+ ):
102
+ sstt = time.time()
103
+ if not isinstance(cfg_list, list):
104
+ cfg_list = [cfg_list] * len(scale_schedule)
105
+ if not isinstance(tau_list, list):
106
+ tau_list = [tau_list] * len(scale_schedule)
107
+ text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt)
108
+ if negative_prompt:
109
+ negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
110
+ else:
111
+ negative_label_B_or_BLT = None
112
+ print(f'cfg: {cfg_list}, tau: {tau_list}')
113
+
114
+ # Set device if not provided
115
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
116
+
117
+ # Set autocast dtype based on bf16 and device support
118
+ if device == 'cuda' and torch.cuda.is_bf16_supported():
119
+ autocast_dtype = torch.bfloat16
120
+ else:
121
+ autocast_dtype = torch.float32
122
+
123
+ torch.cuda.empty_cache()
124
+
125
+ with torch.amp.autocast(device_type=device, dtype=autocast_dtype), torch.no_grad():
126
+ stt = time.time()
127
+ _, _, img_list = infinity_test.autoregressive_infer_cfg(
128
+ vae=vae,
129
+ scale_schedule=scale_schedule,
130
+ label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
131
+ B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
132
+ cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
133
+ returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
134
+ cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
135
+ vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
136
+ ret_img=True, trunk_scale=1000,
137
+ gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
138
+ sampling_per_bits=sampling_per_bits,
139
+ )
140
+ print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
141
+ img = img_list[0]
142
+ return img
143
+
144
  def get_prompt_id(prompt):
145
  md5 = hashlib.md5()
146
  md5.update(prompt.encode('utf-8'))
 
164
  text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
165
  text_tokenizer.model_max_length = 512
166
  text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
167
+ text_encoder.to('cuda')
168
  text_encoder.eval()
169
  text_encoder.requires_grad_(False)
170
  return text_tokenizer, text_encoder
 
179
  model_path='',
180
  scale_schedule=None,
181
  vae=None,
182
+ device=None, # Make device optional
183
  model_kwargs=None,
184
  text_channels=2048,
185
  apply_spatial_patchify=0,
 
188
  ):
189
  print(f'[Loading Infinity]')
190
 
191
+ # Set device if not provided
192
+ if device is None:
193
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
194
+ print(f'Using device: {device}')
195
+
196
  # Set autocast dtype based on bf16 and device support
197
+ if bf16 and device == 'cuda' and torch.cuda.is_bf16_supported():
198
  autocast_dtype = torch.bfloat16
199
  else:
200
  autocast_dtype = torch.float32
 
203
  text_maxlen = 512
204
  torch.cuda.empty_cache()
205
 
206
+ with torch.amp.autocast(device_type=device, dtype=autocast_dtype), torch.no_grad():
207
  infinity_test: Infinity = Infinity(
208
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
209
  shared_aln=True, raw_scale_schedule=scale_schedule,
 
221
  inference_mode=True,
222
  train_h_div_w_list=[1.0],
223
  **model_kwargs,
224
+ ).to(device)
225
 
226
  print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
227
 
 
233
  infinity_test.requires_grad_(False)
234
 
235
  print(f'[Load Infinity weights]')
236
+ state_dict = torch.load(model_path, map_location=device)
237
  print(infinity_test.load_state_dict(state_dict))
238
 
239
+ # Initialize random number generator on the correct device
240
+ infinity_test.rng = torch.Generator(device=device)
241
 
242
  return infinity_test
243
 
 
279
  return gt_img, recons_img, all_bit_indices
280
 
281
  def load_visual_tokenizer(args):
282
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
283
  # load vae
284
  if args.vae_type in [16,18,20,24,32,64]:
285
  from models.bsq_vae.vae import vae_model
 
322
  if not osp.exists(local_model_path):
323
  print(f'copy {model_path} to {local_model_path}')
324
  shutil.copyfile(model_path, local_model_path)
325
+ save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
326
  print(f'copy {local_slim_model_path} to {slim_model_path}')
327
  if not osp.exists(slim_model_path):
328
  shutil.copyfile(local_slim_model_path, slim_model_path)
 
333
  slim_model_path = model_path
334
  print(f'load checkpoint from {slim_model_path}')
335
 
336
+ if args.model_type == 'infinity_2b':
337
+ kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
338
+ elif args.model_type == 'infinity_layer12':
339
+ kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
340
+ elif args.model_type == 'infinity_layer16':
341
+ kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
342
+ elif args.model_type == 'infinity_layer24':
343
+ kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
344
+ elif args.model_type == 'infinity_layer32':
345
+ kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
346
+ elif args.model_type == 'infinity_layer40':
347
+ kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
348
+ elif args.model_type == 'infinity_layer48':
349
+ kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
350
  infinity = load_infinity(
351
  rope2d_each_sa_layer=args.rope2d_each_sa_layer,
352
  rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
 
357
  model_path=slim_model_path,
358
  scale_schedule=None,
359
  vae=vae,
360
+ device=None,
361
+ model_kwargs=kwargs_model,
362
  text_channels=args.text_channels,
363
  apply_spatial_patchify=args.apply_spatial_patchify,
364
  use_flex_attn=args.use_flex_attn,
 
425
  weights_path.mkdir(exist_ok=True)
426
  download_infinity_weights(weights_path)
427
 
428
+ # Device setup
429
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
430
+ print(f"Using dtype: {dtype}")
431
+
432
  # Define args
433
  args = argparse.Namespace(
434
  pn='1M',
 
450
  cache_dir='/dev/shm',
451
  checkpoint_type='torch',
452
  seed=0,
453
+ bf16=1 if dtype == torch.bfloat16 else 0,
454
  save_file='tmp.jpg',
455
  enable_model_cache=False,
456
  )
 
462
 
463
  # Define the image generation function
464
  @spaces.GPU
465
+ def generate_image(prompt, cfg, tau, h_div_w, seed):
 
466
  try:
467
+ args.prompt = prompt
468
+ args.cfg = cfg
469
+ args.tau = tau
470
+ args.h_div_w = h_div_w
471
+ args.seed = seed
 
 
 
 
472
 
473
+ # Find the closest h_div_w_template
474
+ h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))]
 
 
 
 
 
 
 
 
475
 
476
+ # Get scale_schedule based on h_div_w_template_
477
+ scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
478
+ scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
479
 
480
+ # Generate the image
481
+ generated_image = gen_one_img(
482
+ infinity,
483
+ vae,
484
+ text_tokenizer,
485
+ text_encoder,
486
+ prompt,
487
+ g_seed=seed,
488
+ gt_leak=0,
489
+ gt_ls_Bl=None,
490
+ cfg_list=cfg,
491
+ tau_list=tau,
492
+ scale_schedule=scale_schedule,
493
+ cfg_insertion_layer=[args.cfg_insertion_layer],
494
+ vae_type=args.vae_type,
495
+ sampling_per_bits=args.sampling_per_bits,
496
+ )
 
 
 
 
 
 
 
497
 
498
+ # Convert the image to RGB and uint8
499
+ image = generated_image.cpu().numpy()
500
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
501
+ image = np.uint8(image)
 
502
 
503
  return image
504
  except Exception as e:
 
508
  # Set up Gradio interface
509
  with gr.Blocks() as demo:
510
  gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
 
 
 
 
 
511
 
512
  with gr.Row():
513
  with gr.Column():