Tech-Meld's picture
Update app.py
5614c61 verified
raw
history blame
4.95 kB
import gradio as gr
import numpy as np
import random
import torch
from diffusers import StableDiffusion3Pipeline, SD3Transformer2DModel, FlowMatchEulerDiscreteScheduler
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
repo = "stabilityai/stable-diffusion-3-medium-diffusers"
pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=torch.float16).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
def infer(prompts, negative_prompts, seeds, randomize_seeds, widths, heights, guidance_scales, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
images = []
for i, prompt in enumerate(prompts):
if randomize_seeds[i]:
seeds[i] = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seeds[i])
image = pipe(
prompt=prompt,
negative_prompt=negative_prompts[i],
guidance_scale=guidance_scales[i],
num_inference_steps=num_inference_steps[i],
width=widths[i],
height=heights[i],
generator=generator
).images[0]
images.append(image)
return images, seeds
examples = [
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A blurry astronaut", 0, True, 512, 512, 7.5, 28],
["An astronaut riding a green horse", "Astronaut on a regular horse", 0, True, 512, 512, 7.5, 28],
["A delicious ceviche cheesecake slice", "A cheesecake that looks boring", 0, True, 512, 512, 7.5, 28],
]
css="""
#col-container {
margin: 0 auto;
max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Demo [Automated Stable Diffusion 3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
""")
with gr.Row():
prompt_group = gr.Group(elem_id="prompt_group")
with prompt_group:
prompt_input = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
negative_prompt_input = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed_input = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed_input = gr.Checkbox(label="Randomize seed", value=True)
width_input = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=64,
value=512,
)
height_input = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=64,
value=512,
)
guidance_scale_input = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.5,
)
num_inference_steps_input = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Results", show_label=False, columns=4, rows=1)
add_button = gr.Button("Add Prompt")
with gr.Accordion("Advanced Settings", open=False):
pass
gr.Examples(
examples = examples,
inputs = [
prompt_input,
negative_prompt_input,
seed_input,
randomize_seed_input,
width_input,
height_input,
guidance_scale_input,
num_inference_steps_input
]
)
def add_prompt():
prompt_group.duplicate()
def clear_prompts():
prompt_group.clear()
add_button.click(add_prompt)
gr.on(
triggers=[run_button.click, prompt_input.submit, negative_prompt_input.submit],
fn=infer,
inputs=[
prompt_input,
negative_prompt_input,
seed_input,
randomize_seed_input,
width_input,
height_input,
guidance_scale_input,
num_inference_steps_input
],
outputs=[result, seed_input],
api_name="infer"
)
demo.launch()