ginipick commited on
Commit
5153322
·
verified ·
1 Parent(s): 5b0c1dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -191
app.py CHANGED
@@ -17,6 +17,12 @@ subprocess.run(
17
  shell=True,
18
  )
19
 
 
 
 
 
 
 
20
  os.makedirs("/home/user/app/checkpoints", exist_ok=True)
21
  from huggingface_hub import snapshot_download
22
  snapshot_download(
@@ -26,30 +32,31 @@ snapshot_download(
26
  hf_token = os.environ["HF_TOKEN"]
27
 
28
  import argparse
29
- import os
30
  import builtins
31
  import json
32
  import math
33
  import multiprocessing as mp
34
- import os
35
  import random
36
  import socket
37
  import traceback
38
 
39
- #import fairscale.nn.model_parallel.initialize as fs_init
40
  import gradio as gr
41
  import numpy as np
42
  from safetensors.torch import load_file
43
  import torch
44
- #i#mport torch.distributed as dist
45
  from torchvision.transforms.functional import to_pil_image
 
 
 
 
 
46
  import spaces
47
 
48
  from imgproc import generate_crop_size_list
49
  import models
50
  from transport import Sampler, create_transport
51
 
52
- from multiprocessing import Process,Queue,set_start_method,get_context
53
 
54
  class ModelFailure:
55
  pass
@@ -93,22 +100,19 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
93
 
94
  @torch.no_grad()
95
  def model_main(args, master_port, rank):
96
- # import here to avoid huggingface Tokenizer parallelism warnings
97
  from diffusers.models import AutoencoderKL
98
  from transformers import AutoModel, AutoTokenizer
99
 
100
- # override the default print function since the delay can be large for child process
101
  original_print = builtins.print
102
 
103
- # Redefine the print function with flush=True by default
104
  def print(*args, **kwargs):
105
  kwargs.setdefault("flush", True)
106
  original_print(*args, **kwargs)
107
 
108
- # Override the built-in print with the new version
109
  builtins.print = print
110
 
111
-
112
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
113
  print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
114
 
@@ -139,8 +143,8 @@ def model_main(args, master_port, rank):
139
 
140
  assert train_args.model_parallel_size == args.num_gpus
141
  if args.ema:
142
- print("Loading ema model.")
143
- print('load model')
144
  ckpt_path = os.path.join(
145
  args.ckpt,
146
  f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.safetensors",
@@ -155,158 +159,150 @@ def model_main(args, master_port, rank):
155
  assert os.path.exists(ckpt_path)
156
  ckpt = torch.load(ckpt_path, map_location="cuda")
157
  model.load_state_dict(ckpt, strict=True)
158
- print('load model finish')
159
 
160
  return text_encoder, tokenizer, vae, model
161
 
162
 
163
  @torch.no_grad()
164
  def inference(args, infer_args, text_encoder, tokenizer, vae, model):
165
- dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
166
- args.precision
167
- ]
168
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
169
  torch.cuda.set_device(0)
170
  with torch.autocast("cuda", dtype):
171
- while True:
172
- (
173
- cap,
174
- neg_cap,
175
- system_type,
176
- resolution,
177
- num_sampling_steps,
178
- cfg_scale,
179
- cfg_trunc,
180
- renorm_cfg,
181
- solver,
182
- t_shift,
183
- seed,
184
- scaling_method,
185
- scaling_watershed,
186
- proportional_attn,
187
- ) = infer_args
188
-
189
-
190
- system_prompt = system_type
191
- cap = system_prompt + cap
192
- if neg_cap != "":
193
- neg_cap = system_prompt + neg_cap
194
-
195
- metadata = dict(
196
- real_cap=cap,
197
- real_neg_cap=neg_cap,
198
- system_type=system_type,
199
- resolution=resolution,
200
- num_sampling_steps=num_sampling_steps,
201
- cfg_scale=cfg_scale,
202
- cfg_trunc=cfg_trunc,
203
- renorm_cfg=renorm_cfg,
204
- solver=solver,
205
- t_shift=t_shift,
206
- seed=seed,
207
- scaling_method=scaling_method,
208
- scaling_watershed=scaling_watershed,
209
- proportional_attn=proportional_attn,
210
- )
211
- print("> params:", json.dumps(metadata, indent=2))
212
-
213
- try:
214
- # begin sampler
215
- if solver == "dpm":
216
- transport = create_transport(
217
- "Linear",
218
- "velocity",
219
- )
220
- sampler = Sampler(transport)
221
- sample_fn = sampler.sample_dpm(
222
  model.forward_with_cfg,
223
  model_kwargs=model_kwargs,
224
- )
225
- else:
226
- transport = create_transport(
227
- args.path_type,
228
- args.prediction,
229
- args.loss_weight,
230
- args.train_eps,
231
- args.sample_eps,
232
- )
233
- sampler = Sampler(transport)
234
- sample_fn = sampler.sample_ode(
235
- sampling_method=solver,
236
- num_steps=num_sampling_steps,
237
- atol=args.atol,
238
- rtol=args.rtol,
239
- reverse=args.reverse,
240
- time_shifting_factor=t_shift,
241
- )
242
- # end sampler
243
-
244
- resolution = resolution.split(" ")[-1]
245
- w, h = resolution.split("x")
246
- w, h = int(w), int(h)
247
- latent_w, latent_h = w // 8, h // 8
248
- if int(seed) != 0:
249
- torch.random.manual_seed(int(seed))
250
- z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
251
- z = z.repeat(2, 1, 1, 1)
252
-
253
- with torch.no_grad():
254
- if neg_cap != "":
255
- cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
256
- else:
257
- cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
258
-
259
- cap_mask = cap_mask.to(cap_feats.device)
260
-
261
- model_kwargs = dict(
262
- cap_feats=cap_feats,
263
- cap_mask=cap_mask,
264
- cfg_scale=cfg_scale,
265
- cfg_trunc=1 - cfg_trunc,
266
- renorm_cfg=renorm_cfg,
267
  )
268
-
269
- #if dist.get_rank() == 0:
270
- print(f"> caption: {cap}")
271
- print(f"> num_sampling_steps: {num_sampling_steps}")
272
- print(f"> cfg_scale: {cfg_scale}")
273
- print("> start sample")
274
- if solver == "dpm":
275
- samples = sample_fn(z, steps=num_sampling_steps, order=2, skip_type="time_uniform_flow", method="multistep", flow_shift=t_shift)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  else:
277
- samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
278
- samples = samples[:1]
279
- print("smaple_dtype", samples.dtype)
280
-
281
- vae_scale = {
282
- "sdxl": 0.13025,
283
- "sd3": 1.5305,
284
- "ema": 0.18215,
285
- "mse": 0.18215,
286
- "cogvideox": 1.15258426,
287
- "flux": 0.3611,
288
- }["flux"]
289
- vae_shift = {
290
- "sdxl": 0.0,
291
- "sd3": 0.0609,
292
- "ema": 0.0,
293
- "mse": 0.0,
294
- "cogvideox": 0.0,
295
- "flux": 0.1159,
296
- }["flux"]
297
- print(f"> vae scale: {vae_scale}, shift: {vae_shift}")
298
- print("samples.shape", samples.shape)
299
- samples = vae.decode(samples / vae_scale + vae_shift).sample
300
- samples = (samples + 1.0) / 2.0
301
- samples.clamp_(0.0, 1.0)
302
-
303
- img = to_pil_image(samples[0].float())
304
- print("> generated image, done.")
305
-
306
- return img, metadata
307
- except Exception:
308
- print(traceback.format_exc())
309
- return ModelFailure()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
 
312
  def none_or_str(value):
@@ -322,24 +318,24 @@ def parse_transport_args(parser):
322
  type=str,
323
  default="Linear",
324
  choices=["Linear", "GVP", "VP"],
325
- help="the type of path for transport: 'Linear', 'GVP' (Geodesic Vector Pursuit), or 'VP' (Vector Pursuit).",
326
  )
327
  group.add_argument(
328
  "--prediction",
329
  type=str,
330
  default="velocity",
331
  choices=["velocity", "score", "noise"],
332
- help="the prediction model for the transport dynamics.",
333
  )
334
  group.add_argument(
335
  "--loss-weight",
336
  type=none_or_str,
337
  default=None,
338
  choices=[None, "velocity", "likelihood"],
339
- help="the weighting of different components in the loss function, can be 'velocity' for dynamic modeling, 'likelihood' for statistical consistency, or None for no weighting.",
340
  )
341
- group.add_argument("--sample-eps", type=float, help="sampling in the transport model.")
342
- group.add_argument("--train-eps", type=float, help="training to stabilize the learning process.")
343
 
344
 
345
  def parse_ode_args(parser):
@@ -356,11 +352,11 @@ def parse_ode_args(parser):
356
  default=1e-3,
357
  help="Relative tolerance for the ODE solver.",
358
  )
