Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from diffusers import DiffusionPipeline | |
# Constants | |
MAX_SEED = 2**32 - 1 | |
MAX_IMAGE_SIZE = 2048 | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load FLUX model | |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) | |
pipe.enable_model_cpu_offload() | |
pipe.vae.enable_slicing() | |
pipe.vae.enable_tiling() | |
def print_model_shapes(pipe): | |
print("Model component shapes:") | |
print(f"VAE Encoder: {pipe.vae.encoder}") | |
print(f"VAE Decoder: {pipe.vae.decoder}") | |
print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}") | |
print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}") | |
print_model_shapes(pipe) | |
def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0): | |
generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None | |
try: | |
if init_image is None: | |
# text2img case | |
print("Running text-to-image generation") | |
image = pipe( | |
prompt=prompt, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
guidance_scale=guidance_scale | |
).images[0] | |
else: | |
# img2img case | |
print("Running image-to-image generation") | |
init_image = init_image.convert("RGB").resize((width, height)) | |
image = pipe( | |
prompt=prompt, | |
image=init_image, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
guidance_scale=guidance_scale | |
).images[0] | |
return image, seed | |
except RuntimeError as e: | |
if "mat1 and mat2 shapes cannot be multiplied" in str(e): | |
print("Matrix multiplication error detected. Tensor shapes:") | |
print(e) | |
# Here you could add code to print shapes of specific tensors if needed | |
else: | |
print(f"RuntimeError during inference: {e}") | |
import traceback | |
traceback.print_exc() | |
return Image.new("RGB", (width, height), (255, 0, 0)), seed | |
except Exception as e: | |
print(f"Unexpected error during inference: {e}") | |
import traceback | |
traceback.print_exc() | |
return Image.new("RGB", (width, height), (255, 0, 0)), seed | |
# Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt") | |
init_image = gr.Image(label="Initial Image (optional)", type="pil") | |
with gr.Row(): | |
generate = gr.Button("Generate") | |
with gr.Row(): | |
result = gr.Image(label="Result") | |
seed_output = gr.Number(label="Seed") | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None) | |
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4) | |
guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0) | |
generate.click( | |
infer, | |
inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale], | |
outputs=[result, seed_output] | |
) | |
demo.launch() |