File size: 4,752 Bytes
7f891bb
2e306db
 
 
9d86930
044186b
 
a7d057d
7f891bb
d2b0012
 
 
2e306db
d2cb214
7f891bb
a7d057d
d2b0012
 
 
 
d53c0bb
9d86930
 
 
69e75b1
044186b
6af450a
044186b
cec333d
044186b
a7d057d
044186b
 
448d742
86f0308
448d742
86f0308
 
 
 
448d742
86f0308
 
 
9d86930
 
 
 
448d742
86f0308
044186b
d2cb214
da39f41
 
 
6af450a
69e75b1
6af450a
 
 
d2b0012
5b33905
 
 
 
 
d2b0012
 
6af450a
 
 
d2b0012
6af450a
a7d057d
 
448d742
a7d057d
448d742
 
6b927be
86f0308
 
 
d2b0012
29a504c
 
 
 
 
 
 
d2b0012
 
6af450a
 
 
6b927be
 
d2b0012
2e306db
d53ee34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2b0012
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import numpy as np
import random
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline

# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

# Add a projection layer to match x_embedder input
projection = nn.Linear(16, 64).to(device).to(dtype)

def preprocess_image(image, image_size):
    preprocess = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
    return image

def process_latents(latents, height, width):
    print(f"Input latent shape: {latents.shape}")
    
    # Ensure latents are the correct shape
    if latents.shape[2:] != (height // 8, width // 8):
        latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
    print(f"Latent shape after potential interpolation: {latents.shape}")
    
    # Reshape latents to [batch_size, seq_len, channels]
    latents = latents.permute(0, 2, 3, 1).reshape(1, -1, latents.shape[1])
    print(f"Reshaped latent shape: {latents.shape}")
    
    # Project latents from 16 to 64 dimensions
    latents = projection(latents)
    print(f"Projected latent shape: {latents.shape}")
    
    return latents

@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)

    try:
        if init_image is None:
            # text2img case
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0
            ).images[0]
        else:
            # img2img case
            init_image = init_image.convert("RGB")
            init_image = preprocess_image(init_image, 1024)  # Using 1024 as FLUX VAE sample size
            
            # Encode the image using FLUX VAE
            latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215
            print(f"Initial latent shape from VAE: {latents.shape}")
            
            # Process latents to match x_embedder input
            latents = process_latents(latents, height, width)

            print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")
            print(f"First transformer block input shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")

            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0,
                latents=latents
            ).images[0]

        return image, seed
    except Exception as e:
        print(f"Error during inference: {e}")
        import traceback
        traceback.print_exc()
        return Image.new("RGB", (width, height), (255, 0, 0)), seed  # Red fallback image

# Gradio interface setup
with gr.Blocks() as demo:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
        init_image = gr.Image(label="Initial Image (optional)", type="pil")

    with gr.Row():
        generate = gr.Button("Generate")
        
    with gr.Row():
        result = gr.Image(label="Result")
        seed_output = gr.Number(label="Seed")

    with gr.Accordion("Advanced Settings", open=False):
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)

    generate.click(
        infer,
        inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
        outputs=[result, seed_output]
    )

demo.launch()