Spaces:
Runtime error
Runtime error
File size: 5,490 Bytes
d9f1205 7f891bb 383a90d 2e306db 383a90d 044186b 383a90d a7d057d d9f1205 d2b0012 b473829 383a90d 3be64a5 383a90d 7f891bb a7d057d b473829 d2b0012 b473829 9d86930 3be64a5 044186b 383a90d 3be64a5 383a90d 3be64a5 383a90d 044186b d2cb214 383a90d 69e75b1 6af450a d2b0012 5b33905 383a90d d2b0012 6af450a 383a90d 3be64a5 383a90d 3be64a5 383a90d d2b0012 29a504c 383a90d 29a504c 383a90d 3be64a5 d2b0012 6af450a 383a90d 6b927be 383a90d 2e306db 383a90d d53ee34 383a90d d53ee34 383a90d d53ee34 d2b0012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import spaces
import gradio as gr
import numpy as np
import random
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline
# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
LATENT_CHANNELS = 16
TEXT_EMBED_DIM = 768
MAX_TEXT_EMBEDDINGS = 77
SCALING_FACTOR = 0.3611
# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# Add a projection layer to match text embedding dimension
projection = nn.Linear(LATENT_CHANNELS, TEXT_EMBED_DIM).to(device).to(dtype)
def preprocess_image(image, image_size):
preprocess = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
return image
def process_latents(latents, height, width):
print(f"Input latent shape: {latents.shape}")
# Ensure latents are the correct shape
if latents.shape[2:] != (height // 8, width // 8):
latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
print(f"Latent shape after potential interpolation: {latents.shape}")
# Reshape latents to [batch_size, seq_len, channels]
latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
print(f"Reshaped latent shape: {latents.shape}")
# Project latents to match text embedding dimension
latents = projection(latents)
print(f"Projected latent shape: {latents.shape}")
# Adjust sequence length to match text embeddings
seq_len = latents.shape[1]
if seq_len > MAX_TEXT_EMBEDDINGS:
latents = latents[:, :MAX_TEXT_EMBEDDINGS, :]
elif seq_len < MAX_TEXT_EMBEDDINGS:
pad_len = MAX_TEXT_EMBEDDINGS - seq_len
latents = torch.nn.functional.pad(latents, (0, 0, 0, pad_len, 0, 0))
print(f"Final latent shape: {latents.shape}")
return latents
@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
if init_image is None:
# text2img case
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
else:
# img2img case
init_image = init_image.convert("RGB")
init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
# Encode the image using FLUX VAE
latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
print(f"Initial latent shape from VAE: {latents.shape}")
# Process latents to match text embedding format
latents = process_latents(latents, height, width)
# Get text embeddings
text_embeddings = pipe.transformer.text_encoder([prompt])
print(f"Text embedding shape: {text_embeddings.shape}")
# Combine image latents and text embeddings
combined_embeddings = torch.cat([latents, text_embeddings], dim=1)
print(f"Combined embedding shape: {combined_embeddings.shape}")
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0,
latents=combined_embeddings
).images[0]
return image, seed
except Exception as e:
print(f"Error during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
# Gradio interface setup
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt")
init_image = gr.Image(label="Initial Image (optional)", type="pil")
with gr.Row():
generate = gr.Button("Generate")
with gr.Row():
result = gr.Image(label="Result")
seed_output = gr.Number(label="Seed")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
generate.click(
infer,
inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
outputs=[result, seed_output]
)
demo.launch() |