File size: 3,945 Bytes
d53c0bb
7f891bb
2e306db
 
 
044186b
 
a7d057d
7f891bb
d2b0012
 
 
2e306db
d2cb214
7f891bb
a7d057d
d2b0012
 
 
 
d53c0bb
69e75b1
044186b
6af450a
044186b
cec333d
044186b
a7d057d
044186b
 
a7d057d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
044186b
d2cb214
da39f41
 
 
6af450a
69e75b1
6af450a
 
 
d2b0012
5b33905
 
 
 
 
d2b0012
 
6af450a
 
 
d2b0012
6af450a
a7d057d
 
 
 
6af450a
 
a7d057d
 
 
 
 
13ab5d1
a7d057d
 
5b33905
d2b0012
29a504c
 
 
 
 
 
 
d2b0012
 
6af450a
 
 
d2b0012
2e306db
7f891bb
d2b0012
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import spaces
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline

# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

def preprocess_image(image, image_size):
    preprocess = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
    return image

def check_shapes(latents):
    # Get the shape of the latents
    latent_shape = latents.shape
    print(f"Latent shape: {latent_shape}")

    # Get the expected shape for the transformer input
    expected_shape = (1, latent_shape[1] * latent_shape[2], latent_shape[3])
    print(f"Expected transformer input shape: {expected_shape}")

    # Get the shape of the transformer's weight matrix
    if hasattr(pipe.transformer, 'text_model'):
        weight_shape = pipe.transformer.text_model.encoder.layers[0].self_attn.q_proj.weight.shape
    else:
        weight_shape = pipe.transformer.encoder.layers[0].self_attn.q_proj.weight.shape
    print(f"Transformer weight shape: {weight_shape}")

    # Check if the shapes are compatible for matrix multiplication
    if expected_shape[1] == weight_shape[1]:
        print("Shapes are compatible for matrix multiplication.")
    else:
        print("Warning: Shapes are not compatible for matrix multiplication.")
        print(f"Expected: {expected_shape[1]}, Got: {weight_shape[1]}")

@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)

    try:
        if init_image is None:
            # text2img case
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0
            ).images[0]
        else:
            # img2img case
            init_image = init_image.convert("RGB")
            init_image = preprocess_image(init_image, 1024)  # Using 1024 as FLUX VAE sample size
            
            # Encode the image using FLUX VAE
            latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215
            
            # Ensure latents are the correct shape
            latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
            
            # Check shapes before reshaping
            check_shapes(latents)
            
            # Reshape latents to match the expected input shape of the transformer
            latents = latents.permute(0, 2, 3, 1).contiguous().view(1, -1, pipe.vae.config.latent_channels)

            # Check shapes after reshaping
            check_shapes(latents)

            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0,
                latents=latents
            ).images[0]

        return image, seed
    except Exception as e:
        print(f"Error during inference: {e}")
        return Image.new("RGB", (width, height), (255, 0, 0)), seed  # Red fallback image


demo.launch()