Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gc | |
import gradio as gr | |
import numpy as np | |
import torch | |
import json | |
import spaces | |
import random | |
import config | |
import utils | |
import logging | |
from PIL import Image, PngImagePlugin | |
from datetime import datetime | |
from diffusers.models import AutoencoderKL | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline | |
import time | |
from typing import List, Dict, Tuple, Optional | |
from config import ( | |
MODEL, | |
MIN_IMAGE_SIZE, | |
MAX_IMAGE_SIZE, | |
DEFAULT_PROMPT, | |
DEFAULT_NEGATIVE_PROMPT, | |
scheduler_list, | |
) | |
import io | |
MAX_SEED = np.iinfo(np.int32).max | |
# Enhanced logging configuration | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
logger = logging.getLogger(__name__) | |
# PyTorch settings for better performance and determinism | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cuda.matmul.allow_tf32 = True | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {device}") | |
# Model initialization | |
if torch.cuda.is_available(): | |
try: | |
logger.info("Loading VAE and pipeline...") | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", | |
torch_dtype=torch.float16, | |
) | |
pipe = utils.load_pipeline(MODEL, device, vae=vae) | |
logger.info("Pipeline loaded successfully on GPU!") | |
except Exception as e: | |
logger.error(f"Error loading VAE, falling back to default: {e}") | |
pipe = utils.load_pipeline(MODEL, device) | |
else: | |
logger.warning("CUDA not available, running on CPU") | |
pipe = None | |
class GenerationError(Exception): | |
"""Custom exception for generation errors""" | |
pass | |
def validate_prompt(prompt: str) -> str: | |
"""Validate and clean up the input prompt.""" | |
if not isinstance(prompt, str): | |
raise GenerationError("Prompt must be a string") | |
try: | |
# Ensure proper UTF-8 encoding/decoding | |
prompt = prompt.encode('utf-8').decode('utf-8') | |
# Add space between ! and , | |
prompt = prompt.replace("!,", "! ,") | |
except UnicodeError: | |
raise GenerationError("Invalid characters in prompt") | |
# Only check if the prompt is completely empty or only whitespace | |
if not prompt or prompt.isspace(): | |
raise GenerationError("Prompt cannot be empty") | |
return prompt.strip() | |
def validate_dimensions(width: int, height: int) -> None: | |
"""Validate image dimensions.""" | |
if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE: | |
raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") | |
if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE: | |
raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") | |
progress=gr.Progress() | |
def generate( | |
prompt: str, | |
negative_prompt: str, | |
width: int, | |
height: int, | |
scheduler: str, | |
opt_strength:float, | |
opt_scale:float, | |
seed: int, | |
randomize_seed: bool, | |
guidance_scale: float, | |
num_inference_steps: int | |
): | |
progress(0,desc="Starting") | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
"""Generate images based on the given parameters.""" | |
upscaler_pipe = None | |
backup_scheduler = None | |
def callback1(pipe, step, timestep, callback_kwargs): | |
progress_value = 0.1 + ((step+1.0)/num_inference_steps)*(0.5/1.0) | |
progress(progress_value, desc=f"Image generating, {step + 1}/{num_inference_steps} steps") | |
return callback_kwargs | |
optimizing_steps = int(num_inference_steps * opt_strength) | |
def callback2(pipe, step, timestep, callback_kwargs): | |
progress_value = 0.6 + ((step+1.0)/optimizing_steps)*(0.4/1.0) | |
progress(progress_value, desc=f"Image optimizing, {step + 1}/{optimizing_steps} steps") | |
return callback_kwargs | |
try: | |
# Memory management | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Input validation | |
prompt = validate_prompt(prompt) | |
if negative_prompt: | |
negative_prompt = negative_prompt.encode('utf-8').decode('utf-8') | |
validate_dimensions(width, height) | |
# Set up generation | |
generator = utils.seed_everything(seed) | |
width, height = utils.preprocess_image_dimensions(width, height) | |
# Set up pipeline | |
backup_scheduler = pipe.scheduler | |
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, scheduler) | |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components) | |
progress(0.1,desc="Image generating") | |
latents = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
output_type="latent", | |
callback_on_step_end=callback1 | |
).images | |
upscaled_latents = utils.upscale(latents, "nearest-exact", opt_scale) | |
images = upscaler_pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=upscaled_latents, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
strength=opt_strength, | |
generator=generator, | |
output_type="pil", | |
callback_on_step_end=callback2 | |
).images | |
out_img = images[0] | |
path = utils.save_image(out_img, "./outputs") | |
logger.info(f"output path: {path}") | |
progress(1, desc="Complete") | |
return path | |
except GenerationError as e: | |
logger.warning(f"Generation validation error: {str(e)}") | |
raise gr.Error(str(e)) | |
except Exception as e: | |
logger.exception("Unexpected error during generation") | |
raise gr.Error(f"Generation failed: {str(e)}") | |
finally: | |
# Cleanup | |
torch.cuda.empty_cache() | |
gc.collect() | |
if upscaler_pipe is not None: | |
del upscaler_pipe | |
if backup_scheduler is not None and pipe is not None: | |
pipe.scheduler = backup_scheduler | |
utils.free_memory() | |
title = "# Animagine XL 4.0 Demo" | |
custom_css = """ | |
#row-container { | |
align-items: stretch; | |
} | |
#output-image{ | |
flex-grow: 1; | |
} | |
#output-image *{ | |
max-height: none !important; | |
} | |
""" | |
with gr.Blocks(css=custom_css).queue() as demo: | |
gr.Markdown(title) | |
with gr.Row( | |
elem_id="row-container" | |
): | |
with gr.Column(): | |
gr.Markdown("### Input") | |
with gr.Column(): | |
prompt = gr.Text( | |
label="Prompt", | |
max_lines=5, | |
placeholder="Enter your prompt", | |
value=DEFAULT_PROMPT, | |
) | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=5, | |
placeholder="Enter a negative prompt", | |
value=DEFAULT_NEGATIVE_PROMPT, | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=MIN_IMAGE_SIZE, | |
maximum=MAX_IMAGE_SIZE, | |
step=8, | |
value=832, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=MIN_IMAGE_SIZE, | |
maximum=MAX_IMAGE_SIZE, | |
step=8, | |
value=1216, | |
) | |
with gr.Row(): | |
optimization_strength = gr.Slider( | |
label="Optimization strength", | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.55, | |
) | |
optimization_scale = gr.Slider( | |
label="Optimization scale ratio", | |
minimum=1, | |
maximum=1.5, | |
step=0.1, | |
value=1.5, | |
) | |
with gr.Column(): | |
scheduler = gr.Dropdown( | |
label="scheduler", | |
choices=scheduler_list, | |
interactive=True, | |
value="Euler a", | |
) | |
with gr.Column(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=1.0, | |
maximum=12.0, | |
step=0.1, | |
value=6.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=25, | |
) | |
run_button = gr.Button("Run", variant="primary") | |
with gr.Column(): | |
gr.Markdown("### Output") | |
result = gr.Image( | |
type="filepath", | |
label="Generated Image", | |
elem_id="output-image" | |
) | |
run_button.click( | |
fn=generate, | |
inputs=[ | |
prompt, negative_prompt, | |
width, height, | |
scheduler, | |
optimization_strength,optimization_scale, | |
seed,randomize_seed, | |
guidance_scale,num_inference_steps | |
], | |
outputs=[result], | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() |