LPX55's picture
hm
bd72bff
raw
history blame
7.72 kB
import os
import spaces
import gradio as gr
import torch
import logging
from diffusers import DiffusionPipeline
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
from transformer_hidream_image import HiDreamImageTransformer2DModel
from pipeline_hidream_image import HiDreamImagePipeline
from schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
import subprocess
try:
print(subprocess.check_output(["nvcc", "--version"]).decode("utf-8"))
except:
print("nvcc version check error")
# subprocess.run('python -m pip install flash-attn --no-build-isolation', shell=True)
def log_vram(msg: str):
print(f"{msg} (used {torch.cuda.memory_allocated() / 1024**2:.2f} MB VRAM)\n")
# from nf4 import *
# Resolution options
RESOLUTION_OPTIONS = [
"1024 Γ— 1024 (Square)",
"768 Γ— 1360 (Portrait)",
"1360 Γ— 768 (Landscape)",
"880 Γ— 1168 (Portrait)",
"1168 Γ— 880 (Landscape)",
"1248 Γ— 832 (Landscape)",
"832 Γ— 1248 (Portrait)"
]
MODEL_PREFIX = "azaneko"
LLAMA_MODEL_NAME = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
FAST_CONFIG = {
"path": "azaneko/HiDream-I1-Fast-nf4",
"guidance_scale": 0.0,
"num_inference_steps": 16,
"shift": 3.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
}
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME)
log_vram("βœ… Tokenizer loaded!")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
LLAMA_MODEL_NAME,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
log_vram("βœ… Text encoder loaded!")
transformer = HiDreamImageTransformer2DModel.from_pretrained(
"azaneko/HiDream-I1-Fast-nf4",
subfolder="transformer",
torch_dtype=torch.bfloat16
)
log_vram("βœ… Transformer loaded!")
pipe = HiDreamImagePipeline.from_pretrained(
"azaneko/HiDream-I1-Fast-nf4",
scheduler=FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=3.0, use_dynamic_shifting=False),
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
torch_dtype=torch.bfloat16,
)
pipe.transformer = transformer
log_vram("βœ… Pipeline loaded!")
pipe.enable_sequential_cpu_offload()
# Model configurations
MODEL_CONFIGS = {
"dev": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev-nf4",
"guidance_scale": 0.0,
"num_inference_steps": 28,
"shift": 6.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
},
"full": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Full-nf4",
"guidance_scale": 5.0,
"num_inference_steps": 50,
"shift": 3.0,
"scheduler": FlowUniPCMultistepScheduler
},
"fast": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast-nf4",
"guidance_scale": 0.0,
"num_inference_steps": 16,
"shift": 3.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
}
}
# Parse resolution string to get height and width
def parse_resolution(resolution_str):
return tuple(map(int, resolution_str.split("(")[0].strip().split(" Γ— ")))
# def load_models(model_type: str):
# config = MODEL_CONFIGS[model_type]
# tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME)
# log_vram("βœ… Tokenizer loaded!")
# text_encoder_4 = LlamaForCausalLM.from_pretrained(
# LLAMA_MODEL_NAME,
# output_hidden_states=True,
# output_attentions=True,
# return_dict_in_generate=True,
# torch_dtype=torch.bfloat16,
# device_map="auto",
# )
# log_vram("βœ… Text encoder loaded!")
# transformer = HiDreamImageTransformer2DModel.from_pretrained(
# config["path"],
# subfolder="transformer",
# torch_dtype=torch.bfloat16
# )
# log_vram("βœ… Transformer loaded!")
# pipe = HiDreamImagePipeline.from_pretrained(
# config["path"],
# scheduler=FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False),
# tokenizer_4=tokenizer_4,
# text_encoder_4=text_encoder_4,
# torch_dtype=torch.bfloat16,
# )
# pipe.transformer = transformer
# log_vram("βœ… Pipeline loaded!")
# pipe.enable_sequential_cpu_offload()
# return pipe, config
#@torch.inference_mode()
@spaces.GPU()
def generate_image(pipe: HiDreamImagePipeline, model_type: str, prompt: str, resolution: tuple[int, int], seed: int):
# Get configuration for current model
config = MODEL_CONFIGS[model_type]
guidance_scale = 0.0
num_inference_steps = 16
# Parse resolution
width, height = resolution
# Handle seed
if seed == -1:
seed = torch.randint(0, 1000000, (1,)).item()
generator = torch.Generator("cuda").manual_seed(seed)
images = pipe(
prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=generator
).images
return images[0], seed
@spaces.GPU()
def gen_img_helper(prompt, res, seed):
global pipe, current_model
# 1. Check if the model matches loaded model, load the model if not
# if model != current_model:
# print(f"Unloading model {current_model}...")
# del pipe
# torch.cuda.empty_cache()
# print(f"Loading model {model}...")
# pipe, _ = load_models(model)
# current_model = model
# print("Model loaded successfully!")
# 2. Generate image
res = parse_resolution(res)
return generate_image(pipe, model, prompt, res, seed)
if __name__ == "__main__":
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
# Initialize with default model
# print("Loading default model (fast)...")
# current_model = "fast"
# pipe, _ = load_models(current_model)
# print("Model loaded successfully!")
# Create Gradio interface
with gr.Blocks(title="HiDream-I1-nf4 Dashboard") as demo:
gr.Markdown("# HiDream-I1-nf4 Dashboard")
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=list(MODEL_CONFIGS.keys()),
value="fast",
label="Model Type",
info="Select model variant"
)
prompt = gr.Textbox(
label="Prompt",
placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
lines=3
)
resolution = gr.Radio(
choices=RESOLUTION_OPTIONS,
value=RESOLUTION_OPTIONS[0],
label="Resolution",
info="Select image resolution"
)
seed = gr.Number(
label="Seed (use -1 for random)",
value=-1,
precision=0
)
generate_btn = gr.Button("Generate Image")
seed_used = gr.Number(label="Seed Used", interactive=False)
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
generate_btn.click(
fn=gen_img_helper,
inputs=[prompt, resolution, seed],
outputs=[output_image, seed_used]
)
demo.launch()