Tech-Meld's picture
Update app.py
3df52b6 verified
raw
history blame
5.69 kB
import gradio as gr
import numpy as np
import random
import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
# Use the correct repo for SDXL
repo = "stabilityai/sdxl-turbo" # This is the correct repo for SDXL
# Load the model components separately
vae = AutoencoderKL.from_pretrained(repo, subfolder="vae", torch_dtype=torch.float16).to(device)
text_encoder = SD3Transformer2DModel.from_pretrained(repo, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
unet = UNet2DConditionModel.from_pretrained(repo, subfolder="unet", torch_dtype=torch.float16).to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(repo, subfolder="scheduler", torch_dtype=torch.float16)
# Construct the pipeline (this is how you work with SDXL)
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
unet=unet,
scheduler=scheduler
).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
def infer(prompts, negative_prompts, seeds, randomize_seeds, widths, heights, guidance_scales, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
images = []
for i, prompt in enumerate(prompts):
if randomize_seeds[i]:
seeds[i] = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seeds[i])
# SDXL requires a slightly different call format:
image = pipe(
prompt=prompt,
negative_prompt=negative_prompts[i],
guidance_scale=guidance_scales[i],
num_inference_steps=num_inference_steps[i],
width=widths[i],
height=heights[i],
generator=generator
).images[0]
images.append(image)
return images, seeds
examples = [
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A blurry astronaut", 0, True, 512, 512, 7.5, 28],
["An astronaut riding a green horse", "Astronaut on a regular horse", 0, True, 512, 512, 7.5, 28],
["A delicious ceviche cheesecake slice", "A cheesecake that looks boring", 0, True, 512, 512, 7.5, 28],
]
css="""
#col-container {
margin: 0 auto;
max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Demo [Automated Stable Diffusion XL](https://huggingface.co/stabilityai/stablediffusion-xl)
""")
with gr.Row():
prompt_group = gr.Group(elem_id="prompt_group")
with prompt_group:
prompt_input = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
negative_prompt_input = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed_input = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed_input = gr.Checkbox(label="Randomize seed", value=True)
width_input = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=64,
value=512,
)
height_input = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=64,
value=512,
)
guidance_scale_input = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.5,
)
num_inference_steps_input = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Results", show_label=False, columns=4, rows=1)
add_button = gr.Button("Add Prompt")
with gr.Accordion("Advanced Settings", open=False):
pass
gr.Examples(
examples = examples,
inputs = [
prompt_input,
negative_prompt_input,
seed_input,
randomize_seed_input,
width_input,
height_input,
guidance_scale_input,
num_inference_steps_input
]
)
def add_prompt():
prompt_group.duplicate()
def clear_prompts():
prompt_group.clear()
add_button.click(add_prompt)
gr.on(
triggers=[run_button.click, prompt_input.submit, negative_prompt_input.submit],
fn=infer,
inputs=[
prompt_input,
negative_prompt_input,
seed_input,
randomize_seed_input,
width_input,
height_input,
guidance_scale_input,
num_inference_steps_input
],
outputs=[result, seed_input],
api_name="infer"
)
demo.launch()