File size: 1,920 Bytes
d9f1205
7f891bb
2e306db
a7d057d
d9f1205
d2b0012
1c4aefd
383a90d
7f891bb
a7d057d
1c4aefd
 
d2b0012
b473829
 
9d86930
d2cb214
1c4aefd
 
 
6af450a
1c4aefd
 
 
 
 
 
 
 
 
6af450a
 
1c4aefd
6b927be
 
1c4aefd
2e306db
1c4aefd
d53ee34
 
 
1c4aefd
d53ee34
 
1c4aefd
d53ee34
1c4aefd
 
 
d53ee34
1c4aefd
d53ee34
 
 
 
 
1c4aefd
 
d53ee34
 
 
d2b0012
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline

# Constants
MAX_SEED = 2**32 - 1
MAX_IMAGE_SIZE = 2048

# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

@spaces.GPU()
def generate_image(prompt, seed, width, height, num_inference_steps):
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    try:
        image = pipe(
            prompt=prompt,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            generator=generator,
            guidance_scale=0.0
        ).images[0]
        
        return image, seed
    except Exception as e:
        print(f"Error during image generation: {e}")
        import traceback
        traceback.print_exc()
        return None, seed

# Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
    
    with gr.Row():
        generate = gr.Button("Generate")
    
    with gr.Row():
        result = gr.Image(label="Generated Image")
        seed_output = gr.Number(label="Seed Used")
    
    with gr.Accordion("Advanced Settings", open=False):
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, randomize=True)
        width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)

    generate.click(
        generate_image,
        inputs=[prompt, seed, width, height, num_inference_steps],
        outputs=[result, seed_output]
    )

demo.launch()