Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from diffusers import DiffusionPipeline | |
| # Constants | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| # Load FLUX model | |
| pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) | |
| pipe.enable_model_cpu_offload() | |
| pipe.vae.enable_slicing() | |
| pipe.vae.enable_tiling() | |
| def preprocess_image(image, image_size): | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ]) | |
| image = preprocess(image).unsqueeze(0).to(device, dtype=dtype) | |
| return image | |
| def check_shapes(latents): | |
| # Get the shape of the latents | |
| latent_shape = latents.shape | |
| print(f"Latent shape: {latent_shape}") | |
| # Get the expected shape for the transformer input | |
| expected_shape = (1, latent_shape[1] * latent_shape[2], latent_shape[3]) | |
| print(f"Expected transformer input shape: {expected_shape}") | |
| # Get the shape of the transformer's weight matrix | |
| if hasattr(pipe.transformer, 'text_model'): | |
| weight_shape = pipe.transformer.text_model.encoder.layers[0].self_attn.q_proj.weight.shape | |
| else: | |
| weight_shape = pipe.transformer.encoder.layers[0].self_attn.q_proj.weight.shape | |
| print(f"Transformer weight shape: {weight_shape}") | |
| # Check if the shapes are compatible for matrix multiplication | |
| if expected_shape[1] == weight_shape[1]: | |
| print("Shapes are compatible for matrix multiplication.") | |
| else: | |
| print("Warning: Shapes are not compatible for matrix multiplication.") | |
| print(f"Expected: {expected_shape[1]}, Got: {weight_shape[1]}") | |
| def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| try: | |
| if init_image is None: | |
| # text2img case | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| guidance_scale=0.0 | |
| ).images[0] | |
| else: | |
| # img2img case | |
| init_image = init_image.convert("RGB") | |
| init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size | |
| # Encode the image using FLUX VAE | |
| latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215 | |
| # Ensure latents are the correct shape | |
| latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear') | |
| # Check shapes before reshaping | |
| check_shapes(latents) | |
| # Reshape latents to match the expected input shape of the transformer | |
| latents = latents.permute(0, 2, 3, 1).contiguous().view(1, -1, pipe.vae.config.latent_channels) | |
| # Check shapes after reshaping | |
| check_shapes(latents) | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| guidance_scale=0.0, | |
| latents=latents | |
| ).images[0] | |
| return image, seed | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image | |
| demo.launch() |