359
- group.add_argument("--reverse", action="store_true", help="run the ODE solver in reverse.")
360
  group.add_argument(
361
  "--likelihood",
362
  action="store_true",
363
- help="Enable calculation of likelihood during the ODE solving process.",
364
  )
365
 
366
 
@@ -372,14 +368,26 @@ def find_free_port() -> int:
372
  return port
373
 
374
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  def main():
376
  parser = argparse.ArgumentParser()
377
 
378
  parser.add_argument("--num_gpus", type=int, default=1)
379
- parser.add_argument("--ckpt", type=str,default='/home/user/app/checkpoints', required=False)
380
  parser.add_argument("--ema", action="store_true")
381
  parser.add_argument("--precision", default="bf16", choices=["bf16", "fp32"])
382
- parser.add_argument("--hf_token", type=str, default=None, help="huggingface read token for accessing gated repo.")
383
  parser.add_argument("--res", type=int, default=1024, choices=[256, 512, 1024])
384
  parser.add_argument("--port", type=int, default=12123)
385
 
@@ -397,9 +405,30 @@ def main():
397
 
398
  description = "Lumina-Image 2.0 ([Github](https://github.com/Alpha-VLLM/Lumina-Image-2.0/tree/main))"
399
 
400
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  with gr.Row():
402
- gr.Markdown(description)
403
  with gr.Row():
