RageshAntony's picture
return kw
7054a36 verified
raw
history blame
3.4 kB
import spaces
import torch
from diffusers import FluxPipeline, StableDiffusion3Pipeline
from PIL import Image
from io import BytesIO
import gradio as gr
# Function to generate images with progress
def generate_image_with_progress(pipe, prompt, num_steps, guidance_scale=None, seed=None, progress=gr.Progress()):
generator = None
if seed is not None:
generator = torch.Generator("cuda").manual_seed(seed)
print("Start generating")
# Wrapper to track progress
def callback(pipe, step_index, timestep, callback_kwargs): # pipe, step_index, timestep, callback_kwargs
print(f" callback => {pipe}, {step_index}, {timestep}, {callback_kwargs} ")
if step_index == None:
step_index = 0
cur_prg = step_index / num_steps
print(f"Progressing {cur_prg} Step {step_index}/{num_steps}")
progress(cur_prg, desc=f"Step {step_index}/{num_steps}")
return callback_kwargs
if isinstance(pipe, StableDiffusion3Pipeline):
image = pipe(
prompt,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
callback_on_step_end=callback,
).images[0]
elif isinstance(pipe, FluxPipeline):
image = pipe(
prompt,
num_inference_steps=num_steps,
generator=generator,
output_type="pil",
callback_on_step_end=callback,
).images[0]
return image
# Gradio application
def main():
@spaces.GPU(duration=170)
def tab1_logic(prompt_text):
progress = gr.Progress()
num_steps = 30
seed = 42
print(f"Start tab {prompt_text}")
flux_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
image = generate_image_with_progress(
flux_pipe, prompt_text, num_steps=num_steps, seed=seed, progress=progress
)
return f"Seed: {seed}", image
@spaces.GPU(duration=170)
def tab2_logic(prompt_text):
progress = gr.Progress()
num_steps = 28
guidance_scale = 3.5
print(f"Start tab {prompt_text}")
# Initialize pipelines
stable_diffusion_pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16
).to("cuda")
image = generate_image_with_progress(
stable_diffusion_pipe, prompt_text, num_steps=num_steps, guidance_scale=guidance_scale, progress=progress
)
return "Seed: None", image
with gr.Blocks() as app:
gr.Markdown("# Multiple Model Image Generation with Progress Bar")
prompt_text = gr.Textbox(label="Enter prompt")
with gr.Tab("FLUX"):
button_1 = gr.Button("Run FLUX")
output_1 = gr.Textbox(label="Status")
img_1 = gr.Image(label="FLUX", height=300)
button_1.click(fn=tab1_logic, inputs=[prompt_text], outputs=[output_1, img_1])
with gr.Tab("StableDiffusion3"):
button_2 = gr.Button("Run StableDiffusion3")
output_2 = gr.Textbox(label="Status")
img_2 = gr.Image(label="StableDiffusion3", height=300)
button_2.click(fn=tab2_logic, inputs=[prompt_text], outputs=[output_2, img_2])
app.launch()
if __name__ == "__main__":
main()