flux-lightning / app.py
Jordan Legg
move back to complex code
383a90d
raw
history blame
4.88 kB
import spaces
import gradio as gr
import numpy as np
import random
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline
# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
LATENT_CHANNELS = 16
TRANSFORMER_IN_CHANNELS = 64
SCALING_FACTOR = 0.3611
# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# Add a projection layer to match transformer input
projection = nn.Linear(LATENT_CHANNELS, TRANSFORMER_IN_CHANNELS).to(device).to(dtype)
def preprocess_image(image, image_size):
preprocess = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
return image
def process_latents(latents, height, width):
print(f"Input latent shape: {latents.shape}")
# Ensure latents are the correct shape
if latents.shape[2:] != (height // 8, width // 8):
latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
print(f"Latent shape after potential interpolation: {latents.shape}")
# Reshape latents to [batch_size, seq_len, channels]
latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
print(f"Reshaped latent shape: {latents.shape}")
# Project latents from 16 to 64 dimensions
latents = projection(latents)
print(f"Projected latent shape: {latents.shape}")
return latents
@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
if init_image is None:
# text2img case
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
else:
# img2img case
init_image = init_image.convert("RGB")
init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
# Encode the image using FLUX VAE
latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
print(f"Initial latent shape from VAE: {latents.shape}")
# Process latents to match transformer input
latents = process_latents(latents, height, width)
print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")
print(f"First transformer block input shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0,
latents=latents
).images[0]
return image, seed
except Exception as e:
print(f"Error during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
# Gradio interface setup
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt")
init_image = gr.Image(label="Initial Image (optional)", type="pil")
with gr.Row():
generate = gr.Button("Generate")
with gr.Row():
result = gr.Image(label="Result")
seed_output = gr.Number(label="Seed")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
generate.click(
infer,
inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
outputs=[result, seed_output]
)
demo.launch()