Spaces:
Runtime error
Runtime error
File size: 4,752 Bytes
7f891bb 2e306db 9d86930 044186b a7d057d 7f891bb d2b0012 2e306db d2cb214 7f891bb a7d057d d2b0012 d53c0bb 9d86930 69e75b1 044186b 6af450a 044186b cec333d 044186b a7d057d 044186b 448d742 86f0308 448d742 86f0308 448d742 86f0308 9d86930 448d742 86f0308 044186b d2cb214 da39f41 6af450a 69e75b1 6af450a d2b0012 5b33905 d2b0012 6af450a d2b0012 6af450a a7d057d 448d742 a7d057d 448d742 6b927be 86f0308 d2b0012 29a504c d2b0012 6af450a 6b927be d2b0012 2e306db d53ee34 d2b0012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
import numpy as np
import random
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline
# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Load FLUX model
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# Add a projection layer to match x_embedder input
projection = nn.Linear(16, 64).to(device).to(dtype)
def preprocess_image(image, image_size):
preprocess = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
return image
def process_latents(latents, height, width):
print(f"Input latent shape: {latents.shape}")
# Ensure latents are the correct shape
if latents.shape[2:] != (height // 8, width // 8):
latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
print(f"Latent shape after potential interpolation: {latents.shape}")
# Reshape latents to [batch_size, seq_len, channels]
latents = latents.permute(0, 2, 3, 1).reshape(1, -1, latents.shape[1])
print(f"Reshaped latent shape: {latents.shape}")
# Project latents from 16 to 64 dimensions
latents = projection(latents)
print(f"Projected latent shape: {latents.shape}")
return latents
@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
if init_image is None:
# text2img case
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
else:
# img2img case
init_image = init_image.convert("RGB")
init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
# Encode the image using FLUX VAE
latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215
print(f"Initial latent shape from VAE: {latents.shape}")
# Process latents to match x_embedder input
latents = process_latents(latents, height, width)
print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")
print(f"First transformer block input shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0,
latents=latents
).images[0]
return image, seed
except Exception as e:
print(f"Error during inference: {e}")
import traceback
traceback.print_exc()
return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
# Gradio interface setup
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt")
init_image = gr.Image(label="Initial Image (optional)", type="pil")
with gr.Row():
generate = gr.Button("Generate")
with gr.Row():
result = gr.Image(label="Result")
seed_output = gr.Number(label="Seed")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
generate.click(
infer,
inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
outputs=[result, seed_output]
)
demo.launch() |