Spaces:
Runtime error
Runtime error
File size: 3,697 Bytes
d9f1205 7f891bb 2e306db 044186b a7d057d d9f1205 d2c614b d2b0012 d2c614b d2cb214 b473829 7f891bb a7d057d b473829 d2b0012 b473829 9d86930 d2c614b 044186b d2c614b 044186b d2cb214 d2c614b b473829 69e75b1 6af450a d2c614b d2b0012 5b33905 d2c614b d2b0012 6af450a d2c614b d2b0012 29a504c d2c614b 29a504c d2c614b d2b0012 6af450a d2c614b 6af450a d2c614b 6b927be d2c614b 2e306db d2c614b d53ee34 d2c614b d53ee34 d2c614b d53ee34 d2c614b d53ee34 d2b0012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import spaces
import gradio as gr
import torch
from PIL import Image
from diffusers import DiffusionPipeline
# Constants
MAX_SEED = 2**32 - 1
MAX_IMAGE_SIZE = 2048
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
def print_model_shapes(pipe):
print("Model component shapes:")
print(f"VAE Encoder: {pipe.vae.encoder}")
print(f"VAE Decoder: {pipe.vae.decoder}")
print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}")
print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
print_model_shapes(pipe)
@spaces.GPU()
def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0):
generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
try:
if init_image is None:
# text2img case
print("Running text-to-image generation")
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale
).images[0]
else:
# img2img case
print("Running image-to-image generation")
init_image = init_image.convert("RGB").resize((width, height))
image = pipe(
prompt=prompt,
image=init_image,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return image, seed
except RuntimeError as e:
if "mat1 and mat2 shapes cannot be multiplied" in str(e):
print("Matrix multiplication error detected. Tensor shapes:")
print(e)
# Here you could add code to print shapes of specific tensors if needed
else:
print(f"RuntimeError during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed
except Exception as e:
print(f"Unexpected error during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt")
init_image = gr.Image(label="Initial Image (optional)", type="pil")
with gr.Row():
generate = gr.Button("Generate")
with gr.Row():
result = gr.Image(label="Result")
seed_output = gr.Number(label="Seed")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0)
generate.click(
infer,
inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale],
outputs=[result, seed_output]
)
demo.launch() |