flux-lightning / app.py
Jordan Legg
added more console logging
6b927be
raw
history blame
4.92 kB
import spaces
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline
# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
def preprocess_image(image, image_size):
preprocess = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
return image
def check_shapes(latents):
print(f"Latent shape: {latents.shape}")
if len(latents.shape) == 4:
print(f"Expected transformer input shape: {(1, latents.shape[1] * latents.shape[2] * latents.shape[3])}")
elif len(latents.shape) == 2:
print(f"Reshaped latent shape: {latents.shape}")
else:
print(f"Unexpected latent shape: {latents.shape}")
@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
if init_image is None:
# text2img case
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
else:
# img2img case
init_image = init_image.convert("RGB")
init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
# Encode the image using FLUX VAE
latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215
# Ensure latents are the correct shape
latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
# Check shapes before reshaping
check_shapes(latents)
# Reshape latents to match the expected input shape of the transformer
latents = latents.reshape(1, -1)
# Check shapes after reshaping
check_shapes(latents)
# Print the type and shape of each argument
print(f"prompt type: {type(prompt)}, value: {prompt}")
print(f"height type: {type(height)}, value: {height}")
print(f"width type: {type(width)}, value: {width}")
print(f"num_inference_steps type: {type(num_inference_steps)}, value: {num_inference_steps}")
print(f"generator type: {type(generator)}")
print(f"guidance_scale type: {type(0.0)}, value: 0.0")
print(f"latents type: {type(latents)}, shape: {latents.shape}")
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0,
latents=latents
).images[0]
return image, seed
except Exception as e:
print(f"Error during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
# Gradio interface setup
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt")
init_image = gr.Image(label="Initial Image (optional)", type="pil")
with gr.Row():
generate = gr.Button("Generate")
with gr.Row():
result = gr.Image(label="Result")
seed_output = gr.Number(label="Seed")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
generate.click(
infer,
inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
outputs=[result, seed_output]
)
demo.launch()