import gradio as gr import numpy as np import random import torch from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput import spaces device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 # Use the correct repo for SDXL repo = "stabilityai/sdxl-turbo" # This is the correct repo for SDXL # Load the model components separately vae = AutoencoderKL.from_pretrained(repo, subfolder="vae", torch_dtype=torch.float16).to(device) text_encoder = SD3Transformer2DModel.from_pretrained(repo, subfolder="text_encoder", torch_dtype=torch.float16).to(device) unet = UNet2DConditionModel.from_pretrained(repo, subfolder="unet", torch_dtype=torch.float16).to(device) scheduler = EulerDiscreteScheduler.from_pretrained(repo, subfolder="scheduler", torch_dtype=torch.float16) # Construct the pipeline (this is how you work with SDXL) pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, unet=unet, scheduler=scheduler ).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1344 def infer(prompts, negative_prompts, seeds, randomize_seeds, widths, heights, guidance_scales, num_inference_steps, progress=gr.Progress(track_tqdm=True)): images = [] for i, prompt in enumerate(prompts): if randomize_seeds[i]: seeds[i] = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seeds[i]) # SDXL requires a slightly different call format: image = pipe( prompt=prompt, negative_prompt=negative_prompts[i], guidance_scale=guidance_scales[i], num_inference_steps=num_inference_steps[i], width=widths[i], height=heights[i], generator=generator ).images[0] images.append(image) return images, seeds examples = [ ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A blurry astronaut", 0, True, 512, 512, 7.5, 28], ["An astronaut riding a green horse", "Astronaut on a regular horse", 0, True, 512, 512, 7.5, 28], ["A delicious ceviche cheesecake slice", "A cheesecake that looks boring", 0, True, 512, 512, 7.5, 28], ] css=""" #col-container { margin: 0 auto; max-width: 580px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f""" # Demo [Automated Stable Diffusion XL](https://huggingface.co/stabilityai/stablediffusion-xl) """) with gr.Row(): prompt_group = gr.Group(elem_id="prompt_group") with prompt_group: prompt_input = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) negative_prompt_input = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", ) seed_input = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed_input = gr.Checkbox(label="Randomize seed", value=True) width_input = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=512, ) height_input = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=512, ) guidance_scale_input = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5, ) num_inference_steps_input = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=28, ) run_button = gr.Button("Run", scale=0) result = gr.Gallery(label="Results", show_label=False, columns=4, rows=1) add_button = gr.Button("Add Prompt") with gr.Accordion("Advanced Settings", open=False): pass gr.Examples( examples = examples, inputs = [ prompt_input, negative_prompt_input, seed_input, randomize_seed_input, width_input, height_input, guidance_scale_input, num_inference_steps_input ] ) def add_prompt(): prompt_group.duplicate() def clear_prompts(): prompt_group.clear() add_button.click(add_prompt) gr.on( triggers=[run_button.click, prompt_input.submit, negative_prompt_input.submit], fn=infer, inputs=[ prompt_input, negative_prompt_input, seed_input, randomize_seed_input, width_input, height_input, guidance_scale_input, num_inference_steps_input ], outputs=[result, seed_input], api_name="infer" ) demo.launch()