Spaces:
Runtime error
Runtime error
File size: 3,606 Bytes
6cbacd9 7f891bb 2e306db 044186b a7d057d 7f891bb d2c614b d2b0012 d2c614b d2cb214 7f891bb a7d057d d2c614b d2b0012 d2c614b 9d86930 d2c614b 044186b d2c614b 044186b d2cb214 d2c614b 69e75b1 6af450a d2c614b d2b0012 5b33905 d2c614b d2b0012 6af450a d2c614b d2b0012 29a504c d2c614b 29a504c d2c614b d2b0012 6af450a d2c614b 6af450a d2c614b 6b927be d2c614b 2e306db d2c614b d53ee34 d2c614b d53ee34 d2c614b d53ee34 d2c614b d53ee34 d2b0012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import spaces
import gradio as gr
import torch
from PIL import Image
from diffusers import DiffusionPipeline
# Constants
MAX_SEED = 2**32 - 1
MAX_IMAGE_SIZE = 2048
# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
def print_model_shapes(pipe):
print("Model component shapes:")
print(f"VAE Encoder: {pipe.vae.encoder}")
print(f"VAE Decoder: {pipe.vae.decoder}")
print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}")
print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
print_model_shapes(pipe)
@spaces.GPU()
def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0):
generator = torch.Generator(device="cuda").manual_seed(seed) if seed is not None else None
try:
if init_image is None:
# text2img case
print("Running text-to-image generation")
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale
).images[0]
else:
# img2img case
print("Running image-to-image generation")
init_image = init_image.convert("RGB").resize((width, height))
image = pipe(
prompt=prompt,
image=init_image,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return image, seed
except RuntimeError as e:
if "mat1 and mat2 shapes cannot be multiplied" in str(e):
print("Matrix multiplication error detected. Tensor shapes:")
print(e)
# Here you could add code to print shapes of specific tensors if needed
else:
print(f"RuntimeError during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed
except Exception as e:
print(f"Unexpected error during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt")
init_image = gr.Image(label="Initial Image (optional)", type="pil")
with gr.Row():
generate = gr.Button("Generate")
with gr.Row():
result = gr.Image(label="Result")
seed_output = gr.Number(label="Seed")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0)
generate.click(
infer,
inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale],
outputs=[result, seed_output]
)
demo.launch() |