Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import json | |
import logging | |
import torch | |
from PIL import Image | |
import random | |
import time | |
from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel | |
from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler | |
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast | |
from huggingface_hub import ModelCard | |
# Constants | |
MODEL_PREFIX = "HiDream-ai" | |
LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
FAST_MODEL_CONFIG = { | |
"path": f"{MODEL_PREFIX}/HiDream-I1-Full", | |
"guidance_scale": 5.0, | |
"num_inference_steps": 50, | |
"shift": 3.0, | |
"scheduler": FlowUniPCMultistepScheduler | |
} | |
RESOLUTION_OPTIONS = [ | |
"1024 × 1024 (Square)", | |
"768 × 1360 (Portrait)", | |
"1360 × 768 (Landscape)", | |
"880 × 1168 (Portrait)", | |
"1168 × 880 (Landscape)", | |
"1248 × 832 (Landscape)", | |
"832 × 1248 (Portrait)" | |
] | |
# Load LoRAs from JSON file (assumed to be compatible with Hi-Dream) | |
with open('loras.json', 'r') as f: | |
loras = json.load(f) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MAX_SEED = 2**32 - 1 | |
# Parse resolution string to height and width | |
def parse_resolution(res_str): | |
mapping = { | |
"1024 × 1024": (1024, 1024), | |
"768 × 1360": (768, 1360), | |
"1360 × 768": (1360, 768), | |
"880 × 1168": (880, 1168), | |
"1168 × 880": (1168, 880), | |
"1248 × 832": (1248, 832), | |
"832 × 1248": (832, 1248) | |
} | |
for key, (h, w) in mapping.items(): | |
if key in res_str: | |
return h, w | |
return 1024, 1024 # fallback | |
# Load the Hi-Dream Fast Model pipeline | |
pipe, MODEL_CONFIG = None, None | |
def load_fast_model(): | |
global pipe, MODEL_CONFIG | |
config = FAST_MODEL_CONFIG | |
scheduler = config["scheduler"]( | |
num_train_timesteps=1000, | |
shift=config["shift"], | |
use_dynamic_shifting=False | |
) | |
tokenizer = PreTrainedTokenizerFast.from_pretrained( | |
LLAMA_MODEL_NAME, | |
use_fast=False | |
) | |
text_encoder = LlamaForCausalLM.from_pretrained( | |
LLAMA_MODEL_NAME, | |
output_hidden_states=True, | |
output_attentions=True, | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
transformer = HiDreamImageTransformer2DModel.from_pretrained( | |
config["path"], | |
subfolder="transformer", | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
pipe = HiDreamImagePipeline.from_pretrained( | |
config["path"], | |
scheduler=scheduler, | |
tokenizer_4=tokenizer, | |
text_encoder_4=text_encoder, | |
torch_dtype=torch.bfloat16 | |
).to(device, torch.bfloat16) | |
pipe.transformer = transformer | |
MODEL_CONFIG = config | |
return pipe, config | |
# Generate image | |
def generate_image(prompt, resolution, seed, guidance_scale, num_inference_steps): | |
global pipe, MODEL_CONFIG | |
if pipe is None: | |
pipe, MODEL_CONFIG = load_fast_model() | |
height, width = parse_resolution(resolution) | |
if seed == -1 or seed is None: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(int(seed)) | |
result = pipe( | |
prompt=prompt, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
num_images_per_prompt=1, | |
generator=generator | |
) | |
return result.images[0], seed | |
class calculateDuration: | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
if self.activity_name: | |
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
else: | |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
def update_selection(evt: gr.SelectData, resolution): | |
selected_lora = loras[evt.index] | |
new_placeholder = f"Type a prompt for {selected_lora['title']}" | |
lora_repo = selected_lora["repo"] | |
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨" | |
if "aspect" in selected_lora: | |
if selected_lora["aspect"] == "portrait": | |
resolution = "768 × 1360 (Portrait)" | |
elif selected_lora["aspect"] == "landscape": | |
resolution = "1360 × 768 (Landscape)" | |
else: | |
resolution = "1024 × 1024 (Square)" | |
return ( | |
gr.update(placeholder=new_placeholder), | |
updated_text, | |
evt.index, | |
resolution, | |
) | |
def run_lora(prompt, resolution, cfg_scale, steps, selected_index, randomize_seed, seed): | |
global pipe | |
if pipe is None: | |
pipe, _ = load_fast_model() | |
if selected_index is not None: | |
selected_lora = loras[selected_index] | |
lora_path = selected_lora["repo"] | |
weight_name = selected_lora.get("weights", None) | |
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"): | |
pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True) | |
trigger_word = selected_lora.get("trigger_word", "") | |
if trigger_word: | |
if "trigger_position" in selected_lora and selected_lora["trigger_position"] == "prepend": | |
prompt = f"{trigger_word} {prompt}" | |
else: | |
prompt = f"{prompt} {trigger_word}" | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
with calculateDuration("Generating image"): | |
final_image, used_seed = generate_image(prompt, resolution, seed, cfg_scale, steps) | |
return final_image, used_seed | |
def check_custom_model(link): | |
split_link = link.split("/") | |
if len(split_link) != 2: | |
raise Exception("Invalid Hugging Face repository link format.") | |
model_card = ModelCard.load(link) | |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None) | |
trigger_word = model_card.data.get("instance_prompt", "") | |
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None | |
safetensors_name = None # Simplified; assumes a safetensors file exists | |
return split_link[1], link, safetensors_name, trigger_word, image_url | |
def add_custom_lora(custom_lora): | |
global loras | |
if custom_lora: | |
try: | |
title, repo, path, trigger_word, image = check_custom_model(custom_lora) | |
card = f''' | |
<div class="custom_lora_card"> | |
<span>Loaded custom LoRA:</span> | |
<div class="card_internal"> | |
<img src="{image}" /> | |
<div> | |
<h3>{title}</h3> | |
<small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found."}</small> | |
</div> | |
</div> | |
</div> | |
''' | |
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None) | |
if not existing_item_index: | |
new_item = { | |
"image": image, | |
"title": title, | |
"repo": repo, | |
"weights": path, | |
"trigger_word": trigger_word | |
} | |
existing_item_index = len(loras) | |
loras.append(new_item) | |
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word | |
except Exception as e: | |
gr.Warning(f"Invalid LoRA: {str(e)}") | |
return gr.update(visible=True, value=f"Invalid LoRA: {str(e)}"), gr.update(visible=True), gr.update(), "", None, "" | |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" | |
def remove_custom_lora(): | |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" | |
css = ''' | |
#gen_btn{height: 100%} | |
#gen_column{align-self: stretch} | |
#title{text-align: center} | |
#title h1{font-size: 3em; display:inline-flex; align-items:center} | |
#title img{width: 100px; margin-right: 0.5em} | |
#gallery .grid-wrap{height: 10vh} | |
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%} | |
.card_internal{display: flex;height: 100px;margin-top: .5em} | |
.card_internal img{margin-right: 1em} | |
.styler{--form-gap-width: 0px !important} | |
''' | |
font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"] | |
with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60)) as app: | |
title = gr.HTML( | |
"""<h1>Hi-Dream Full LoRA DLC 🤩</h1>""", | |
elem_id="title", | |
) | |
selected_index = gr.State(None) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA") | |
with gr.Column(scale=1, elem_id="gen_column"): | |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn") | |
with gr.Row(): | |
with gr.Column(): | |
selected_info = gr.Markdown("") | |
gallery = gr.Gallery( | |
[(item["image"], item["title"]) for item in loras], | |
label="LoRA Gallery", | |
allow_preview=False, | |
columns=3, | |
elem_id="gallery", | |
show_share_button=False | |
) | |
with gr.Group(): | |
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="multimodalart/vintage-ads-flux") | |
gr.Markdown("[Check the list of Hi-Dream LoRAs]", elem_id="lora_list") | |
custom_lora_info = gr.HTML(visible=False) | |
custom_lora_button = gr.Button("Remove custom LoRA", visible=False) | |
with gr.Column(): | |
result = gr.Image(label="Generated Image") | |
with gr.Row(): | |
with gr.Accordion("Advanced Settings", open=False): | |
cfg_scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=20, step=0.1, value=FAST_MODEL_CONFIG["guidance_scale"]) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=FAST_MODEL_CONFIG["num_inference_steps"]) | |
resolution = gr.Radio( | |
choices=RESOLUTION_OPTIONS, | |
value=RESOLUTION_OPTIONS[0], | |
label="Resolution" | |
) | |
randomize_seed = gr.Checkbox(True, label="Randomize seed") | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) | |
gallery.select( | |
update_selection, | |
inputs=[resolution], | |
outputs=[prompt, selected_info, selected_index, resolution] | |
) | |
custom_lora.input( | |
add_custom_lora, | |
inputs=[custom_lora], | |
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt] | |
) | |
custom_lora_button.click( | |
remove_custom_lora, | |
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora] | |
) | |
gr.on( | |
triggers=[generate_button.click, prompt.submit], | |
fn=run_lora, | |
inputs=[prompt, resolution, cfg_scale, steps, selected_index, randomize_seed, seed], | |
outputs=[result, seed] | |
) | |
app.queue() | |
app.launch() |