|
import os |
|
import gc |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import json |
|
import spaces |
|
import config |
|
import utils |
|
import logging |
|
from PIL import Image, PngImagePlugin |
|
from datetime import datetime |
|
from diffusers.models import AutoencoderKL |
|
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline |
|
from config import ( |
|
MODEL, |
|
MIN_IMAGE_SIZE, |
|
MAX_IMAGE_SIZE, |
|
USE_TORCH_COMPILE, |
|
ENABLE_CPU_OFFLOAD, |
|
OUTPUT_DIR, |
|
DEFAULT_NEGATIVE_PROMPT, |
|
DEFAULT_ASPECT_RATIO, |
|
examples, |
|
sampler_list, |
|
aspect_ratios, |
|
style_list, |
|
) |
|
import time |
|
from typing import List, Dict, Tuple, Optional |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1" |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" |
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
class GenerationError(Exception): |
|
"""Custom exception for generation errors""" |
|
pass |
|
|
|
def validate_prompt(prompt: str) -> str: |
|
"""Validate and clean up the input prompt.""" |
|
if not isinstance(prompt, str): |
|
raise GenerationError("Prompt must be a string") |
|
try: |
|
|
|
prompt = prompt.encode('utf-8').decode('utf-8') |
|
|
|
prompt = prompt.replace("!,", "! ,") |
|
except UnicodeError: |
|
raise GenerationError("Invalid characters in prompt") |
|
|
|
|
|
if not prompt or prompt.isspace(): |
|
raise GenerationError("Prompt cannot be empty") |
|
return prompt.strip() |
|
|
|
def validate_dimensions(width: int, height: int) -> None: |
|
"""Validate image dimensions.""" |
|
if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE: |
|
raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") |
|
|
|
if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE: |
|
raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") |
|
|
|
@spaces.GPU |
|
def generate( |
|
prompt: str, |
|
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT, |
|
seed: int = 0, |
|
custom_width: int = 1024, |
|
custom_height: int = 1024, |
|
guidance_scale: float = 6.0, |
|
num_inference_steps: int = 25, |
|
sampler: str = "Euler a", |
|
aspect_ratio_selector: str = DEFAULT_ASPECT_RATIO, |
|
style_selector: str = "(None)", |
|
use_upscaler: bool = False, |
|
upscaler_strength: float = 0.55, |
|
upscale_by: float = 1.5, |
|
add_quality_tags: bool = True, |
|
progress: gr.Progress = gr.Progress(track_tqdm=True), |
|
) -> Tuple[List[str], Dict]: |
|
"""Generate images based on the given parameters.""" |
|
start_time = time.time() |
|
upscaler_pipe = None |
|
backup_scheduler = None |
|
|
|
try: |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
prompt = validate_prompt(prompt) |
|
if negative_prompt: |
|
negative_prompt = negative_prompt.encode('utf-8').decode('utf-8') |
|
|
|
validate_dimensions(custom_width, custom_height) |
|
|
|
|
|
generator = utils.seed_everything(seed) |
|
width, height = utils.aspect_ratio_handler( |
|
aspect_ratio_selector, |
|
custom_width, |
|
custom_height, |
|
) |
|
|
|
|
|
if add_quality_tags: |
|
prompt = "masterpiece, high score, great score, absurdres, {prompt}".format(prompt=prompt) |
|
|
|
prompt, negative_prompt = utils.preprocess_prompt( |
|
styles, style_selector, prompt, negative_prompt |
|
) |
|
|
|
width, height = utils.preprocess_image_dimensions(width, height) |
|
|
|
|
|
backup_scheduler = pipe.scheduler |
|
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler) |
|
|
|
if use_upscaler: |
|
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components) |
|
|
|
|
|
metadata = { |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"resolution": f"{width} x {height}", |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"style_preset": style_selector, |
|
"seed": seed, |
|
"sampler": sampler, |
|
"Model": "Animagine XL 4.0", |
|
"Model hash": "e3c47aedb0", |
|
} |
|
|
|
if use_upscaler: |
|
new_width = int(width * upscale_by) |
|
new_height = int(height * upscale_by) |
|
metadata["use_upscaler"] = { |
|
"upscale_method": "nearest-exact", |
|
"upscaler_strength": upscaler_strength, |
|
"upscale_by": upscale_by, |
|
"new_resolution": f"{new_width} x {new_height}", |
|
} |
|
else: |
|
metadata["use_upscaler"] = None |
|
|
|
logger.info(f"Starting generation with parameters: {json.dumps(metadata, indent=4)}") |
|
|
|
|
|
if use_upscaler: |
|
latents = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
output_type="latent", |
|
).images |
|
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by) |
|
images = upscaler_pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=upscaled_latents, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
strength=upscaler_strength, |
|
generator=generator, |
|
output_type="pil", |
|
).images |
|
else: |
|
images = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
output_type="pil", |
|
).images |
|
|
|
|
|
if images: |
|
total = len(images) |
|
image_paths = [] |
|
for idx, image in enumerate(images, 1): |
|
progress(idx/total, desc="Saving images...") |
|
path = utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB) |
|
image_paths.append(path) |
|
logger.info(f"Image {idx}/{total} saved as {path}") |
|
|
|
generation_time = time.time() - start_time |
|
logger.info(f"Generation completed successfully in {generation_time:.2f} seconds") |
|
metadata["generation_time"] = f"{generation_time:.2f}s" |
|
|
|
return image_paths, metadata |
|
|
|
except GenerationError as e: |
|
logger.warning(f"Generation validation error: {str(e)}") |
|
raise gr.Error(str(e)) |
|
except Exception as e: |
|
logger.exception("Unexpected error during generation") |
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
finally: |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
if upscaler_pipe is not None: |
|
del upscaler_pipe |
|
|
|
if backup_scheduler is not None and pipe is not None: |
|
pipe.scheduler = backup_scheduler |
|
|
|
utils.free_memory() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
try: |
|
logger.info("Loading VAE and pipeline...") |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=torch.float16, |
|
) |
|
pipe = utils.load_pipeline(MODEL, device, vae=vae) |
|
logger.info("Pipeline loaded successfully on GPU!") |
|
except Exception as e: |
|
logger.error(f"Error loading VAE, falling back to default: {e}") |
|
pipe = utils.load_pipeline(MODEL, device) |
|
else: |
|
logger.warning("CUDA not available, running on CPU") |
|
pipe = None |
|
|
|
|
|
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} |
|
|
|
with gr.Blocks(css="style.css", theme="Nymbo/Nymbo_Theme_5") as demo: |
|
gr.HTML( |
|
""" |
|
<div class="header"> |
|
<div class="title">ANIM4GINE</div> |
|
<div class="subtitle">Gradio demo for <a href="https://huggingface.co/CagliostroLab/Animagine-XL-4.0" target="_blank">Animagine XL 4.0</a></div> |
|
</div> |
|
""", |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Group(): |
|
prompt = gr.Text( |
|
label="Prompt", |
|
max_lines=5, |
|
placeholder="Describe what you want to generate", |
|
info="Enter your image generation prompt here. Be specific and descriptive for better results.", |
|
) |
|
negative_prompt = gr.Text( |
|
label="Negative Prompt", |
|
max_lines=5, |
|
placeholder="Describe what you want to avoid", |
|
value=DEFAULT_NEGATIVE_PROMPT, |
|
info="Specify elements you don't want in the image.", |
|
) |
|
add_quality_tags = gr.Checkbox( |
|
label="Quality Tags", |
|
value=True, |
|
info="Add quality-enhancing tags to your prompt automatically.", |
|
) |
|
with gr.Accordion(label="More Settings", open=False): |
|
with gr.Group(): |
|
aspect_ratio_selector = gr.Radio( |
|
label="Aspect Ratio", |
|
choices=aspect_ratios, |
|
value=DEFAULT_ASPECT_RATIO, |
|
container=True, |
|
info="Choose the dimensions of your image.", |
|
) |
|
with gr.Group(visible=False) as custom_resolution: |
|
with gr.Row(): |
|
custom_width = gr.Slider( |
|
label="Width", |
|
minimum=MIN_IMAGE_SIZE, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=8, |
|
value=1024, |
|
info=f"Image width (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})", |
|
) |
|
custom_height = gr.Slider( |
|
label="Height", |
|
minimum=MIN_IMAGE_SIZE, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=8, |
|
value=1024, |
|
info=f"Image height (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})", |
|
) |
|
with gr.Group(): |
|
use_upscaler = gr.Checkbox( |
|
label="Use Upscaler", |
|
value=False, |
|
info="Enable high-resolution upscaling.", |
|
) |
|
with gr.Row() as upscaler_row: |
|
upscaler_strength = gr.Slider( |
|
label="Strength", |
|
minimum=0, |
|
maximum=1, |
|
step=0.05, |
|
value=0.55, |
|
visible=False, |
|
info="Control how much the upscaler affects the final image.", |
|
) |
|
upscale_by = gr.Slider( |
|
label="Upscale by", |
|
minimum=1, |
|
maximum=1.5, |
|
step=0.1, |
|
value=1.5, |
|
visible=False, |
|
info="Multiplier for the final image resolution.", |
|
) |
|
with gr.Accordion(label="Advanced Parameters", open=False): |
|
with gr.Group(): |
|
style_selector = gr.Dropdown( |
|
label="Style Preset", |
|
interactive=True, |
|
choices=list(styles.keys()), |
|
value="(None)", |
|
info="Apply a predefined style to your generation.", |
|
) |
|
with gr.Group(): |
|
sampler = gr.Dropdown( |
|
label="Sampler", |
|
choices=sampler_list, |
|
interactive=True, |
|
value="Euler a", |
|
info="Different samplers can produce varying results.", |
|
) |
|
with gr.Group(): |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=utils.MAX_SEED, |
|
step=1, |
|
value=0, |
|
info="Set a specific seed for reproducible results.", |
|
) |
|
randomize_seed = gr.Checkbox( |
|
label="Randomize seed", |
|
value=True, |
|
info="Generate a new random seed for each image.", |
|
) |
|
with gr.Group(): |
|
with gr.Row(): |
|
guidance_scale = gr.Slider( |
|
label="Guidance scale", |
|
minimum=1, |
|
maximum=12, |
|
step=0.1, |
|
value=6.0, |
|
info="Higher values make the image more closely match your prompt.", |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Number of inference steps", |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=25, |
|
info="More steps generally mean higher quality but slower generation.", |
|
) |
|
|
|
with gr.Column(scale=3): |
|
with gr.Blocks(): |
|
run_button = gr.Button("Generate", variant="primary", elem_id="generate-button") |
|
result = gr.Gallery( |
|
label="Generated Images", |
|
columns=1, |
|
height='768px', |
|
preview=True, |
|
show_label=True, |
|
) |
|
with gr.Accordion(label="Generation Parameters", open=False): |
|
gr_metadata = gr.JSON( |
|
label="Image Metadata", |
|
show_label=True, |
|
) |
|
gr.Examples( |
|
examples=examples, |
|
inputs=prompt, |
|
outputs=[result, gr_metadata], |
|
fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs), |
|
cache_examples=CACHE_EXAMPLES, |
|
) |
|
|
|
|
|
with gr.Row(): |
|
gr.HTML( |
|
""" |
|
<a href="https://discord.com/invite/cqh9tZgbGc" target="_blank" class="discord-btn"> |
|
<svg class="discord-icon" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 127.14 96.36"><path fill="currentColor" d="M107.7,8.07A105.15,105.15,0,0,0,81.47,0a72.06,72.06,0,0,0-3.36,6.83A97.68,97.68,0,0,0,49,6.83,72.37,72.37,0,0,0,45.64,0,105.89,105.89,0,0,0,19.39,8.09C2.79,32.65-1.71,56.6.54,80.21h0A105.73,105.73,0,0,0,32.71,96.36,77.7,77.7,0,0,0,39.6,85.25a68.42,68.42,0,0,1-10.85-5.18c.91-.66,1.8-1.34,2.66-2a75.57,75.57,0,0,0,64.32,0c.87.71,1.76,1.39,2.66,2a68.68,68.68,0,0,1-10.87,5.19,77,77,0,0,0,6.89,11.1A105.25,105.25,0,0,0,126.6,80.22h0C129.24,52.84,122.09,29.11,107.7,8.07ZM42.45,65.69C36.18,65.69,31,60,31,53s5-12.74,11.43-12.74S54,46,53.89,53,48.84,65.69,42.45,65.69Zm42.24,0C78.41,65.69,73.25,60,73.25,53s5-12.74,11.44-12.74S96.23,46,96.12,53,91.08,65.69,84.69,65.69Z"/></svg> |
|
<span class="discord-text">Join our Discord Server</span> |
|
</a> |
|
""" |
|
) |
|
|
|
use_upscaler.change( |
|
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)], |
|
inputs=use_upscaler, |
|
outputs=[upscaler_strength, upscale_by], |
|
queue=False, |
|
api_name=False, |
|
) |
|
aspect_ratio_selector.change( |
|
fn=lambda x: gr.update(visible=x == "Custom"), |
|
inputs=aspect_ratio_selector, |
|
outputs=custom_resolution, |
|
queue=False, |
|
api_name=False, |
|
) |
|
|
|
|
|
gr.on( |
|
triggers=[ |
|
prompt.submit, |
|
negative_prompt.submit, |
|
run_button.click, |
|
], |
|
fn=utils.randomize_seed_fn, |
|
inputs=[seed, randomize_seed], |
|
outputs=seed, |
|
queue=False, |
|
api_name=False, |
|
).then( |
|
fn=lambda: gr.update(interactive=False, value="Generating..."), |
|
outputs=run_button, |
|
).then( |
|
fn=generate, |
|
inputs=[ |
|
prompt, |
|
negative_prompt, |
|
seed, |
|
custom_width, |
|
custom_height, |
|
guidance_scale, |
|
num_inference_steps, |
|
sampler, |
|
aspect_ratio_selector, |
|
style_selector, |
|
use_upscaler, |
|
upscaler_strength, |
|
upscale_by, |
|
add_quality_tags, |
|
], |
|
outputs=[result, gr_metadata], |
|
).then( |
|
fn=lambda: gr.update(interactive=True, value="Generate"), |
|
outputs=run_button, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB) |
|
|