import subprocess import os subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) subprocess.run( "pip install huggingface_hub==0.25.0", shell=True, ) subprocess.run( "pip install numpy==1.26.4", shell=True, ) # Additional dependencies for translation and UI improvements subprocess.run( "pip install transformers gradio safetensors torchvision diffusers", shell=True, ) os.makedirs("/home/user/app/checkpoints", exist_ok=True) from huggingface_hub import snapshot_download snapshot_download( repo_id="Alpha-VLLM/Lumina-Image-2.0", local_dir="/home/user/app/checkpoints" ) hf_token = os.environ["HF_TOKEN"] import argparse import builtins import json import math import multiprocessing as mp import random import socket import traceback import gradio as gr import numpy as np from safetensors.torch import load_file import torch from torchvision.transforms.functional import to_pil_image # Import translation pipeline from transformers from transformers import pipeline translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") import spaces from imgproc import generate_crop_size_list import models from transport import Sampler, create_transport from multiprocessing import Process, Queue, set_start_method, get_context class ModelFailure: pass # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True): captions = [] for caption in prompt_batch: if random.random() < proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): captions.append(random.choice(caption) if is_train else caption[0]) with torch.no_grad(): text_inputs = tokenizer( captions, padding=True, pad_to_multiple_of=8, max_length=256, truncation=True, return_tensors="pt", ) print(f"Text Encoder Device: {text_encoder.device}") text_input_ids = text_inputs.input_ids.cuda() prompt_masks = text_inputs.attention_mask.cuda() print(f"Text Input Ids Device: {text_input_ids.device}") print(f"Prompt Masks Device: {prompt_masks.device}") prompt_embeds = text_encoder( input_ids=text_input_ids, attention_mask=prompt_masks, output_hidden_states=True, ).hidden_states[-2] text_encoder.cpu() return prompt_embeds, prompt_masks @torch.no_grad() def model_main(args, master_port, rank): # Import here to avoid huggingface Tokenizer parallelism warnings from diffusers.models import AutoencoderKL from transformers import AutoModel, AutoTokenizer # Override the default print function since the delay can be large for child processes original_print = builtins.print def print(*args, **kwargs): kwargs.setdefault("flush", True) original_print(*args, **kwargs) builtins.print = print train_args = torch.load(os.path.join(args.ckpt, "model_args.pth")) print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2)) print(f"Creating lm: Gemma-2-2B") dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision] text_encoder = AutoModel.from_pretrained( "google/gemma-2-2b", torch_dtype=dtype, token=hf_token ).eval().to("cuda") cap_feat_dim = text_encoder.config.hidden_size if args.num_gpus > 1: raise NotImplementedError("Inference with >1 GPUs not yet supported") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", token=hf_token) tokenizer.padding_side = "right" vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", token=hf_token).cuda() print(f"Creating DiT: {train_args.model}") model = models.__dict__[train_args.model]( in_channels=16, qk_norm=train_args.qk_norm, cap_feat_dim=cap_feat_dim, ) model.eval().to("cuda", dtype=dtype) assert train_args.model_parallel_size == args.num_gpus if args.ema: print("Loading EMA model.") print('Loading model weights...') ckpt_path = os.path.join( args.ckpt, f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.safetensors", ) if os.path.exists(ckpt_path): ckpt = load_file(ckpt_path) else: ckpt_path = os.path.join( args.ckpt, f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.pth", ) assert os.path.exists(ckpt_path) ckpt = torch.load(ckpt_path, map_location="cuda") model.load_state_dict(ckpt, strict=True) print('Model weights loaded.') return text_encoder, tokenizer, vae, model @torch.no_grad() def inference(args, infer_args, text_encoder, tokenizer, vae, model): dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision] train_args = torch.load(os.path.join(args.ckpt, "model_args.pth")) torch.cuda.set_device(0) with torch.autocast("cuda", dtype): ( cap, neg_cap, system_type, resolution, num_sampling_steps, cfg_scale, cfg_trunc, renorm_cfg, solver, t_shift, seed, scaling_method, scaling_watershed, proportional_attn, ) = infer_args system_prompt = system_type cap = system_prompt + cap if neg_cap != "": neg_cap = system_prompt + neg_cap metadata = dict( real_cap=cap, real_neg_cap=neg_cap, system_type=system_type, resolution=resolution, num_sampling_steps=num_sampling_steps, cfg_scale=cfg_scale, cfg_trunc=cfg_trunc, renorm_cfg=renorm_cfg, solver=solver, t_shift=t_shift, seed=seed, scaling_method=scaling_method, scaling_watershed=scaling_watershed, proportional_attn=proportional_attn, ) print("> Parameters:", json.dumps(metadata, indent=2)) try: # Begin sampler if solver == "dpm": transport = create_transport("Linear", "velocity") sampler = Sampler(transport) sample_fn = sampler.sample_dpm( model.forward_with_cfg, model_kwargs=model_kwargs, ) else: transport = create_transport( args.path_type, args.prediction, args.loss_weight, args.train_eps, args.sample_eps, ) sampler = Sampler(transport) sample_fn = sampler.sample_ode( sampling_method=solver, num_steps=num_sampling_steps, atol=args.atol, rtol=args.rtol, reverse=args.reverse, time_shifting_factor=t_shift, ) # End sampler resolution = resolution.split(" ")[-1] w, h = resolution.split("x") w, h = int(w), int(h) latent_w, latent_h = w // 8, h // 8 if int(seed) != 0: torch.random.manual_seed(int(seed)) z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype) z = z.repeat(2, 1, 1, 1) with torch.no_grad(): if neg_cap != "": cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0) else: cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0) cap_mask = cap_mask.to(cap_feats.device) model_kwargs = dict( cap_feats=cap_feats, cap_mask=cap_mask, cfg_scale=cfg_scale, cfg_trunc=1 - cfg_trunc, renorm_cfg=renorm_cfg, ) print(f"> Caption: {cap}") print(f"> Number of sampling steps: {num_sampling_steps}") print(f"> CFG scale: {cfg_scale}") print("> Starting sampling...") if solver == "dpm": samples = sample_fn(z, steps=num_sampling_steps, order=2, skip_type="time_uniform_flow", method="multistep", flow_shift=t_shift) else: samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1] samples = samples[:1] print("Sample dtype:", samples.dtype) vae_scale = { "sdxl": 0.13025, "sd3": 1.5305, "ema": 0.18215, "mse": 0.18215, "cogvideox": 1.15258426, "flux": 0.3611, }["flux"] vae_shift = { "sdxl": 0.0, "sd3": 0.0609, "ema": 0.0, "mse": 0.0, "cogvideox": 0.0, "flux": 0.1159, }["flux"] print(f"> VAE scale: {vae_scale}, shift: {vae_shift}") print("Samples shape:", samples.shape) samples = vae.decode(samples / vae_scale + vae_shift).sample samples = (samples + 1.0) / 2.0 samples.clamp_(0.0, 1.0) img = to_pil_image(samples[0].float()) print("> Generated image successfully.") return img, metadata except Exception: print(traceback.format_exc()) return ModelFailure() def none_or_str(value): if value == "None": return None return value def parse_transport_args(parser): group = parser.add_argument_group("Transport arguments") group.add_argument( "--path-type", type=str, default="Linear", choices=["Linear", "GVP", "VP"], help="Type of path for transport: 'Linear', 'GVP' (Geodesic Vector Pursuit), or 'VP' (Vector Pursuit).", ) group.add_argument( "--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"], help="Prediction model for the transport dynamics.", ) group.add_argument( "--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"], help="Weighting of different loss components: 'velocity', 'likelihood', or None.", ) group.add_argument("--sample-eps", type=float, help="Sampling parameter in the transport model.") group.add_argument("--train-eps", type=float, help="Training epsilon to stabilize learning.") def parse_ode_args(parser): group = parser.add_argument_group("ODE arguments") group.add_argument( "--atol", type=float, default=1e-6, help="Absolute tolerance for the ODE solver.", ) group.add_argument( "--rtol", type=float, default=1e-3, help="Relative tolerance for the ODE solver.", ) group.add_argument("--reverse", action="store_true", help="Run the ODE solver in reverse.") group.add_argument( "--likelihood", action="store_true", help="Enable likelihood calculation during the ODE solving process.", ) def find_free_port() -> int: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() return port # Utility function to translate Korean text to English if needed. def translate_if_korean(text: str) -> str: import re # Check if any Korean characters are present if re.search(r"[ㄱ-ㅎㅏ-ㅣ가-힣]", text): print("Translating Korean prompt to English...") translation = translator(text) # Return the translated text from the pipeline output return translation[0]["translation_text"] return text def main(): parser = argparse.ArgumentParser() parser.add_argument("--num_gpus", type=int, default=1) parser.add_argument("--ckpt", type=str, default='/home/user/app/checkpoints', required=False) parser.add_argument("--ema", action="store_true") parser.add_argument("--precision", default="bf16", choices=["bf16", "fp32"]) parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face read token for accessing gated repo.") parser.add_argument("--res", type=int, default=1024, choices=[256, 512, 1024]) parser.add_argument("--port", type=int, default=12123) parse_transport_args(parser) parse_ode_args(parser) args = parser.parse_known_args()[0] if args.num_gpus != 1: raise NotImplementedError("Multi-GPU Inference is not yet supported") master_port = find_free_port() text_encoder, tokenizer, vae, model = model_main(args, master_port, 0) description = "Lumina-Image 2.0 ([Github](https://github.com/Alpha-VLLM/Lumina-Image-2.0/tree/main))" # Create a Gradio Blocks UI with custom CSS for a sleek, modern appearance. custom_css = """ body { background: linear-gradient(135deg, #1a2a6c, #b21f1f, #fdbb2d); font-family: 'Helvetica', sans-serif; color: #333; } .gradio-container { background: #fff; border-radius: 15px; box-shadow: 0 8px 16px rgba(0, 0, 0, 0.25); padding: 20px; } .gradio-title { font-weight: bold; font-size: 1.5em; text-align: center; margin-bottom: 10px; } """ with gr.Blocks(css=custom_css) as demo: with gr.Row(): gr.Markdown(f"