import spaces import gradio as gr import torch from PIL import Image from diffusers import DiffusionPipeline # Constants MAX_SEED = 2**32 - 1 MAX_IMAGE_SIZE = 2048 dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" # Load FLUX model pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) pipe.enable_model_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling() def print_model_shapes(pipe): print("Model component shapes:") print(f"VAE Encoder: {pipe.vae.encoder}") print(f"VAE Decoder: {pipe.vae.decoder}") print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}") print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}") print_model_shapes(pipe) @spaces.GPU() def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0): generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None try: if init_image is None: # text2img case print("Running text-to-image generation") image = pipe( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale ).images[0] else: # img2img case print("Running image-to-image generation") init_image = init_image.convert("RGB").resize((width, height)) image = pipe( prompt=prompt, image=init_image, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale ).images[0] return image, seed except RuntimeError as e: if "mat1 and mat2 shapes cannot be multiplied" in str(e): print("Matrix multiplication error detected. Tensor shapes:") print(e) # Here you could add code to print shapes of specific tensors if needed else: print(f"RuntimeError during inference: {e}") import traceback traceback.print_exc() return Image.new("RGB", (width, height), (255, 0, 0)), seed except Exception as e: print(f"Unexpected error during inference: {e}") import traceback traceback.print_exc() return Image.new("RGB", (width, height), (255, 0, 0)), seed # Gradio interface with gr.Blocks() as demo: with gr.Row(): prompt = gr.Textbox(label="Prompt") init_image = gr.Image(label="Initial Image (optional)", type="pil") with gr.Row(): generate = gr.Button("Generate") with gr.Row(): result = gr.Image(label="Result") seed_output = gr.Number(label="Seed") with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None) width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4) guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0) generate.click( infer, inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale], outputs=[result, seed_output] ) demo.launch()