Sergidev's picture
v3p3
16189d5 verified
raw
history blame
16.8 kB
import os
import gc
import random
import gradio as gr
import numpy as np
import torch
import json
import spaces
import config
import utils
import logging
from PIL import Image, PngImagePlugin
from datetime import datetime
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DESCRIPTION = "PonyDiffusion V6 XL"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
HF_TOKEN = os.getenv("HF_TOKEN")
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
MODEL = os.getenv(
"MODEL",
"https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_pipeline(model_name):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipeline = (
StableDiffusionXLPipeline.from_single_file
if MODEL.endswith(".safetensors")
else StableDiffusionXLPipeline.from_pretrained
)
pipe = pipeline(
model_name,
vae=vae,
torch_dtype=torch.float16,
custom_pipeline="lpw_stable_diffusion_xl",
use_safetensors=True,
add_watermarker=False,
use_auth_token=HF_TOKEN,
variant="fp16",
)
pipe.to(device)
return pipe
def parse_json_parameters(json_str):
try:
params = json.loads(json_str)
return params
except json.JSONDecodeError:
return None
def apply_json_parameters(json_str):
params = parse_json_parameters(json_str)
if params:
return (
params.get("prompt", ""),
params.get("negative_prompt", ""),
params.get("seed", 0),
params.get("width", 1024),
params.get("height", 1024),
params.get("guidance_scale", 7.0),
params.get("num_inference_steps", 30),
params.get("sampler", "DPM++ 2M SDE Karras"),
params.get("aspect_ratio", "1024 x 1024"),
params.get("use_upscaler", False),
params.get("upscaler_strength", 0.55),
params.get("upscale_by", 1.5),
)
return [gr.update()] * 12
def generate(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
custom_width: int = 1024,
custom_height: int = 1024,
guidance_scale: float = 7.0,
num_inference_steps: int = 30,
sampler: str = "DPM++ 2M SDE Karras",
aspect_ratio_selector: str = "1024 x 1024",
use_upscaler: bool = False,
upscaler_strength: float = 0.55,
upscale_by: float = 1.5,
progress=gr.Progress(track_tqdm=True),
) -> Image:
generator = utils.seed_everything(seed)
width, height = utils.aspect_ratio_handler(
aspect_ratio_selector,
custom_width,
custom_height,
)
width, height = utils.preprocess_image_dimensions(width, height)
backup_scheduler = pipe.scheduler
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
if use_upscaler:
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
metadata = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"resolution": f"{width} x {height}",
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"seed": seed,
"sampler": sampler,
}
if use_upscaler:
new_width = int(width * upscale_by)
new_height = int(height * upscale_by)
metadata["use_upscaler"] = {
"upscale_method": "nearest-exact",
"upscaler_strength": upscaler_strength,
"upscale_by": upscale_by,
"new_resolution": f"{new_width} x {new_height}",
}
else:
metadata["use_upscaler"] = None
logger.info(json.dumps(metadata, indent=4))
try:
if use_upscaler:
latents = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
output_type="latent",
).images
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
images = upscaler_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=upscaled_latents,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
strength=upscaler_strength,
generator=generator,
output_type="pil",
).images
else:
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
output_type="pil",
).images
if images and IS_COLAB:
for image in images:
filepath = utils.save_image(image, metadata, OUTPUT_DIR)
logger.info(f"Image saved as {filepath} with metadata")
return images, json.dumps(metadata) # Return metadata as a JSON string
except Exception as e:
logger.exception(f"An error occurred: {e}")
raise
finally:
if use_upscaler:
del upscaler_pipe
pipe.scheduler = backup_scheduler
utils.free_memory()
def get_random_prompt():
anime_characters = [
"Naruto Uzumaki", "Monkey D. Luffy", "Goku", "Eren Yeager", "Light Yagami",
"Lelouch Lamperouge", "Edward Elric", "Levi Ackerman", "Spike Spiegel",
"Sakura Haruno", "Mikasa Ackerman", "Asuka Langley Soryu", "Rem", "Megumin",
"Violet Evergarden"
]
styles = ["pixel art", "stylized anime", "digital art", "watercolor", "sketch"]
scores = ["score_9", "score_8_up", "score_7_up"]
character = random.choice(anime_characters)
style = random.choice(styles)
score = ", ".join(random.sample(scores, k=3))
return f"{score}, {character}, {style}, show accurate"
if torch.cuda.is_available():
pipe = load_pipeline(MODEL)
logger.info("Loaded on Device!")
else:
pipe = None
# Define the JavaScript code as a string
js_code = """
<script>
document.addEventListener('DOMContentLoaded', (event) => {
const historyDropdown = document.getElementById('history-dropdown');
const resultGallery = document.querySelector('.gallery');
if (historyDropdown && resultGallery) {
const observer = new MutationObserver((mutations) => {
mutations.forEach((mutation) => {
if (mutation.type === 'childList' && mutation.addedNodes.length > 0) {
const newImage = mutation.addedNodes[0];
if (newImage.tagName === 'IMG') {
updateHistory(newImage.src);
}
}
});
});
observer.observe(resultGallery, { childList: true });
function updateHistory(imageSrc) {
const prompt = document.querySelector('#prompt textarea').value;
const option = document.createElement('option');
option.value = prompt;
option.textContent = prompt;
option.setAttribute('data-image', imageSrc);
historyDropdown.insertBefore(option, historyDropdown.firstChild);
if (historyDropdown.children.length > 10) {
historyDropdown.removeChild(historyDropdown.lastChild);
}
}
historyDropdown.addEventListener('change', (event) => {
const selectedOption = event.target.selectedOptions[0];
const imageSrc = selectedOption.getAttribute('data-image');
if (imageSrc) {
const img = document.createElement('img');
img.src = imageSrc;
resultGallery.innerHTML = '';
resultGallery.appendChild(img);
}
});
}
});
</script>
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML(js_code) # Add the JavaScript code to the interface
title = gr.HTML(
f"""<h1><span>{DESCRIPTION}</span></h1>""",
elem_id="title",
)
gr.Markdown(
f"""Gradio demo for [Pony Diffusion V6](https://civitai.com/models/257749/pony-diffusion-v6-xl/)""",
elem_id="subtitle",
)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=5,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button(
"Generate",
variant="primary",
scale=0
)
result = gr.Gallery(
label="Result",
columns=1,
preview=True,
show_label=False
)
with gr.Accordion(label="Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative Prompt",
max_lines=5,
placeholder="Enter a negative prompt",
value=""
)
aspect_ratio_selector = gr.Radio(
label="Aspect Ratio",
choices=config.aspect_ratios,
value="1024 x 1024",
container=True,
)
with gr.Group(visible=False) as custom_resolution:
with gr.Row():
custom_width = gr.Slider(
label="Width",
minimum=MIN_IMAGE_SIZE,
maximum=MAX_IMAGE_SIZE,
step=8,
value=1024,
)
custom_height = gr.Slider(
label="Height",
minimum=MIN_IMAGE_SIZE,
maximum=MAX_IMAGE_SIZE,
step=8,
value=1024,
)
use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
with gr.Row() as upscaler_row:
upscaler_strength = gr.Slider(
label="Strength",
minimum=0,
maximum=1,
step=0.05,
value=0.55,
visible=False,
)
upscale_by = gr.Slider(
label="Upscale by",
minimum=1,
maximum=1.5,
step=0.1,
value=1.5,
visible=False,
)
sampler = gr.Dropdown(
label="Sampler",
choices=config.sampler_list,
interactive=True,
value="DPM++ 2M SDE Karras",
)
with gr.Row():
seed = gr.Slider(
label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Group():
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=12,
step=0.1,
value=7.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
with gr.Accordion(label="JSON Parameters", open=False):
json_input = gr.TextArea(label="Input JSON parameters")
apply_json_button = gr.Button("Apply JSON Parameters")
with gr.Row():
clear_button = gr.Button("Clear All")
random_prompt_button = gr.Button("Random Prompt")
history = gr.State([]) # Add a state component to store history
history_dropdown = gr.Dropdown(label="Generation History", choices=[], interactive=True, elem_id="history-dropdown")
with gr.Accordion(label="Generation Parameters", open=False):
gr_metadata = gr.JSON(label="Metadata", show_label=False)
def update_history(images, metadata, history):
if images:
new_entry = {"prompt": json.loads(metadata)["prompt"], "image": images[0]}
history = [new_entry] + history[:9] # Keep only the last 10 entries
return gr.update(choices=[h["prompt"] for h in history]), history
gr.Examples(
examples=config.examples,
inputs=prompt,
outputs=[result, gr_metadata],
fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
cache_examples=CACHE_EXAMPLES,
)
use_upscaler.change(
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
inputs=use_upscaler,
outputs=[upscaler_strength, upscale_by],
queue=False,
api_name=False,
)
aspect_ratio_selector.change(
fn=lambda x: gr.update(visible=x == "Custom"),
inputs=aspect_ratio_selector,
outputs=custom_resolution,
queue=False,
api_name=False,
)
inputs = [
prompt,
negative_prompt,
seed,
custom_width,
custom_height,
guidance_scale,
num_inference_steps,
sampler,
aspect_ratio_selector,
use_upscaler,
upscaler_strength,
upscale_by,
]
prompt.submit(
fn=utils.randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=[result, gr_metadata],
api_name="run",
).then(
fn=update_history,
inputs=[result, gr_metadata, history],
outputs=[history_dropdown, history],
)
negative_prompt.submit(
fn=utils.randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=[result, gr_metadata],
api_name=False,
).then(
fn=update_history,
inputs=[result, gr_metadata, history],
outputs=[history_dropdown, history],
)
run_button.click(
fn=utils.randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=[result, gr_metadata],
api_name=False,
).then(
fn=update_history,
inputs=[result, gr_metadata, history],
outputs=[history_dropdown, history],
)
apply_json_button.click(
fn=apply_json_parameters,
inputs=json_input,
outputs=[prompt, negative_prompt, seed, custom_width, custom_height,
guidance_scale, num_inference_steps, sampler,
aspect_ratio_selector, use_upscaler, upscaler_strength, upscale_by]
)
clear_button.click(
fn=lambda: (gr.update(value=""), gr.update(value=""), gr.update(value=0),
gr.update(value=1024), gr.update(value=1024),
gr.update(value=7.0), gr.update(value=30),
gr.update(value="DPM++ 2M SDE Karras"),
gr.update(value="1024 x 1024"), gr.update(value=False),
gr.update(value=0.55), gr.update(value=1.5)),
inputs=[],
outputs=[prompt, negative_prompt, seed, custom_width, custom_height,
guidance_scale, num_inference_steps, sampler,
aspect_ratio_selector, use_upscaler, upscaler_strength, upscale_by]
)
random_prompt_button.click(
fn=get_random_prompt,
inputs=[],
outputs=prompt
)
history_dropdown.change(
fn=lambda x, history: next((h["prompt"] for h in history if h["prompt"] == x), ""),
inputs=[history_dropdown, history],
outputs=prompt
)
demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)