Aatricks's picture
Update app.py
f83908e verified
import glob
import gradio as gr
import sys
import os
from PIL import Image
import numpy as np
import spaces
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from modules.user.pipeline import pipeline
import torch
def load_generated_images():
"""Load generated images with given prefix from disk"""
image_files = glob.glob("./_internal/output/**/*.png")
# If there are no image files, return
if not image_files:
return []
# Sort files by modification time in descending order
image_files.sort(key=os.path.getmtime, reverse=True)
# Get most recent timestamp
latest_time = os.path.getmtime(image_files[0])
# Get all images from same batch (within 1 second of most recent)
batch_images = []
for file in image_files:
if abs(os.path.getmtime(file) - latest_time) < 1.0:
try:
img = Image.open(file)
batch_images.append(img)
except:
continue
if not batch_images:
return []
return batch_images
@spaces.GPU(duration=120)
def generate_images(
prompt: str,
width: int = 512,
height: int = 512,
num_images: int = 1,
batch_size: int = 1,
hires_fix: bool = False,
adetailer: bool = False,
enhance_prompt: bool = False,
img2img_enabled: bool = False,
img2img_image: str = None,
stable_fast: bool = False,
reuse_seed: bool = False,
flux_enabled: bool = False,
prio_speed: bool = False,
realistic_model: bool = False,
progress=gr.Progress(),
):
"""Generate images using the LightDiffusion pipeline"""
try:
if img2img_enabled and img2img_image is not None:
# Convert numpy array to PIL Image
if isinstance(img2img_image, np.ndarray):
img_pil = Image.fromarray(img2img_image)
img_pil.save("temp_img2img.png")
prompt = "temp_img2img.png"
# Run pipeline and capture saved images
with torch.inference_mode():
pipeline(
prompt=prompt,
w=width,
h=height,
number=num_images,
batch=batch_size,
hires_fix=hires_fix,
adetailer=adetailer,
enhance_prompt=enhance_prompt,
img2img=img2img_enabled,
stable_fast=stable_fast,
reuse_seed=reuse_seed,
flux_enabled=flux_enabled,
prio_speed=prio_speed,
autohdr=True,
realistic_model=realistic_model,
)
# Clean up temporary file if it exists
if os.path.exists("temp_img2img.png"):
os.remove("temp_img2img.png")
return load_generated_images()
except Exception:
import traceback
print(traceback.format_exc())
# Clean up temporary file if it exists
if os.path.exists("temp_img2img.png"):
os.remove("temp_img2img.png")
return [Image.new("RGB", (512, 512), color="black")]
# Create Gradio interface
with gr.Blocks(title="LightDiffusion Web UI") as demo:
gr.Markdown("# LightDiffusion Web UI")
gr.Markdown("Generate AI images using LightDiffusion")
gr.Markdown(
"This is the demo for LightDiffusion, the fastest diffusion backend for generating images. https://github.com/LightDiffusion/LightDiffusion-Next"
)
with gr.Row():
with gr.Column():
# Input components
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
with gr.Row():
width = gr.Slider(
minimum=64, maximum=2048, value=512, step=64, label="Width"
)
height = gr.Slider(
minimum=64, maximum=2048, value=512, step=64, label="Height"
)
with gr.Row():
num_images = gr.Slider(
minimum=1, maximum=10, value=1, step=1, label="Number of Images"
)
batch_size = gr.Slider(
minimum=1, maximum=4, value=1, step=1, label="Batch Size"
)
with gr.Row():
hires_fix = gr.Checkbox(label="HiRes Fix")
adetailer = gr.Checkbox(label="Auto Face/Body Enhancement")
enhance_prompt = gr.Checkbox(label="Enhance Prompt")
stable_fast = gr.Checkbox(label="Stable Fast Mode")
with gr.Row():
reuse_seed = gr.Checkbox(label="Reuse Seed")
flux_enabled = gr.Checkbox(label="Flux Mode")
prio_speed = gr.Checkbox(label="Prioritize Speed")
realistic_model = gr.Checkbox(label="Realistic Model")
with gr.Row():
img2img_enabled = gr.Checkbox(label="Image to Image Mode")
img2img_image = gr.Image(label="Input Image for img2img", visible=False)
# Make input image visible only when img2img is enabled
img2img_enabled.change(
fn=lambda x: gr.update(visible=x),
inputs=[img2img_enabled],
outputs=[img2img_image],
)
generate_btn = gr.Button("Generate")
# Output gallery
gallery = gr.Gallery(
label="Generated Images",
show_label=True,
elem_id="gallery",
columns=[2],
rows=[2],
object_fit="contain",
height="auto",
)
# Connect generate button to pipeline
generate_btn.click(
fn=generate_images,
inputs=[
prompt,
width,
height,
num_images,
batch_size,
hires_fix,
adetailer,
enhance_prompt,
img2img_enabled,
img2img_image,
stable_fast,
reuse_seed,
flux_enabled,
prio_speed,
realistic_model,
],
outputs=gallery,
)
def is_huggingface_space():
return "SPACE_ID" in os.environ
# For local testing
if __name__ == "__main__":
if is_huggingface_space():
demo.launch(
debug=False,
server_name="0.0.0.0",
server_port=7860, # Standard HF Spaces port
)
else:
demo.launch(
server_name="0.0.0.0",
server_port=8000,
auth=None,
share=True, # Only enable sharing locally
debug=True,
)