File size: 3,697 Bytes
d9f1205
7f891bb
2e306db
044186b
a7d057d
d9f1205
d2c614b
d2b0012
d2c614b
d2cb214
b473829
 
7f891bb
a7d057d
b473829
d2b0012
b473829
 
9d86930
d2c614b
 
 
 
 
 
044186b
d2c614b
044186b
d2cb214
d2c614b
b473829
69e75b1
6af450a
 
 
d2c614b
d2b0012
5b33905
 
 
 
 
d2c614b
d2b0012
6af450a
 
d2c614b
 
d2b0012
29a504c
d2c614b
29a504c
 
d2c614b
d2b0012
 
6af450a
d2c614b
 
 
 
 
 
 
 
 
 
6af450a
d2c614b
6b927be
 
d2c614b
2e306db
d2c614b
d53ee34
 
 
 
 
 
 
 
 
 
 
 
 
d2c614b
d53ee34
 
 
d2c614b
d53ee34
 
 
d2c614b
d53ee34
 
 
d2b0012
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import spaces
import gradio as gr
import torch
from PIL import Image
from diffusers import DiffusionPipeline


# Constants
MAX_SEED = 2**32 - 1
MAX_IMAGE_SIZE = 2048
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

def print_model_shapes(pipe):
    print("Model component shapes:")
    print(f"VAE Encoder: {pipe.vae.encoder}")
    print(f"VAE Decoder: {pipe.vae.decoder}")
    print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}")
    print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")

print_model_shapes(pipe)

@spaces.GPU()
def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0):
    generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None

    try:
        if init_image is None:
            # text2img case
            print("Running text-to-image generation")
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=guidance_scale
            ).images[0]
        else:
            # img2img case
            print("Running image-to-image generation")
            init_image = init_image.convert("RGB").resize((width, height))
            image = pipe(
                prompt=prompt,
                image=init_image,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=guidance_scale
            ).images[0]

        return image, seed
    except RuntimeError as e:
        if "mat1 and mat2 shapes cannot be multiplied" in str(e):
            print("Matrix multiplication error detected. Tensor shapes:")
            print(e)
            # Here you could add code to print shapes of specific tensors if needed
        else:
            print(f"RuntimeError during inference: {e}")
        import traceback
        traceback.print_exc()
        return Image.new("RGB", (width, height), (255, 0, 0)), seed
    except Exception as e:
        print(f"Unexpected error during inference: {e}")
        import traceback
        traceback.print_exc()
        return Image.new("RGB", (width, height), (255, 0, 0)), seed

# Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
        init_image = gr.Image(label="Initial Image (optional)", type="pil")

    with gr.Row():
        generate = gr.Button("Generate")
        
    with gr.Row():
        result = gr.Image(label="Result")
        seed_output = gr.Number(label="Seed")

    with gr.Accordion("Advanced Settings", open=False):
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None)
        width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
        guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0)

    generate.click(
        infer,
        inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale],
        outputs=[result, seed_output]
    )

demo.launch()