Spaces:
Runtime error
Runtime error
import os | |
import spaces | |
import torch | |
import json | |
import logging | |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLImg2ImgPipeline, AutoencoderKL | |
import gradio as gr | |
import random | |
from datetime import datetime | |
from PIL import Image, PngImagePlugin | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Configuration | |
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs") | |
MAX_SEED = 2**32 - 1 | |
def seed_everything(seed): | |
if seed is None: | |
seed = random.randint(0, MAX_SEED) | |
torch.manual_seed(seed) | |
random.seed(seed) | |
return torch.Generator(device='cuda').manual_seed(seed) | |
def save_image(image, metadata, output_dir, is_colab=False): | |
os.makedirs(output_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"generated_{timestamp}.png" | |
filepath = os.path.join(output_dir, filename) | |
# Save with metadata | |
png_info = PngImagePlugin.PngInfo() | |
png_info.add_text("parameters", json.dumps(metadata)) | |
image.save(filepath, "PNG", pnginfo=png_info) | |
return filepath | |
# Load the diffusion pipeline with optimized VAE | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"kayfahaarukku/irAsu-1.0", | |
torch_dtype=torch.float16, | |
custom_pipeline="lpw_stable_diffusion_xl", | |
) | |
# Load optimized VAE | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", | |
torch_dtype=torch.float16, | |
) | |
pipe.vae = vae | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
# Style presets | |
styles = { | |
"(None)": ("", ""), | |
"Detailed": ("highly detailed, intricate details, ", ""), | |
"Simple": ("simple style, minimalistic, ", "complex, detailed"), | |
"Soft": ("soft lighting, dreamy atmosphere, ", "harsh lighting, sharp contrast"), | |
} | |
# Quality presets | |
quality_presets = { | |
"Standard": ( | |
"best quality, amazing quality, very aesthetic", | |
"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts" | |
), | |
"High Detail": ( | |
"masterpiece, best quality, amazing quality, very aesthetic, highly detailed", | |
"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality" | |
), | |
"Basic": ( | |
"good quality", | |
"nsfw, lowres, bad quality" | |
) | |
} | |
# Function to generate an image | |
def generate_image( | |
prompt, | |
negative_prompt, | |
use_quality_preset, | |
resolution, | |
guidance_scale, | |
num_inference_steps, | |
seed, | |
randomize_seed, | |
style_preset="(None)", | |
use_upscaler=False, | |
upscaler_strength=0.55, | |
upscale_by=1.5, | |
progress=gr.Progress() | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
# Apply style preset | |
style_prompt, style_negative = styles[style_preset] | |
prompt = f"{style_prompt}{prompt}" | |
negative_prompt = f"{negative_prompt}, {style_negative}" if style_negative else negative_prompt | |
if use_quality_preset: | |
quality_prompt, quality_negative = quality_presets["Standard"] | |
prompt = f"{prompt}, {quality_prompt}" | |
negative_prompt = f"{negative_prompt}, {quality_negative}" | |
generator = seed_everything(seed) | |
width, height = map(int, resolution.split('x')) | |
metadata = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"resolution": f"{width} x {height}", | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": num_inference_steps, | |
"seed": seed, | |
"style_preset": style_preset, | |
"use_quality_preset": use_quality_preset | |
} | |
try: | |
if use_upscaler: | |
# Initial generation | |
latents = pipe( | |
prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
output_type="latent" | |
).images | |
# Setup img2img pipeline for upscaling | |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components) | |
# Calculate new dimensions | |
new_width = int(width * upscale_by) | |
new_height = int(height * upscale_by) | |
# Upscale | |
image = upscaler_pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=latents, | |
strength=upscaler_strength, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator | |
).images[0] | |
metadata["upscaler"] = { | |
"strength": upscaler_strength, | |
"scale_factor": upscale_by, | |
"final_resolution": f"{new_width}x{new_height}" | |
} | |
else: | |
image = pipe( | |
prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
callback=lambda step, timestep, latents: progress(step / num_inference_steps) | |
).images[0] | |
# Save image with metadata | |
image_path = save_image(image, metadata, OUTPUT_DIR) | |
logger.info(f"Image saved as {image_path} with metadata") | |
return image, seed, json.dumps(metadata, indent=2) | |
except Exception as e: | |
logger.exception(f"An error occurred: {e}") | |
raise | |
finally: | |
if use_upscaler: | |
del upscaler_pipe | |
torch.cuda.empty_cache() | |
# Define Gradio interface | |
with gr.Blocks(title="irAsu 1.0 Enhanced Demo", theme="NoCrypt/[email protected]") as demo: | |
gr.HTML("<h1>irAsu 1.0 Enhanced Demo</h1>") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(lines=2, placeholder="Enter prompt here", label="Prompt") | |
negative_prompt_input = gr.Textbox(lines=2, placeholder="Enter negative prompt here", label="Negative Prompt") | |
with gr.Accordion("Style & Quality", open=True): | |
style_selector = gr.Radio( | |
choices=list(styles.keys()), | |
value="(None)", | |
label="Style Preset" | |
) | |
use_quality_preset = gr.Checkbox(label="Use Quality Preset", value=True) | |
resolution_input = gr.Radio( | |
choices=[ | |
"1024x1024", "1152x896", "896x1152", "1216x832", "832x1216", | |
"1344x768", "768x1344", "1536x640", "640x1536" | |
], | |
label="Resolution", | |
value="832x1216" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
guidance_scale_input = gr.Slider(minimum=1, maximum=20, step=0.5, label="Guidance Scale", value=4) | |
num_inference_steps_input = gr.Slider(minimum=1, maximum=100, step=1, label="Number of Inference Steps", value=28) | |
seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, label="Seed", value=0) | |
randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True) | |
use_upscaler_input = gr.Checkbox(label="Use Upscaler", value=False) | |
with gr.Group(visible=False) as upscaler_settings: | |
upscaler_strength_input = gr.Slider(minimum=0, maximum=1, step=0.05, label="Upscaler Strength", value=0.55) | |
upscale_by_input = gr.Slider(minimum=1, maximum=1.5, step=0.1, label="Upscale Factor", value=1.5) | |
generate_button = gr.Button("Generate") | |
reset_button = gr.Button("Reset") | |
with gr.Column(): | |
output_image = gr.Image(type="pil", label="Generated Image") | |
with gr.Accordion("Parameters", open=False): | |
metadata_textbox = gr.Textbox(lines=6, label="Image Parameters", interactive=False) | |
# Handle upscaler visibility | |
use_upscaler_input.change( | |
fn=lambda x: gr.Group(visible=x), | |
inputs=[use_upscaler_input], | |
outputs=[upscaler_settings] | |
) | |
# Generate button click event | |
generate_button.click( | |
generate_image, | |
inputs=[ | |
prompt_input, | |
negative_prompt_input, | |
use_quality_preset, | |
resolution_input, | |
guidance_scale_input, | |
num_inference_steps_input, | |
seed_input, | |
randomize_seed_input, | |
style_selector, | |
use_upscaler_input, | |
upscaler_strength_input, | |
upscale_by_input | |
], | |
outputs=[output_image, seed_input, metadata_textbox] | |
) | |
# Reset button click event | |
reset_button.click( | |
lambda: ( | |
"", "", True, "832x1216", 4, 28, 0, True, | |
"(None)", False, 0.55, 1.5, None | |
), | |
outputs=[ | |
prompt_input, negative_prompt_input, use_quality_preset, | |
resolution_input, guidance_scale_input, num_inference_steps_input, | |
seed_input, randomize_seed_input, style_selector, | |
use_upscaler_input, upscaler_strength_input, upscale_by_input, | |
metadata_textbox | |
] | |
) | |
demo.queue(max_size=20).launch(share=False) |