Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import copy | |
import math | |
import time | |
import random | |
import logging | |
import numpy as np | |
from typing import Any, Dict, List, Optional, Union | |
import torch | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
from diffusers import ( | |
DiffusionPipeline, | |
FlowMatchEulerDiscreteScheduler) | |
from huggingface_hub import ( | |
hf_hub_download, | |
HfFileSystem, | |
ModelCard, | |
snapshot_download) | |
from diffusers.utils import load_image | |
import requests | |
from urllib.parse import urlparse | |
import tempfile | |
import shutil | |
import uuid | |
import zipfile | |
# META: CUDA_CHECK / GPU_INFO | |
print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
print("torch.__version__ =", torch.__version__) | |
print("torch.version.cuda =", torch.version.cuda) | |
print("cuda available:", torch.cuda.is_available()) | |
print("cuda device count:", torch.cuda.device_count()) | |
if torch.cuda.is_available(): | |
print("current device:", torch.cuda.current_device()) | |
print("device name:", torch.cuda.get_device_name(torch.cuda.current_device())) | |
print("Using device:", processing_device) | |
# List of predefined style models (formerly LoRAs) | |
style_definitions = [ | |
{ | |
"thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Studio-Realism/resolve/main/images/2.png", | |
"style_name": "Studio Realism", | |
"repo_id": "prithivMLmods/Qwen-Image-Studio-Realism", | |
"weight_file": "qwen-studio-realism.safetensors", | |
"activation_phrase": "Studio Realism" | |
}, | |
{ | |
"thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Sketch-Smudge/resolve/main/images/1.png", | |
"style_name": "Sketch Smudge", | |
"repo_id": "prithivMLmods/Qwen-Image-Sketch-Smudge", | |
"weight_file": "qwen-sketch-smudge.safetensors", | |
"activation_phrase": "Sketch Smudge" | |
}, | |
{ | |
"thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Anime-LoRA/resolve/main/images/1.png", | |
"style_name": "Qwen Anime", | |
"repo_id": "prithivMLmods/Qwen-Image-Anime-LoRA", | |
"weight_file": "qwen-anime.safetensors", | |
"activation_phrase": "Qwen Anime" | |
}, | |
{ | |
"thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Synthetic-Face/resolve/main/images/2.png", | |
"style_name": "Synthetic Face", | |
"repo_id": "prithivMLmods/Qwen-Image-Synthetic-Face", | |
"weight_file": "qwen-synthetic-face.safetensors", | |
"activation_phrase": "Synthetic Face" | |
}, | |
{ | |
"thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Fragmented-Portraiture/resolve/main/images/3.png", | |
"style_name": "Fragmented Portraiture", | |
"repo_id": "prithivMLmods/Qwen-Image-Fragmented-Portraiture", | |
"weight_file": "qwen-fragmented-portraiture.safetensors", | |
"activation_phrase": "Fragmented Portraiture" | |
}, | |
] | |
# --- Model Initialization --- | |
model_precision = torch.bfloat16 | |
processing_device = "cuda" if torch.cuda.is_available() else "cpu" | |
foundation_model_id = "Qwen/Qwen-Image" | |
# Sampler configuration from the Qwen-Image-Lightning repository | |
sampler_settings = { | |
"base_image_seq_len": 256, | |
"base_shift": math.log(3), | |
"invert_sigmas": False, | |
"max_image_seq_len": 8192, | |
"max_shift": math.log(3), | |
"num_train_timesteps": 1000, | |
"shift": 1.0, | |
"shift_terminal": None, | |
"stochastic_sampling": False, | |
"time_shift_type": "exponential", | |
"use_beta_sigmas": False, | |
"use_dynamic_shifting": True, | |
"use_exponential_sigmas": False, | |
"use_karras_sigmas": False, | |
} | |
sampler = FlowMatchEulerDiscreteScheduler.from_config(sampler_settings) | |
diffusion_pipeline = DiffusionPipeline.from_pretrained( | |
foundation_model_id, scheduler=sampler, torch_dtype=model_precision | |
).to(processing_device) | |
# Information for the fast generation LoRA | |
FAST_GENERATION_LORA_REPO = "lightx2v/Qwen-Image-Lightning" | |
FAST_GENERATION_LORA_WEIGHTS = "Qwen-Image-Lightning-8steps-V1.0.safetensors" | |
MAX_SEED_VALUE = np.iinfo(np.int32).max | |
class ExecutionTimer: | |
"""A context manager to time a block of code.""" | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
activity_log = f" for {self.activity_name}" if self.activity_name else "" | |
print(f"Elapsed time{activity_log}: {self.elapsed_time:.6f} seconds") | |
def get_dimensions_from_ratio(aspect_ratio_str): | |
"""Converts an aspect ratio string to a (width, height) tuple.""" | |
ratios = { | |
"1:1": (1024, 1024), | |
"16:9": (1152, 640), | |
"9:16": (640, 1152), | |
"4:3": (1024, 768), | |
"3:4": (768, 1024), | |
"3:2": (1024, 688), | |
"2:3": (688, 1024), | |
} | |
return ratios.get(aspect_ratio_str, (1024, 1024)) | |
def on_style_select(event_data: gr.SelectData, current_aspect_ratio): | |
"""Handles the user selecting a style from the gallery.""" | |
selected_style = style_definitions[event_data.index] | |
new_placeholder = f"Type a prompt for {selected_style['style_name']}" | |
repo_id = selected_style["repo_id"] | |
updated_info_text = f"### Selected: [{repo_id}](https://huggingface.co/{repo_id}) ✨" | |
# Update aspect ratio if specified in the style's configuration | |
if "aspect" in selected_style: | |
if selected_style["aspect"] == "portrait": | |
current_aspect_ratio = "9:16" | |
elif selected_style["aspect"] == "landscape": | |
current_aspect_ratio = "16:9" | |
else: | |
current_aspect_ratio = "1:1" | |
return ( | |
gr.update(placeholder=new_placeholder), | |
updated_info_text, | |
event_data.index, | |
current_aspect_ratio, | |
) | |
def on_mode_change(generation_mode): | |
"""Updates UI elements based on the selected generation mode (Speed/Quality).""" | |
if generation_mode == "Speed (8 steps)": | |
return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0 | |
else: | |
return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5 | |
def execute_image_generation(full_prompt, steps, seed_val, cfg, width, height, negative_prompt=""): | |
"""Generates an image using the diffusion pipeline.""" | |
diffusion_pipeline.to("cuda") | |
generator = torch.Generator(device="cuda").manual_seed(seed_val) | |
with ExecutionTimer("Image Generation"): | |
generated_image = diffusion_pipeline( | |
prompt=full_prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=steps, | |
true_cfg_scale=cfg, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
return generated_image | |
def handle_generate_request(prompt_text, cfg, steps, style_idx, use_random_seed, seed_val, aspect_ratio_str, style_scale, generation_mode, progress=gr.Progress(track_tqdm=True)): | |
"""Main function to handle a user's image generation request.""" | |
if style_idx is None: | |
raise gr.Error("You must select a style before generating an image.") | |
selected_style = style_definitions[style_idx] | |
style_repo_path = selected_style["repo_id"] | |
activation_phrase = selected_style["activation_phrase"] | |
# Combine the user prompt with the style's activation phrase | |
if activation_phrase: | |
position = selected_style.get("trigger_position", "prepend") | |
if position == "prepend": | |
full_prompt = f"{activation_phrase} {prompt_text}" | |
else: | |
full_prompt = f"{prompt_text} {activation_phrase}" | |
else: | |
full_prompt = prompt_text | |
# Always unload existing adapters to start fresh | |
with ExecutionTimer("Unloading existing adapters"): | |
diffusion_pipeline.unload_lora_weights() | |
# Load adapters based on the selected generation mode | |
if generation_mode == "Speed (8 steps)": | |
with ExecutionTimer("Loading Lightning and Style adapters"): | |
# Load the fast generation adapter first | |
diffusion_pipeline.load_lora_weights( | |
FAST_GENERATION_LORA_REPO, | |
weight_name=FAST_GENERATION_LORA_WEIGHTS, | |
adapter_name="lightning" | |
) | |
# Load the selected style adapter | |
weight_file = selected_style.get("weight_file", None) | |
diffusion_pipeline.load_lora_weights( | |
style_repo_path, | |
weight_name=weight_file, | |
low_cpu_mem_usage=True, | |
adapter_name="style" | |
) | |
# Set both adapters active with their respective weights | |
diffusion_pipeline.set_adapters(["lightning", "style"], adapter_weights=[1.0, style_scale]) | |
else: # Quality mode | |
with ExecutionTimer(f"Loading adapter weights for {selected_style['style_name']}"): | |
weight_file = selected_style.get("weight_file", None) | |
diffusion_pipeline.load_lora_weights( | |
style_repo_path, | |
weight_name=weight_file, | |
low_cpu_mem_usage=True | |
) | |
# Set the seed for reproducibility | |
with ExecutionTimer("Setting seed"): | |
if use_random_seed: | |
seed_val = random.randint(0, MAX_SEED_VALUE) | |
# Get image dimensions | |
width, height = get_dimensions_from_ratio(aspect_ratio_str) | |
# Generate the final image | |
final_image = execute_image_generation(full_prompt, steps, seed_val, cfg, width, height) | |
return final_image, seed_val | |
def fetch_hf_safetensors_details(repo_link): | |
"""Fetches details of a LoRA from a Hugging Face repository.""" | |
split_link = repo_link.split("/") | |
if len(split_link) != 2: | |
raise ValueError("Invalid Hugging Face repository link format.") | |
print(f"Attempting to load repository: {repo_link}") | |
model_card = ModelCard.load(repo_link) | |
base_model = model_card.data.get("base_model") | |
print(f"Base model identified: {base_model}") | |
# Validate that the LoRA is compatible with Qwen-Image | |
acceptable_models = {"Qwen/Qwen-Image"} | |
models_to_check = base_model if isinstance(base_model, list) else [base_model] | |
if not any(model in acceptable_models for model in models_to_check): | |
raise TypeError("The provided model is not a Qwen-Image compatible LoRA.") | |
# Extract metadata from the model card | |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url") | |
activation_phrase = model_card.data.get("instance_prompt", "") | |
image_url = f"https://huggingface.co/{repo_link}/resolve/main/{image_path}" if image_path else None | |
# Find the .safetensors file in the repository | |
fs = HfFileSystem() | |
try: | |
repo_files = fs.ls(repo_link, detail=False) | |
safetensors_filename = None | |
for file_path in repo_files: | |
filename = file_path.split("/")[-1] | |
if filename.endswith(".safetensors"): | |
safetensors_filename = filename | |
break | |
if not safetensors_filename: | |
raise FileNotFoundError("No .safetensors file was found in the repository.") | |
except Exception as e: | |
print(e) | |
raise IOError("Could not access the Hugging Face repository or find a valid .safetensors file.") | |
return split_link[1], repo_link, safetensors_filename, activation_phrase, image_url | |
def parse_custom_model_source(source_text): | |
"""Parses a user-provided link to a custom LoRA.""" | |
print(f"Parsing custom model source: {source_text}") | |
if source_text.endswith('.safetensors') and 'huggingface.co' in source_text: | |
parts = source_text.split('/') | |
try: | |
hf_index = parts.index('huggingface.co') | |
username = parts[hf_index + 1] | |
repo_name = parts[hf_index + 2] | |
repo_id = f"{username}/{repo_name}" | |
safetensors_filename = parts[-1] | |
try: | |
model_card = ModelCard.load(repo_id) | |
activation_phrase = model_card.data.get("instance_prompt", "") | |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url") | |
image_url = f"https://huggingface.co/{repo_id}/resolve/main/{image_path}" if image_path else None | |
except Exception: | |
activation_phrase = "" | |
image_url = None | |
return repo_name, repo_id, safetensors_filename, activation_phrase, image_url | |
except ValueError: | |
raise ValueError("Invalid .safetensors URL format.") | |
if source_text.startswith("https://"): | |
parsed_url = urlparse(source_text) | |
if "huggingface.co" in parsed_url.netloc: | |
repo_link = parsed_url.path.strip("/") | |
return fetch_hf_safetensors_details(repo_link) | |
# Assume it's a direct repo path like "username/repo-name" | |
return fetch_hf_safetensors_details(source_text) | |
def add_custom_style_model(custom_model_path): | |
"""Adds a custom LoRA provided by the user to the session.""" | |
global style_definitions | |
if custom_model_path: | |
try: | |
style_name, repo_id, weight_file, activation_phrase, thumbnail_url = parse_custom_model_source(custom_model_path) | |
print(f"Successfully loaded custom style: {repo_id}") | |
card_html = f''' | |
<div class="custom_lora_card"> | |
<span>Loaded custom style:</span> | |
<div class="card_internal"> | |
<img src="{thumbnail_url}" alt="{style_name}" /> | |
<div> | |
<h3>{style_name}</h3> | |
<small>{"Activation phrase: <code><b>"+activation_phrase+"</b></code>" if activation_phrase else "No activation phrase found. If required, include it in your prompt."}<br></small> | |
</div> | |
</div> | |
</div> | |
''' | |
# Check if this style already exists | |
existing_item_index = next((index for (index, item) in enumerate(style_definitions) if item['repo_id'] == repo_id), None) | |
if existing_item_index is None: | |
new_style_item = { | |
"thumbnail_url": thumbnail_url, | |
"style_name": style_name, | |
"repo_id": repo_id, | |
"weight_file": weight_file, | |
"activation_phrase": activation_phrase | |
} | |
style_definitions.append(new_style_item) | |
existing_item_index = len(style_definitions) - 1 | |
return gr.update(visible=True, value=card_html), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {weight_file}", existing_item_index, activation_phrase | |
except Exception as e: | |
gr.Warning(f"Failed to load custom style. Error: {e}") | |
error_message = f"Invalid input. Could not load the specified style. Please check the link or repository path." | |
return gr.update(visible=True, value=error_message), gr.update(visible=True), gr.update(), "", None, "" | |
# If input is empty, hide the custom section | |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" | |
def remove_custom_style_model(): | |
"""Resets the UI when a custom LoRA is removed.""" | |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" | |
# --- Gradio UI Definition --- | |
app_css = ''' | |
#gen_btn{height: 100%} | |
#gen_column{align-self: stretch} | |
#title{text-align: center} | |
#title h1{font-size: 3em; display:inline-flex; align-items:center} | |
#title img{width: 100px; margin-right: 0.5em} | |
#gallery .grid-wrap{height: 10vh} | |
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%} | |
.card_internal{display: flex;height: 100px;margin-top: .5em} | |
.card_internal img{margin-right: 1em; object-fit: cover;} | |
.styler{--form-gap-width: 0px !important} | |
#speed_status{padding: .5em; border-radius: 5px; margin: 1em 0} | |
.custom_lora_card{padding: 1em; border: 1px solid var(--border-color-primary); border-radius: var(--radius-lg)} | |
''' | |
with gr.Blocks(theme="bethecloud/storj_theme", css=app_css, delete_cache=(120, 120)) as web_interface: | |
main_title = gr.HTML("""<h1>Qwen Image LoRA DLC❤️🔥</h1>""", elem_id="title") | |
selected_style_index = gr.State(None) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt_textbox = gr.Textbox(label="Prompt", lines=1, placeholder="Select a style to begin...") | |
with gr.Column(scale=1, elem_id="gen_column"): | |
generate_btn = gr.Button("Generate", variant="primary", elem_id="gen_btn") | |
with gr.Row(): | |
with gr.Column(): | |
selected_style_info = gr.Markdown("") | |
style_gallery = gr.Gallery( | |
[(item["thumbnail_url"], item["style_name"]) for item in style_definitions], | |
label="Style Gallery", | |
allow_preview=False, | |
columns=3, | |
elem_id="gallery", | |
show_share_button=False | |
) | |
with gr.Group(): | |
custom_style_textbox = gr.Textbox(label="Load Custom Style", info="Enter a Hugging Face repository path (e.g., username/repo-name)", placeholder="username/qwen-image-custom-style") | |
gr.Markdown("[Find More Qwen-Image Styles Here](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list") | |
custom_style_info_html = gr.HTML(visible=False) | |
remove_custom_style_btn = gr.Button("Remove Custom Style", visible=False) | |
with gr.Column(): | |
output_image_display = gr.Image(label="Generated Image") | |
with gr.Row(): | |
aspect_ratio_dropdown = gr.Dropdown( | |
label="Aspect Ratio", | |
choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"], | |
value="1:1" | |
) | |
with gr.Row(): | |
generation_mode_dropdown = gr.Dropdown( | |
label="Generation Mode", | |
choices=["Speed (8 steps)", "Quality (45 steps)"], | |
value="Quality (45 steps)", | |
) | |
generation_mode_status_display = gr.Markdown("Quality mode active", elem_id="speed_status") | |
with gr.Row(): | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Column(): | |
with gr.Row(): | |
cfg_scale_slider = gr.Slider( | |
label="Guidance Scale (CFG)", | |
minimum=1.0, | |
maximum=5.0, | |
step=0.1, | |
value=3.5, | |
info="Adjusts how strictly the model follows the prompt. Lower for speed, higher for quality." | |
) | |
steps_slider = gr.Slider( | |
label="Inference Steps", | |
minimum=4, | |
maximum=50, | |
step=1, | |
value=45, | |
info="Number of steps for the generation process. Automatically set by Generation Mode." | |
) | |
with gr.Row(): | |
randomize_seed_checkbox = gr.Checkbox(True, label="Use Random Seed") | |
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED_VALUE, step=1, value=0, randomize=True) | |
style_scale_slider = gr.Slider(label="Style Strength", minimum=0, maximum=2, step=0.01, value=1.0) | |
# --- Event Handlers --- | |
style_gallery.select( | |
on_style_select, | |
inputs=[aspect_ratio_dropdown], | |
outputs=[prompt_textbox, selected_style_info, selected_style_index, aspect_ratio_dropdown] | |
) | |
generation_mode_dropdown.change( | |
on_mode_change, | |
inputs=[generation_mode_dropdown], | |
outputs=[generation_mode_status_display, steps_slider, cfg_scale_slider] | |
) | |
custom_style_textbox.submit( | |
add_custom_style_model, | |
inputs=[custom_style_textbox], | |
outputs=[custom_style_info_html, remove_custom_style_btn, style_gallery, selected_style_info, selected_style_index, prompt_textbox] | |
) | |
remove_custom_style_btn.click( | |
remove_custom_style_model, | |
outputs=[custom_style_info_html, remove_custom_style_btn, style_gallery, selected_style_info, selected_style_index, custom_style_textbox] | |
) | |
# Combined trigger for generation | |
generate_triggers = [generate_btn.click, prompt_textbox.submit] | |
gr.on( | |
triggers=generate_triggers, | |
fn=handle_generate_request, | |
inputs=[prompt_textbox, cfg_scale_slider, steps_slider, selected_style_index, randomize_seed_checkbox, seed_slider, aspect_ratio_dropdown, style_scale_slider, generation_mode_dropdown], | |
outputs=[output_image_display, seed_slider] | |
) | |
web_interface.queue() | |
web_interface.launch(share=False, ssr_mode=False, show_error=True) |