flux-lightning / app.py
Jordan Legg
moved import to the top
d9f1205
raw
history blame
3.7 kB
import spaces
import gradio as gr
import torch
from PIL import Image
from diffusers import DiffusionPipeline
# Constants
MAX_SEED = 2**32 - 1
MAX_IMAGE_SIZE = 2048
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
def print_model_shapes(pipe):
print("Model component shapes:")
print(f"VAE Encoder: {pipe.vae.encoder}")
print(f"VAE Decoder: {pipe.vae.decoder}")
print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}")
print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
print_model_shapes(pipe)
@spaces.GPU()
def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0):
generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
try:
if init_image is None:
# text2img case
print("Running text-to-image generation")
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale
).images[0]
else:
# img2img case
print("Running image-to-image generation")
init_image = init_image.convert("RGB").resize((width, height))
image = pipe(
prompt=prompt,
image=init_image,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return image, seed
except RuntimeError as e:
if "mat1 and mat2 shapes cannot be multiplied" in str(e):
print("Matrix multiplication error detected. Tensor shapes:")
print(e)
# Here you could add code to print shapes of specific tensors if needed
else:
print(f"RuntimeError during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed
except Exception as e:
print(f"Unexpected error during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt")
init_image = gr.Image(label="Initial Image (optional)", type="pil")
with gr.Row():
generate = gr.Button("Generate")
with gr.Row():
result = gr.Image(label="Result")
seed_output = gr.Number(label="Seed")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0)
generate.click(
infer,
inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale],
outputs=[result, seed_output]
)
demo.launch()