File size: 4,915 Bytes
d53c0bb
7f891bb
2e306db
 
 
044186b
 
a7d057d
7f891bb
d2b0012
 
 
2e306db
d2cb214
7f891bb
a7d057d
d2b0012
 
 
 
d53c0bb
69e75b1
044186b
6af450a
044186b
cec333d
044186b
a7d057d
044186b
 
a7d057d
6b927be
 
 
 
 
 
 
044186b
d2cb214
da39f41
 
 
6af450a
69e75b1
6af450a
 
 
d2b0012
5b33905
 
 
 
 
d2b0012
 
6af450a
 
 
d2b0012
6af450a
a7d057d
 
 
 
6af450a
 
a7d057d
 
 
 
bc9da49
13ab5d1
a7d057d
 
5b33905
6b927be
 
 
 
 
 
 
 
 
d2b0012
29a504c
 
 
 
 
 
 
d2b0012
 
6af450a
 
 
6b927be
 
d2b0012
2e306db
d53ee34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2b0012
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import spaces
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline

# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

def preprocess_image(image, image_size):
    preprocess = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
    return image

def check_shapes(latents):
    print(f"Latent shape: {latents.shape}")
    if len(latents.shape) == 4:
        print(f"Expected transformer input shape: {(1, latents.shape[1] * latents.shape[2] * latents.shape[3])}")
    elif len(latents.shape) == 2:
        print(f"Reshaped latent shape: {latents.shape}")
    else:
        print(f"Unexpected latent shape: {latents.shape}")

@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)

    try:
        if init_image is None:
            # text2img case
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0
            ).images[0]
        else:
            # img2img case
            init_image = init_image.convert("RGB")
            init_image = preprocess_image(init_image, 1024)  # Using 1024 as FLUX VAE sample size
            
            # Encode the image using FLUX VAE
            latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215
            
            # Ensure latents are the correct shape
            latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
            
            # Check shapes before reshaping
            check_shapes(latents)
            
            # Reshape latents to match the expected input shape of the transformer
            latents = latents.reshape(1, -1)

            # Check shapes after reshaping
            check_shapes(latents)

            # Print the type and shape of each argument
            print(f"prompt type: {type(prompt)}, value: {prompt}")
            print(f"height type: {type(height)}, value: {height}")
            print(f"width type: {type(width)}, value: {width}")
            print(f"num_inference_steps type: {type(num_inference_steps)}, value: {num_inference_steps}")
            print(f"generator type: {type(generator)}")
            print(f"guidance_scale type: {type(0.0)}, value: 0.0")
            print(f"latents type: {type(latents)}, shape: {latents.shape}")

            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0,
                latents=latents
            ).images[0]

        return image, seed
    except Exception as e:
        print(f"Error during inference: {e}")
        import traceback
        traceback.print_exc()
        return Image.new("RGB", (width, height), (255, 0, 0)), seed  # Red fallback image

# Gradio interface setup
with gr.Blocks() as demo:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
        init_image = gr.Image(label="Initial Image (optional)", type="pil")

    with gr.Row():
        generate = gr.Button("Generate")
        
    with gr.Row():
        result = gr.Image(label="Result")
        seed_output = gr.Number(label="Seed")

    with gr.Accordion("Advanced Settings", open=False):
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)

    generate.click(
        infer,
        inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
        outputs=[result, seed_output]
    )

demo.launch()