File size: 5,490 Bytes
d9f1205
7f891bb
383a90d
 
2e306db
383a90d
044186b
383a90d
a7d057d
d9f1205
d2b0012
b473829
 
383a90d
 
 
3be64a5
 
383a90d
7f891bb
a7d057d
b473829
d2b0012
b473829
 
9d86930
3be64a5
 
044186b
383a90d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3be64a5
383a90d
 
 
3be64a5
 
 
 
 
 
 
 
 
383a90d
044186b
d2cb214
383a90d
 
 
 
69e75b1
6af450a
 
 
d2b0012
5b33905
 
 
 
 
383a90d
d2b0012
6af450a
 
383a90d
 
 
 
 
 
 
3be64a5
383a90d
 
3be64a5
 
 
 
 
 
 
383a90d
d2b0012
29a504c
383a90d
 
29a504c
 
383a90d
3be64a5
d2b0012
 
6af450a
 
383a90d
6b927be
 
383a90d
2e306db
383a90d
d53ee34
 
 
 
 
 
 
 
 
 
 
 
 
383a90d
 
d53ee34
 
 
 
 
 
383a90d
d53ee34
 
 
d2b0012
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import spaces
import gradio as gr
import numpy as np
import random
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline

# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
LATENT_CHANNELS = 16
TEXT_EMBED_DIM = 768
MAX_TEXT_EMBEDDINGS = 77
SCALING_FACTOR = 0.3611

# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

# Add a projection layer to match text embedding dimension
projection = nn.Linear(LATENT_CHANNELS, TEXT_EMBED_DIM).to(device).to(dtype)

def preprocess_image(image, image_size):
    preprocess = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
    return image

def process_latents(latents, height, width):
    print(f"Input latent shape: {latents.shape}")
    
    # Ensure latents are the correct shape
    if latents.shape[2:] != (height // 8, width // 8):
        latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
    print(f"Latent shape after potential interpolation: {latents.shape}")
    
    # Reshape latents to [batch_size, seq_len, channels]
    latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
    print(f"Reshaped latent shape: {latents.shape}")
    
    # Project latents to match text embedding dimension
    latents = projection(latents)
    print(f"Projected latent shape: {latents.shape}")
    
    # Adjust sequence length to match text embeddings
    seq_len = latents.shape[1]
    if seq_len > MAX_TEXT_EMBEDDINGS:
        latents = latents[:, :MAX_TEXT_EMBEDDINGS, :]
    elif seq_len < MAX_TEXT_EMBEDDINGS:
        pad_len = MAX_TEXT_EMBEDDINGS - seq_len
        latents = torch.nn.functional.pad(latents, (0, 0, 0, pad_len, 0, 0))
    print(f"Final latent shape: {latents.shape}")
    
    return latents

@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)

    try:
        if init_image is None:
            # text2img case
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0
            ).images[0]
        else:
            # img2img case
            init_image = init_image.convert("RGB")
            init_image = preprocess_image(init_image, 1024)  # Using 1024 as FLUX VAE sample size
            
            # Encode the image using FLUX VAE
            latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
            print(f"Initial latent shape from VAE: {latents.shape}")
            
            # Process latents to match text embedding format
            latents = process_latents(latents, height, width)

            # Get text embeddings
            text_embeddings = pipe.transformer.text_encoder([prompt])
            print(f"Text embedding shape: {text_embeddings.shape}")

            # Combine image latents and text embeddings
            combined_embeddings = torch.cat([latents, text_embeddings], dim=1)
            print(f"Combined embedding shape: {combined_embeddings.shape}")

            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0,
                latents=combined_embeddings
            ).images[0]

        return image, seed
    except Exception as e:
        print(f"Error during inference: {e}")
        import traceback
        traceback.print_exc()
        return Image.new("RGB", (width, height), (255, 0, 0)), seed  # Red fallback image

# Gradio interface setup
with gr.Blocks() as demo:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
        init_image = gr.Image(label="Initial Image (optional)", type="pil")

    with gr.Row():
        generate = gr.Button("Generate")
        
    with gr.Row():
        result = gr.Image(label="Result")
        seed_output = gr.Number(label="Seed")

    with gr.Accordion("Advanced Settings", open=False):
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)

    generate.click(
        infer,
        inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
        outputs=[result, seed_output]
    )

demo.launch()