Sergidev's picture
History v1
23daa6d verified
raw
history blame
11.9 kB
import os
import gc
import gradio as gr
import numpy as np
import torch
import json
import spaces
import config
import utils
import logging
from PIL import Image, PngImagePlugin
from datetime import datetime
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from collections import deque
import base64
from io import BytesIO
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DESCRIPTION = "PonyDiffusion V6 XL"
if not torch.cuda.is_available():
DESCRIPTION += "\n\nRunning on CPU 🥶 This demo does not work on CPU."
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
HF_TOKEN = os.getenv("HF_TOKEN")
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
MODEL = os.getenv(
"MODEL",
"https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MAX_HISTORY_SIZE = 10
image_history = deque(maxlen=MAX_HISTORY_SIZE)
def load_pipeline(model_name):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipeline = (
StableDiffusionXLPipeline.from_single_file
if MODEL.endswith(".safetensors")
else StableDiffusionXLPipeline.from_pretrained
)
pipe = pipeline(
model_name,
vae=vae,
torch_dtype=torch.float16,
custom_pipeline="lpw_stable_diffusion_xl",
use_safetensors=True,
add_watermarker=False,
use_auth_token=HF_TOKEN,
variant="fp16",
)
pipe.to(device)
return pipe
@spaces.GPU
def generate(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
custom_width: int = 1024,
custom_height: int = 1024,
guidance_scale: float = 7.0,
num_inference_steps: int = 30,
sampler: str = "DPM++ 2M SDE Karras",
aspect_ratio_selector: str = "1024 x 1024",
use_upscaler: bool = False,
upscaler_strength: float = 0.55,
upscale_by: float = 1.5,
progress=gr.Progress(track_tqdm=True),
) -> Image:
generator = utils.seed_everything(seed)
width, height = utils.aspect_ratio_handler(
aspect_ratio_selector, custom_width, custom_height,
)
width, height = utils.preprocess_image_dimensions(width, height)
backup_scheduler = pipe.scheduler
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
if use_upscaler:
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
metadata = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"resolution": f"{width} x {height}",
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"seed": seed,
"sampler": sampler,
}
if use_upscaler:
new_width = int(width * upscale_by)
new_height = int(height * upscale_by)
metadata["use_upscaler"] = {
"upscale_method": "nearest-exact",
"upscaler_strength": upscaler_strength,
"upscale_by": upscale_by,
"new_resolution": f"{new_width} x {new_height}",
}
else:
metadata["use_upscaler"] = None
logger.info(json.dumps(metadata, indent=4))
try:
if use_upscaler:
latents = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
output_type="latent",
).images
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
images = upscaler_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=upscaled_latents,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
strength=upscaler_strength,
generator=generator,
output_type="pil",
).images
else:
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
output_type="pil",
).images
if images:
for image in images:
# Create thumbnail
thumbnail = image.copy()
thumbnail.thumbnail((256, 256))
# Convert thumbnail to base64
buffered = BytesIO()
thumbnail.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Add to history
image_history.appendleft({
"thumbnail": f"data:image/png;base64,{img_str}",
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"width": width,
"height": height,
})
if images and IS_COLAB:
for image in images:
filepath = utils.save_image(image, metadata, OUTPUT_DIR)
logger.info(f"Image saved as {filepath} with metadata")
return images, metadata, list(image_history)
except Exception as e:
logger.exception(f"An error occurred: {e}")
raise
finally:
if use_upscaler:
del upscaler_pipe
pipe.scheduler = backup_scheduler
utils.free_memory()
if torch.cuda.is_available():
pipe = load_pipeline(MODEL)
logger.info("Loaded on Device!")
else:
pipe = None
with gr.Blocks(css="style.css") as demo:
title = gr.HTML(
f"""<h1>{DESCRIPTION}</h1>"""
)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
show_label=False,
max_lines=2,
placeholder="Enter a negative prompt",
)
with gr.Row():
seed = gr.Number(
label="Seed",
value=0,
precision=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
custom_width = gr.Slider(
label="Width",
minimum=MIN_IMAGE_SIZE,
maximum=MAX_IMAGE_SIZE,
step=8,
value=1024,
)
custom_height = gr.Slider(
label="Height",
minimum=MIN_IMAGE_SIZE,
maximum=MAX_IMAGE_SIZE,
step=8,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=20, step=0.1, value=7
)
num_inference_steps = gr.Slider(
label="Num Inference Steps",
minimum=1,
maximum=100,
step=1,
value=30,
)
with gr.Row():
sampler = gr.Dropdown(
label="Sampler",
choices=[
"DPM++ 2M SDE Karras",
"DPM++ 2M SDE",
"Euler a",
"Euler",
"DPM++ 2M Karras",
"DPM++ 2M",
"LMS Karras",
"Heun",
"DPM++ SDE Karras",
"DPM++ SDE",
"DPM2 Karras",
"DPM2",
"DPM2 a Karras",
"DPM2 a",
"LMS",
"DDIM",
"PLMS",
],
value="DPM++ 2M SDE Karras",
)
aspect_ratio_selector = gr.Dropdown(
label="Aspect Ratio",
choices=[
"1024 x 1024",
"1152 x 896",
"896 x 1152",
"1216 x 832",
"832 x 1216",
"1344 x 768",
"768 x 1344",
"1536 x 640",
"640 x 1536",
],
value="1024 x 1024",
)
with gr.Row():
use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
upscaler_strength = gr.Slider(
label="Upscaler Strength",
minimum=0,
maximum=1,
step=0.05,
value=0.55,
)
upscale_by = gr.Slider(
label="Upscale By",
minimum=1,
maximum=4,
step=0.1,
value=1.5,
)
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image")
output_text = gr.JSON(label="Generation Info")
with gr.Row():
generate_button = gr.Button("Generate")
# Add the history component
history = gr.HTML(label="Generation History")
# Update the generate_button click event
generate_button.click(
generate,
inputs=[
prompt,
negative_prompt,
seed,
custom_width,
custom_height,
guidance_scale,
num_inference_steps,
sampler,
aspect_ratio_selector,
use_upscaler,
upscaler_strength,
upscale_by,
],
outputs=[output_image, output_text, history],
)
# Add a function to update the history display
def update_history(history_data):
html = "<div class='history-container'>"
for item in history_data:
html += f"""
<div class='history-item'>
<img src='{item['thumbnail']}' alt='Generated Image'>
<div class='history-info'>
<p><strong>Prompt:</strong> {item['prompt']}</p>
<p><strong>Negative Prompt:</strong> {item['negative_prompt']}</p>
<p><strong>Seed:</strong> {item['seed']}</p>
<p><strong>Size:</strong> {item['width']}x{item['height']}</p>
</div>
</div>
"""
html += "</div>"
return html
# Connect the update_history function to the history component
history.change(update_history, inputs=[history], outputs=[history])
demo.queue(concurrency_count=1, max_size=20)
demo.launch(debug=True)