import torch from diffusers import FluxPipeline, StableDiffusion3Pipeline from PIL import Image from io import BytesIO import gradio as gr # Initialize pipelines stable_diffusion_pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16 ).to("cuda") flux_pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ).to("cuda") # Function to generate images with progress def generate_image_with_progress(pipe, prompt, num_steps, guidance_scale=None, seed=None, progress=gr.Progress()): generator = None if seed is not None: generator = torch.Generator("cpu").manual_seed(seed) print("Start generating") # Wrapper to track progress def callback(step, timestep, latents): cur_prg = step / num_steps print(f"Progressing {cur_prg} ") progress(cur_prg, desc=f"Step {step}/{num_steps}") if isinstance(pipe, StableDiffusion3Pipeline): image = pipe( prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, callback=callback, ).images[0] elif isinstance(pipe, FluxPipeline): image = pipe( prompt, num_inference_steps=num_steps, generator=generator, output_type="pil", callback=callback, ).images[0] return image # Gradio application def main(): def tab1_logic(prompt_text): progress = gr.Progress() num_steps = 30 seed = 42 print(f"Start tab {prompt_text}") image = generate_image_with_progress( flux_pipe, prompt_text, num_steps=num_steps, seed=seed, progress=progress ) return f"Seed: {seed}", image def tab2_logic(prompt_text): progress = gr.Progress() num_steps = 28 guidance_scale = 3.5 print(f"Start tab {prompt_text}") image = generate_image_with_progress( stable_diffusion_pipe, prompt_text, num_steps=num_steps, guidance_scale=guidance_scale, progress=progress ) return "Seed: None", image with gr.Blocks() as app: gr.Markdown("# Multiple Model Image Generation with Progress Bar") prompt_text = gr.Textbox(label="Enter prompt") with gr.Tab("FLUX"): button_1 = gr.Button("Run FLUX") output_1 = gr.Textbox(label="Status") img_1 = gr.Image(label="FLUX", height=300) button_1.click(fn=tab1_logic, inputs=[prompt_text], outputs=[output_1, img_1]) with gr.Tab("StableDiffusion3"): button_2 = gr.Button("Run StableDiffusion3") output_2 = gr.Textbox(label="Status") img_2 = gr.Image(label="StableDiffusion3", height=300) button_2.click(fn=tab2_logic, inputs=[prompt_text], outputs=[output_2, img_2]) app.launch() if __name__ == "__main__": main()