frogleo's picture
η€ζ‰‹εŠ ε…₯ζ¨‘εž‹
6aefd85
raw
history blame
5.97 kB
import spaces
import gradio as gr
import numpy as np
import torch
import random
import logging
import utils
from diffusers.models import AutoencoderKL
MAX_SEED = np.iinfo(np.int32).max
MIN_IMAGE_SIZE = 512
MAX_IMAGE_SIZE = 2048
# Enhanced logging configuration
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# PyTorch settings for better performance and determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Model initialization
if torch.cuda.is_available():
try:
logger.info("Loading VAE and pipeline...")
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipe = utils.load_pipeline("cagliostrolab/animagine-xl-4.0", device, vae=vae)
logger.info("Pipeline loaded successfully on GPU!")
except Exception as e:
logger.error(f"Error loading VAE, falling back to default: {e}")
pipe = utils.load_pipeline("cagliostrolab/animagine-xl-4.0", device)
else:
logger.warning("CUDA not available, running on CPU")
pipe = None
@spaces.GPU
def generate(
prompt: str,
negative_prompt: str,
width: int,
height: int,
scheduler: str,
upscaler_strength:float,
upscale_by:float,
seed: int,
randomize_seed: bool,
guidance_scale: float,
num_inference_steps: int,
progress:gr.Progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# generator = torch.Generator().manual_seed(seed)
# image = pipe(
# prompt=prompt,
# negative_prompt=negative_prompt,
# guidance_scale=guidance_scale,
# num_inference_steps=num_inference_steps,
# width=width,
# height=height,
# generator=generator,
# ).images[0]
# return image, seed
return None, seed
scheduler_list = [
"DPM++ 2M Karras",
"DPM++ SDE Karras",
"DPM++ 2M SDE Karras",
"Euler",
"Euler a",
"DDIM"
]
title = "# Animagine XL 4.0 Demo"
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
custom_css = """
#row-container {
align-items: stretch;
}
#output-image{
flex-grow: 1;
}
"""
with gr.Blocks(css=custom_css).queue() as demo:
gr.Markdown(title)
with gr.Row(
elem_id="row-container"
):
with gr.Column():
gr.Markdown("### Input")
with gr.Column():
prompt = gr.Text(
label="Prompt",
max_lines=1,
placeholder="Enter your prompt",
)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=MIN_IMAGE_SIZE,
maximum=MAX_IMAGE_SIZE,
step=8,
value=832,
)
height = gr.Slider(
label="Height",
minimum=MIN_IMAGE_SIZE,
maximum=MAX_IMAGE_SIZE,
step=8,
value=1216,
)
with gr.Row():
upscaler_strength = gr.Slider(
label="Upscaler strength",
minimum=0,
maximum=1,
step=0.05,
value=0.55,
)
upscale_by = gr.Slider(
label="Upscale",
minimum=1,
maximum=1.5,
step=0.1,
value=1.5,
)
with gr.Column():
scheduler = gr.Dropdown(
label="scheduler",
choices=scheduler_list,
interactive=True,
value="Euler a",
)
with gr.Column():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=12.0,
step=0.1,
value=6.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
run_button = gr.Button("Run", variant="primary")
with gr.Column():
gr.Markdown("### Output")
result = gr.Image(
label="Generated Image",
elem_id="output-image"
)
run_button.click(
fn=generate,
inputs=[
prompt, negative_prompt,
width, height,
scheduler,
upscaler_strength,upscale_by,
seed,randomize_seed,
guidance_scale,num_inference_steps
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()