404
  with gr.Column():
405
  cap = gr.Textbox(
@@ -416,21 +445,20 @@ def main():
416
  value="",
417
  placeholder="Enter a negative caption.",
418
  )
419
- default_value = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts."
420
  system_type = gr.Dropdown(
421
  value=default_value,
422
  choices=[
423
- "You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts.",
424
  "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts.",
425
  "",
426
- ],
427
  label="System Type",
428
  )
429
 
430
  with gr.Row():
431
  res_choices = [f"{w}x{h}" for w, h in generate_crop_size_list((args.res // 64) ** 2, 64)]
432
- default_value = "1024x1024" # Set the default value to 256x256
433
-
434
  resolution = gr.Dropdown(
435
  value=default_value, choices=res_choices, label="Resolution"
436
  )
@@ -441,7 +469,7 @@ def main():
441
  value=40,
442
  step=1,
443
  interactive=True,
444
- label="Sampling steps",
445
  )
446
  seed = gr.Slider(
447
  minimum=0,
@@ -463,7 +491,7 @@ def main():
463
  solver = gr.Dropdown(
464
  value="euler",
465
  choices=["euler", "midpoint", "rk4"],
466
- label="solver",
467
  )
468
  t_shift = gr.Slider(
469
  minimum=1,
@@ -471,17 +499,17 @@ def main():
471
  value=6,
472
  step=1,
473
  interactive=True,
474
- label="Time shift",
475
  )
476
  cfg_scale = gr.Slider(
477
  minimum=1.0,
478
  maximum=20.0,
479
  value=4.0,
480
  interactive=True,
481
- label="CFG scale",
482
  )
483
  with gr.Row():
484
- renorm_cfg = gr.Dropdown(
485
  value=True,
486
  choices=[True, False, 2.0],
487
  label="CFG Renorm",
@@ -491,51 +519,52 @@ def main():
491
  scaling_method = gr.Dropdown(
492
  value="Time-aware",
493
  choices=["Time-aware", "None"],
494
- label="RoPE scaling method",
495
  )
496
  scaling_watershed = gr.Slider(
497
  minimum=0.0,
498
  maximum=1.0,
499
  value=0.3,
500
  interactive=True,
501
- label="Linear/NTK watershed",
502
  )
503
  with gr.Row():
504
  proportional_attn = gr.Checkbox(
505
  value=True,
506
  interactive=True,
507
- label="Proportional attention",
508
  )
509
  with gr.Row():
510
  submit_btn = gr.Button("Submit", variant="primary")
511
  with gr.Column():
512
  output_img = gr.Image(
513
- label="Generated image",
514
  interactive=False,
515
  )
516
  with gr.Accordion(label="Generation Parameters", open=True):
517
- gr_metadata = gr.JSON(label="metadata", show_label=False)
518
 
519
  with gr.Row():
520
-
521
- prompts=[ "Close-up portrait of a young woman with light brown hair, looking to the right, illuminated by warm, golden sunlight. Her hair is gently tousled, catching the light and creating a halo effect around her head. She wears a white garment with a V-neck, visible in the lower left of the frame. The background is dark and out of focus, enhancing the contrast between her illuminated face and the shadows. Soft, ethereal lighting, high contrast, warm color palette, shallow depth of field, natural backlighting, serene and contemplative mood, cinematic quality, intimate and visually striking composition.",
522
- "一个剑客,武侠风,红色腰带,戴着斗笠,低头,盖住眼睛,白色背景,细致,精品,杰作,水墨画,墨烟,墨云,泼墨,色带,墨水,墨黑白莲花,光影艺术,笔触。",
523
- "Aesthetic photograph of a bouquet of pink and white ranunculus flowers in a clear glass vase, centrally positioned on a wooden surface. The flowers are in full bloom, displaying intricate layers of petals with a soft gradient from pale pink to white. The vase is filled with water, visible through the clear glass, and the stems are submerged. In the background, a blurred vase with green stems is partially visible, adding depth to the composition. The lighting is warm and natural, casting soft shadows and highlighting the delicate textures of the petals. The scene is serene and intimate, with a focus on the organic beauty of the flowers. Photorealistic, shallow depth of field, soft natural lighting, warm color palette, high contrast, glossy texture, tranquil, visually balanced.",
524
- "一只优雅的白猫穿着一件紫色的旗袍,旗袍上绣有精致的牡丹花图案,显得高贵典雅。它头上戴着一朵金色的发饰,嘴里叼着一根象征好运的红色丝带。周围环绕着许多飘动的纸鹤和金色的光点,营造出一种祥瑞和梦幻的氛围。超写实风格。"
525
- ]
526
- prompts = [[_] for _ in prompts]
527
- gr.Examples( # noqa
528
- prompts,
529
- [cap],
530
- label="Examples",
531
- ) # noqa
532
 
533
  @spaces.GPU(duration=200)
534
- def on_submit(*infer_args, progress=gr.Progress(track_tqdm=True),):
 
 
 
 
 
 
535
  result = inference(args, infer_args, text_encoder, tokenizer, vae, model)
536
  if isinstance(result, ModelFailure):
537
  raise RuntimeError("Model failed to generate the image.")
538
-
539
  return result
540
 
541
  submit_btn.click(
 
17
  shell=True,
18
  )
19
 
20
+ # Additional dependencies for translation and UI improvements
21
+ subprocess.run(
22
+ "pip install transformers gradio safetensors torchvision diffusers",
23
+ shell=True,
24
+ )
25
+
26
  os.makedirs("/home/user/app/checkpoints", exist_ok=True)
27
  from huggingface_hub import snapshot_download
28
  snapshot_download(
 
32
  hf_token = os.environ["HF_TOKEN"]
33
 
34
  import argparse
 
35
  import builtins
36
  import json
37
  import math
38
  import multiprocessing as mp
 
39
  import random
40
  import socket
41
  import traceback
42
 
 
43
  import gradio as gr
44
  import numpy as np
45
  from safetensors.torch import load_file
46
  import torch
 
47
  from torchvision.transforms.functional import to_pil_image
48
+
49
+ # Import translation pipeline from transformers
50
+ from transformers import pipeline
51
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
52
+
53
  import spaces
54
 
55
  from imgproc import generate_crop_size_list
56
  import models
57
  from transport import Sampler, create_transport
58
 
59
+ from multiprocessing import Process, Queue, set_start_method, get_context
60
 
61
  class ModelFailure:
62
  pass
 
100
 
101
  @torch.no_grad()
102
  def model_main(args, master_port, rank):
103
+ # Import here to avoid huggingface Tokenizer parallelism warnings
104
  from diffusers.models import AutoencoderKL
105
  from transformers import AutoModel, AutoTokenizer
106
 
107
+ # Override the default print function since the delay can be large for child processes
108
  original_print = builtins.print
109
 
 
110
  def print(*args, **kwargs):
111
  kwargs.setdefault("flush", True)
112
  original_print(*args, **kwargs)
113
 
 
114
  builtins.print = print
115
 
 
116
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
117
  print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
118
 
 
143
 
144
  assert train_args.model_parallel_size == args.num_gpus
145
  if args.ema:
146
+ print("Loading EMA model.")
147
+ print('Loading model weights...')
148
  ckpt_path = os.path.join(
149
  args.ckpt,
150
  f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.safetensors",
 
159
  assert os.path.exists(ckpt_path)
160
  ckpt = torch.load(ckpt_path, map_location="cuda")
161
  model.load_state_dict(ckpt, strict=True)
162
+ print('Model weights loaded.')
163
 
164
  return text_encoder, tokenizer, vae, model
165
 
166
 
167
  @torch.no_grad()
168
  def inference(args, infer_args, text_encoder, tokenizer, vae, model):
169
+ dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision]
 
 
170
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
171
  torch.cuda.set_device(0)
172
  with torch.autocast("cuda", dtype):
173
+ (
174
+ cap,
175
+ neg_cap,
176
+ system_type,
177
+ resolution,
178
+ num_sampling_steps,
179
+ cfg_scale,
180
+ cfg_trunc,
181
+ renorm_cfg,
182
+ solver,
183
+ t_shift,
184
+ seed,
185
+ scaling_method,
186
+ scaling_watershed,
187
+ proportional_attn,
188
+ ) = infer_args
189
+
190
+ system_prompt = system_type
191
+ cap = system_prompt + cap
192
+ if neg_cap != "":
193
+ neg_cap = system_prompt + neg_cap
194
+
195
+ metadata = dict(
196
+ real_cap=cap,
197
+ real_neg_cap=neg_cap,
198
+ system_type=system_type,
199
+ resolution=resolution,
200
+ num_sampling_steps=num_sampling_steps,
201
+ cfg_scale=cfg_scale,
202
+ cfg_trunc=cfg_trunc,
203
+ renorm_cfg=renorm_cfg,
204
+ solver=solver,
205
+ t_shift=t_shift,
206
+ seed=seed,
207
+ scaling_method=scaling_method,
208
+ scaling_watershed=scaling_watershed,
209
+ proportional_attn=proportional_attn,
210
+ )
211
+ print("> Parameters:", json.dumps(metadata, indent=2))
212
+
213
+ try:
214
+ # Begin sampler
215
+ if solver == "dpm":
216
+ transport = create_transport("Linear", "velocity")
217
+ sampler = Sampler(transport)
218
+ sample_fn = sampler.sample_dpm(
 
 
 
 
 
219
  model.forward_with_cfg,
220
  model_kwargs=model_kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  )
222
+ else:
223
+ transport = create_transport(
224
+ args.path_type,
225
+ args.prediction,
226
+ args.loss_weight,
227
+ args.train_eps,
228
+ args.sample_eps,
229
+ )
230
+ sampler = Sampler(transport)
231
+ sample_fn = sampler.sample_ode(
232
+ sampling_method=solver,
233
+ num_steps=num_sampling_steps,
234
+ atol=args.atol,
235
+ rtol=args.rtol,
236
+ reverse=args.reverse,
237
+ time_shifting_factor=t_shift,
238
+ )
239
+ # End sampler
240
+
241
+ resolution = resolution.split(" ")[-1]
242
+ w, h = resolution.split("x")
243
+ w, h = int(w), int(h)
244
+ latent_w, latent_h = w // 8, h // 8
245
+ if int(seed) != 0:
246
+ torch.random.manual_seed(int(seed))
247
+ z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
248
+ z = z.repeat(2, 1, 1, 1)
249
+
250
+ with torch.no_grad():
251
+ if neg_cap != "":
252
+ cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
253
  else:
254
+ cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
255
+
256
+ cap_mask = cap_mask.to(cap_feats.device)
257
+
258
+ model_kwargs = dict(
259
+ cap_feats=cap_feats,
260
+ cap_mask=cap_mask,
261
+ cfg_scale=cfg_scale,
262
+ cfg_trunc=1 - cfg_trunc,
263
+ renorm_cfg=renorm_cfg,
264
+ )
265
+
266
+ print(f"> Caption: {cap}")
267
+ print(f"> Number of sampling steps: {num_sampling_steps}")
268
+ print(f"> CFG scale: {cfg_scale}")
269
+ print("> Starting sampling...")
270
+ if solver == "dpm":
271
+ samples = sample_fn(z, steps=num_sampling_steps, order=2, skip_type="time_uniform_flow", method="multistep", flow_shift=t_shift)
272
+ else:
273
+ samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
274
+ samples = samples[:1]
275
+ print("Sample dtype:", samples.dtype)
276
+
277
+ vae_scale = {
278
+ "sdxl": 0.13025,
279
+ "sd3": 1.5305,
280
+ "ema": 0.18215,
281
+ "mse": 0.18215,
282
+ "cogvideox": 1.15258426,
283
+ "flux": 0.3611,
284
+ }["flux"]
285
+ vae_shift = {
286
+ "sdxl": 0.0,
287
+ "sd3": 0.0609,
288
+ "ema": 0.0,
289
+ "mse": 0.0,
290
+ "cogvideox": 0.0,
291
+ "flux": 0.1159,
292
+ }["flux"]
293
+ print(f"> VAE scale: {vae_scale}, shift: {vae_shift}")
294
+ print("Samples shape:", samples.shape)
295
+ samples = vae.decode(samples / vae_scale + vae_shift).sample
296
+ samples = (samples + 1.0) / 2.0
297
+ samples.clamp_(0.0, 1.0)
298
+
299
+ img = to_pil_image(samples[0].float())
300
+ print("> Generated image successfully.")
301
+
302
+ return img, metadata
303
+ except Exception:
304
+ print(traceback.format_exc())
305
+ return ModelFailure()
306
 
307
 
308
  def none_or_str(value):
 
318
  type=str,
319
  default="Linear",
320
  choices=["Linear", "GVP", "VP"],
321
+ help="Type of path for transport: 'Linear', 'GVP' (Geodesic Vector Pursuit), or 'VP' (Vector Pursuit).",
322
  )
323
  group.add_argument(
324
  "--prediction",
325
  type=str,
326
  default="velocity",
327
  choices=["velocity", "score", "noise"],
328
+ help="Prediction model for the transport dynamics.",
329
  )
330
  group.add_argument(
331
  "--loss-weight",
332
  type=none_or_str,
333
  default=None,
334
  choices=[None, "velocity", "likelihood"],
335
+ help="Weighting of different loss components: 'velocity', 'likelihood', or None.",
336
  )
337
+ group.add_argument("--sample-eps", type=float, help="Sampling parameter in the transport model.")
338
+ group.add_argument("--train-eps", type=float, help="Training epsilon to stabilize learning.")
339
 
340
 
341
  def parse_ode_args(parser):
 
352
  default=1e-3,
353
  help="Relative tolerance for the ODE solver.",
354
  )
355
+ group.add_argument("--reverse", action="store_true", help="Run the ODE solver in reverse.")
356
  group.add_argument(
357
  "--likelihood",
358
  action="store_true",
359
+ help="Enable likelihood calculation during the ODE solving process.",
360
  )
361
 
362
 
 
368
  return port
369
 
370
 
371
+ # Utility function to translate Korean text to English if needed.
372
+ def translate_if_korean(text: str) -> str:
373
+ import re
374
+ # Check if any Korean characters are present
375
+ if re.search(r"[ㄱ-ㅎㅏ-ㅣ가-힣]", text):
376
+ print("Translating Korean prompt to English...")
377
+ translation = translator(text)
378
+ # Return the translated text from the pipeline output
379
+ return translation[0]["translation_text"]
380
+ return text
381
+
382
+
383
  def main():
384
  parser = argparse.ArgumentParser()
385
 
386
  parser.add_argument("--num_gpus", type=int, default=1)
387
+ parser.add_argument("--ckpt", type=str, default='/home/user/app/checkpoints', required=False)
388
  parser.add_argument("--ema", action="store_true")
389
  parser.add_argument("--precision", default="bf16", choices=["bf16", "fp32"])
390
+ parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face read token for accessing gated repo.")
391
  parser.add_argument("--res", type=int, default=1024, choices=[256, 512, 1024])
392
  parser.add_argument("--port", type=int, default=12123)
393
 
 
405
 
406
  description = "Lumina-Image 2.0 ([Github](https://github.com/Alpha-VLLM/Lumina-Image-2.0/tree/main))"
407
 
408
+ # Create a Gradio Blocks UI with custom CSS for a sleek, modern appearance.
409
+ custom_css = """
410
+ body {
411
+ background: linear-gradient(135deg, #1a2a6c, #b21f1f, #fdbb2d);
412
+ font-family: 'Helvetica', sans-serif;
413
+ color: #333;
414
+ }
415
+ .gradio-container {
416
+ background: #fff;
417
+ border-radius: 15px;
418
+ box-shadow: 0 8px 16px rgba(0, 0, 0, 0.25);
419
+ padding: 20px;
420
+ }
421
+ .gradio-title {
422
+ font-weight: bold;
423
+ font-size: 1.5em;
424
+ text-align: center;
425
+ margin-bottom: 10px;
426
+ }
427
+ """
428
+
429
+ with gr.Blocks(css=custom_css) as demo:
430
  with gr.Row():
431
+ gr.Markdown(f"<div class='gradio-title'>{description}</div>")
432
  with gr.Row():
433
  with gr.Column():
434
  cap = gr.Textbox(
 
445
  value="",
446
  placeholder="Enter a negative caption.",
447
  )
448
+ default_value = "You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."
449
  system_type = gr.Dropdown(
450
  value=default_value,
451
  choices=[
452
+ "You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts.",
453
  "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts.",
454
  "",
455
+ ],
456
  label="System Type",
457
  )
458
 
459
  with gr.Row():
460
  res_choices = [f"{w}x{h}" for w, h in generate_crop_size_list((args.res // 64) ** 2, 64)]
461
+ default_value = "1024x1024"
 
462
  resolution = gr.Dropdown(
463
  value=default_value, choices=res_choices, label="Resolution"
464
  )
 
469
  value=40,
470
  step=1,
471
  interactive=True,
472
+ label="Sampling Steps",
473
  )
474
  seed = gr.Slider(
475
  minimum=0,
 
491
  solver = gr.Dropdown(
492
  value="euler",
493
  choices=["euler", "midpoint", "rk4"],
494
+ label="Solver",
495
  )
496
  t_shift = gr.Slider(
497
  minimum=1,
 
499
  value=6,
500
  step=1,
501
  interactive=True,
502
+ label="Time Shift",
503
  )
504
  cfg_scale = gr.Slider(
505
  minimum=1.0,
506
  maximum=20.0,
507
  value=4.0,
508
  interactive=True,
509
+ label="CFG Scale",
510
  )
511
  with gr.Row():
512
+ renorm_cfg = gr.Dropdown(
513
  value=True,
514
  choices=[True, False, 2.0],
515
  label="CFG Renorm",
 
519
  scaling_method = gr.Dropdown(
520
  value="Time-aware",
521
  choices=["Time-aware", "None"],
522
+ label="RoPE Scaling Method",
523
  )
524
  scaling_watershed = gr.Slider(
525
  minimum=0.0,
526
  maximum=1.0,
527
  value=0.3,
528
  interactive=True,
529
+ label="Linear/NTK Watershed",
530
  )
531
  with gr.Row():
532
  proportional_attn = gr.Checkbox(
533
  value=True,
534
  interactive=True,
535
+ label="Proportional Attention",
536
  )
537
  with gr.Row():
538
  submit_btn = gr.Button("Submit", variant="primary")
539
  with gr.Column():
540
  output_img = gr.Image(
541
+ label="Generated Image",
542
  interactive=False,
543
  )
544
  with gr.Accordion(label="Generation Parameters", open=True):
545
+ gr_metadata = gr.JSON(label="Metadata", show_label=False)
546
 
547
  with gr.Row():
548
+ prompts = [
549
+ "Close-up portrait of a young woman with light brown hair, looking to the right, illuminated by warm, golden sunlight. Her hair is gently tousled, catching the light and creating a halo effect around her head. She wears a white garment with a V-neck, visible in the lower left of the frame. The background is dark and out of focus, enhancing the contrast between her illuminated face and the shadows. Soft, ethereal lighting, high contrast, warm color palette, shallow depth of field, natural backlighting, serene and contemplative mood, cinematic quality, intimate and visually striking composition.",
550
+ "하늘을 나는 용, 신비로운 분위기, 구름 위를 날며 빛나는 비늘을 가진, 전설 속의 존재, 강렬한 색채와 디테일한 묘사.",
551
+ "Aesthetic photograph of a bouquet of pink and white ranunculus flowers in a clear glass vase, centrally positioned on a wooden surface. The flowers are in full bloom, displaying intricate layers of petals with a soft gradient from pale pink to white. The vase is filled with water, visible through the clear glass, and the stems are submerged. In the background, a blurred vase with green stems is partially visible, adding depth to the composition. The lighting is warm and natural, casting soft shadows and highlighting the delicate textures of the petals. The scene is serene and intimate, with a focus on the organic beauty of the flowers. Photorealistic, shallow depth of field, soft natural lighting, warm color palette, high contrast, glossy texture, tranquil, visually balanced.",
552
+ "한只优雅的白猫穿着一件紫色的旗袍,旗袍上绣有精致的牡丹花图案,显得高贵典雅。它头上戴着一朵金色的发饰,嘴里叼着一根象征好运的红色丝带。周围环绕着许多飘动的纸鹤和金色的光点,营造出一种祥瑞和梦幻的氛围。超写实风格。"
553
+ ]
554
+ prompts = [[p] for p in prompts]
555
+ gr.Examples(prompts, [cap], label="Examples")
 
 
 
 
556
 
557
  @spaces.GPU(duration=200)
558
+ def on_submit(cap, neg_cap, system_type, resolution, num_sampling_steps, cfg_scale, cfg_trunc, renorm_cfg, solver, t_shift, seed, scaling_method, scaling_watershed, proportional_attn, progress=gr.Progress(track_tqdm=True)):
559
+ # Translate the caption and negative caption if they contain Korean characters
560
+ cap = translate_if_korean(cap)
561
+ if neg_cap and neg_cap.strip():
562
+ neg_cap = translate_if_korean(neg_cap)
563
+ # Pack updated arguments and call inference
564
+ infer_args = (cap, neg_cap, system_type, resolution, num_sampling_steps, cfg_scale, cfg_trunc, renorm_cfg, solver, t_shift, seed, scaling_method, scaling_watershed, proportional_attn)
565
  result = inference(args, infer_args, text_encoder, tokenizer, vae, model)
566
  if isinstance(result, ModelFailure):
567
  raise RuntimeError("Model failed to generate the image.")
 
568
  return result
569
 
570
  submit_btn.click(