File size: 3,396 Bytes
8f2de47
7d40369
 
 
 
 
 
 
8f2de47
7d40369
 
 
 
 
7054a36
7d40369
 
99d1063
7054a36
 
 
99d1063
 
 
7054a36
7d40369
 
 
 
 
 
99d1063
7d40369
 
 
 
 
 
 
99d1063
7d40369
 
 
 
 
8f2de47
 
7d40369
 
 
 
 
8f2de47
 
 
7d40369
 
 
 
8f2de47
 
7d40369
 
 
 
 
8f2de47
 
 
 
7d40369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import spaces
import torch
from diffusers import FluxPipeline, StableDiffusion3Pipeline
from PIL import Image
from io import BytesIO
import gradio as gr




# Function to generate images with progress
def generate_image_with_progress(pipe, prompt, num_steps, guidance_scale=None, seed=None, progress=gr.Progress()):
    generator = None
    if seed is not None:
        generator = torch.Generator("cuda").manual_seed(seed)
    print("Start generating")
    # Wrapper to track progress
    def callback(pipe, step_index, timestep, callback_kwargs): # pipe, step_index, timestep, callback_kwargs
        print(f" callback => {pipe}, {step_index}, {timestep}, {callback_kwargs} ")
        if step_index == None:
            step_index = 0
        cur_prg = step_index / num_steps
        print(f"Progressing {cur_prg} Step {step_index}/{num_steps}")
        progress(cur_prg, desc=f"Step {step_index}/{num_steps}")
        return callback_kwargs

    if isinstance(pipe, StableDiffusion3Pipeline):
        image = pipe(
            prompt,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            callback_on_step_end=callback,
        ).images[0]
    elif isinstance(pipe, FluxPipeline):
        image = pipe(
            prompt,
            num_inference_steps=num_steps,
            generator=generator,
            output_type="pil",
            callback_on_step_end=callback,
        ).images[0]
    return image

# Gradio application
def main():
    
    @spaces.GPU(duration=170)
    def tab1_logic(prompt_text):
        progress = gr.Progress()
        num_steps = 30
        seed = 42
        print(f"Start tab {prompt_text}")
        flux_pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
        ).to("cuda")
        image = generate_image_with_progress(
            flux_pipe, prompt_text, num_steps=num_steps, seed=seed, progress=progress
        )
        return f"Seed: {seed}", image
    
    @spaces.GPU(duration=170)
    def tab2_logic(prompt_text):
        progress = gr.Progress()
        num_steps = 28
        guidance_scale = 3.5
        print(f"Start tab {prompt_text}")
        # Initialize pipelines
        stable_diffusion_pipe = StableDiffusion3Pipeline.from_pretrained(
        "stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16
        ).to("cuda")
        image = generate_image_with_progress(
            stable_diffusion_pipe, prompt_text, num_steps=num_steps, guidance_scale=guidance_scale, progress=progress
        )
        return "Seed: None", image

    with gr.Blocks() as app:
        gr.Markdown("# Multiple Model Image Generation with Progress Bar")

        prompt_text = gr.Textbox(label="Enter prompt")

        with gr.Tab("FLUX"):
            button_1 = gr.Button("Run FLUX")
            output_1 = gr.Textbox(label="Status")
            img_1 = gr.Image(label="FLUX", height=300)
            button_1.click(fn=tab1_logic, inputs=[prompt_text], outputs=[output_1, img_1])

        with gr.Tab("StableDiffusion3"):
            button_2 = gr.Button("Run StableDiffusion3")
            output_2 = gr.Textbox(label="Status")
            img_2 = gr.Image(label="StableDiffusion3", height=300)
            button_2.click(fn=tab2_logic, inputs=[prompt_text], outputs=[output_2, img_2])

    app.launch()

if __name__ == "__main__":
    main()