Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
import torch.multiprocessing as mp | |
from torch.cuda.amp import autocast | |
from diffusers import ( | |
DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline, | |
AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline, | |
LuminaText2ImgPipeline | |
) | |
import spaces | |
import gc | |
import os | |
import psutil | |
import threading | |
from pathlib import Path | |
import shutil | |
import time | |
import glob | |
from datetime import datetime | |
from PIL import Image | |
from queue import Queue | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
TORCH_DTYPE = torch.bfloat16 | |
OUTPUT_DIR = "generated_images" | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# Get available GPU devices | |
AVAILABLE_GPUS = list(range(torch.cuda.device_count())) | |
print(f"Available GPUs: {AVAILABLE_GPUS}") | |
# Model configurations | |
MODEL_CONFIGS = { | |
"FLUX": { | |
"repo_id": "black-forest-labs/FLUX.1-dev", | |
"pipeline_class": FluxPipeline | |
}, | |
"Stable Diffusion 3.5": { | |
"repo_id": "stabilityai/stable-diffusion-3.5-large", | |
"pipeline_class": StableDiffusion3Pipeline | |
} | |
} | |
# GPU allocation queue and model cache | |
gpu_queue = Queue() | |
for gpu_id in AVAILABLE_GPUS: | |
gpu_queue.put(gpu_id) | |
model_cache = {} | |
model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()} | |
def get_next_available_gpu(): | |
"""Get the next available GPU from the queue""" | |
gpu_id = gpu_queue.get() | |
return gpu_id | |
def release_gpu(gpu_id): | |
"""Release GPU back to the queue""" | |
gpu_queue.put(gpu_id) | |
def load_pipeline_on_gpu(model_name, gpu_id): | |
"""Load model pipeline on specific GPU with memory tracking""" | |
config = MODEL_CONFIGS[model_name] | |
with torch.cuda.device(gpu_id): | |
pipe = config["pipeline_class"].from_pretrained( | |
config["repo_id"], | |
torch_dtype=TORCH_DTYPE | |
) | |
pipe = pipe.to(f"cuda:{gpu_id}") | |
if hasattr(pipe, 'enable_model_cpu_offload'): | |
pipe.enable_model_cpu_offload() | |
return pipe | |
def save_generated_image(image, model_name, prompt): | |
"""Save generated image with timestamp and model name""" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip() | |
filename = f"{timestamp}_{model_name}_{prompt_part}.png" | |
filepath = os.path.join(OUTPUT_DIR, filename) | |
image.save(filepath) | |
return filepath | |
def get_generated_images(): | |
"""Get list of generated images with their details""" | |
files = glob.glob(os.path.join(OUTPUT_DIR, "*.png")) | |
files.sort(key=os.path.getctime, reverse=True) | |
return [ | |
{ | |
"path": f, | |
"name": os.path.basename(f), | |
"date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"), | |
"size": f"{os.path.getsize(f) / 1024:.1f} KB" | |
} | |
for f in files | |
] | |
def generate_image_on_gpu(args): | |
"""Generate image on specific GPU""" | |
model_name, prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps = args | |
try: | |
gpu_id = get_next_available_gpu() | |
print(f"Generating {model_name} on GPU {gpu_id}") | |
# Load or get cached pipeline | |
cache_key = f"{model_name}_{gpu_id}" | |
if cache_key not in model_cache: | |
with model_locks[model_name]: | |
model_cache[cache_key] = load_pipeline_on_gpu(model_name, gpu_id) | |
pipe = model_cache[cache_key] | |
# Generate image | |
with torch.cuda.device(gpu_id), autocast(): | |
generator = torch.Generator(f"cuda:{gpu_id}").manual_seed(seed) | |
image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
filepath = save_generated_image(image, model_name, prompt) | |
print(f"Saved image from {model_name} to: {filepath}") | |
release_gpu(gpu_id) | |
return image, seed | |
except Exception as e: | |
print(f"Error with {model_name} on GPU {gpu_id}: {str(e)}") | |
release_gpu(gpu_id) | |
raise e | |
def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()): | |
outputs = [None] * (len(MODEL_CONFIGS) * 2) | |
# Prepare generation tasks | |
tasks = [] | |
for model_name in MODEL_CONFIGS.keys(): | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else seed | |
tasks.append(( | |
model_name, prompt, negative_prompt, current_seed, | |
width, height, guidance_scale, num_inference_steps | |
)) | |
# Run generation in parallel using thread pool | |
with ThreadPoolExecutor(max_workers=len(AVAILABLE_GPUS)) as executor: | |
future_to_model = { | |
executor.submit(generate_image_on_gpu, task): idx | |
for idx, task in enumerate(tasks) | |
} | |
for future in as_completed(future_to_model): | |
idx = future_to_model[future] | |
try: | |
image, used_seed = future.result() | |
outputs[idx * 2] = image | |
outputs[idx * 2 + 1] = used_seed | |
yield outputs + [None] | |
except Exception as e: | |
print(f"Generation failed for model {idx}: {str(e)}") | |
outputs[idx * 2] = None | |
outputs[idx * 2 + 1] = None | |
# Update gallery after all generations complete | |
gallery_images = update_gallery() | |
return outputs | |
# Gradio Interface | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 1024px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f"# Multi-GPU Image Generation ({len(AVAILABLE_GPUS)} GPUs Available)") | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Generate", scale=0, variant="primary") | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=512, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=512, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=7.5, | |
step=0.1, | |
value=4.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=40, | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Tabs() as tabs: | |
results = {} | |
seeds = {} | |
for model_name in MODEL_CONFIGS.keys(): | |
with gr.Tab(model_name): | |
results[model_name] = gr.Image(label=f"{model_name} Result") | |
seeds[model_name] = gr.Number(label="Seed used", visible=False) | |
with gr.Column(scale=1): | |
gr.Markdown("### Generated Images") | |
file_gallery = gr.Gallery( | |
label="Generated Images", | |
show_label=False, | |
elem_id="file_gallery", | |
columns=2, | |
height=400 | |
) | |
refresh_button = gr.Button("Refresh Gallery") | |
def update_gallery(): | |
"""Update the file gallery""" | |
files = get_generated_images() | |
return [ | |
(f["path"], f"{f['name']}\n{f['date']}") | |
for f in files | |
] | |
output_components = [] | |
for model_name in MODEL_CONFIGS.keys(): | |
output_components.extend([results[model_name], seeds[model_name]]) | |
run_button.click( | |
fn=generate_all, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=output_components, | |
) | |
refresh_button.click( | |
fn=update_gallery, | |
inputs=[], | |
outputs=[file_gallery], | |
) | |
demo.load( | |
fn=update_gallery, | |
inputs=[], | |
outputs=[file_gallery], | |
) | |
if __name__ == "__main__": | |
# Initialize multiprocessing for PyTorch | |
mp.set_start_method('spawn', force=True) | |
demo.launch() |