Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from diffusers import DiffusionPipeline | |
# Constants | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
# Load FLUX model | |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) | |
pipe.enable_model_cpu_offload() | |
pipe.vae.enable_slicing() | |
pipe.vae.enable_tiling() | |
def preprocess_image(image, image_size): | |
preprocess = transforms.Compose([ | |
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]) | |
]) | |
image = preprocess(image).unsqueeze(0).to(device, dtype=dtype) | |
return image | |
def check_shapes(latents): | |
# Get the shape of the latents | |
latent_shape = latents.shape | |
print(f"Latent shape: {latent_shape}") | |
# Get the expected shape for the transformer input | |
expected_shape = (1, latent_shape[1] * latent_shape[2], latent_shape[3]) | |
print(f"Expected transformer input shape: {expected_shape}") | |
# Get the shape of the transformer's weight matrix | |
if hasattr(pipe.transformer, 'text_model'): | |
weight_shape = pipe.transformer.text_model.encoder.layers[0].self_attn.q_proj.weight.shape | |
else: | |
weight_shape = pipe.transformer.encoder.layers[0].self_attn.q_proj.weight.shape | |
print(f"Transformer weight shape: {weight_shape}") | |
# Check if the shapes are compatible for matrix multiplication | |
if expected_shape[1] == weight_shape[1]: | |
print("Shapes are compatible for matrix multiplication.") | |
else: | |
print("Warning: Shapes are not compatible for matrix multiplication.") | |
print(f"Expected: {expected_shape[1]}, Got: {weight_shape[1]}") | |
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
try: | |
if init_image is None: | |
# text2img case | |
image = pipe( | |
prompt=prompt, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
guidance_scale=0.0 | |
).images[0] | |
else: | |
# img2img case | |
init_image = init_image.convert("RGB") | |
init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size | |
# Encode the image using FLUX VAE | |
latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215 | |
# Ensure latents are the correct shape | |
latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear') | |
# Check shapes before reshaping | |
check_shapes(latents) | |
# Reshape latents to match the expected input shape of the transformer | |
latents = latents.permute(0, 2, 3, 1).contiguous().view(1, -1, pipe.vae.config.latent_channels) | |
# Check shapes after reshaping | |
check_shapes(latents) | |
image = pipe( | |
prompt=prompt, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
guidance_scale=0.0, | |
latents=latents | |
).images[0] | |
return image, seed | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image | |
demo.launch() |