Spaces:
Runtime error
Runtime error
File size: 4,090 Bytes
d53c0bb 7f891bb 2e306db 044186b 13ab5d1 7f891bb d2b0012 2e306db d2cb214 7f891bb d2b0012 d53c0bb d2b0012 2e306db 69e75b1 044186b 6af450a 044186b cec333d 044186b d2b0012 044186b d2b0012 044186b d2cb214 da39f41 6af450a 69e75b1 6af450a d2b0012 5b33905 d2b0012 6af450a d2b0012 6af450a d2b0012 13ab5d1 6af450a 5b33905 d2b0012 29a504c d2b0012 6af450a d2b0012 2e306db d2b0012 2e306db d2b0012 13ab5d1 da39f41 d2b0012 2e306db 7f891bb d2b0012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import spaces
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline, AutoencoderKL
# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Load models
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
def preprocess_image(image, image_size):
preprocess = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image = preprocess(image).unsqueeze(0).to(device, dtype=torch.float32)
return image
def encode_image(image):
with torch.no_grad():
latents = vae.encode(image).latent_dist.sample() * 0.18215
return latents
@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
if init_image is None:
# text2img case
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
else:
# img2img case
init_image = init_image.convert("RGB")
init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
latents = encode_image(init_image)
latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
if latents.shape[1] != pipe.vae.config.latent_channels:
conv = torch.nn.Conv2d(latents.shape[1], pipe.vae.config.latent_channels, kernel_size=1).to(device, dtype=dtype)
latents = conv(latents.to(dtype))
latents = latents.permute(0, 2, 3, 1).contiguous().view(-1, pipe.vae.config.latent_channels)
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0,
latents=latents
).images[0]
return image, seed
except Exception as e:
print(f"Error during inference: {e}")
return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
# Gradio interface setup
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt")
init_image = gr.Image(label="Initial Image (optional)", type="pil")
with gr.Row():
generate = gr.Button("Generate")
with gr.Row():
result = gr.Image(label="Result")
seed_output = gr.Number(label="Seed")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
generate.click(
infer,
inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
outputs=[result, seed_output]
)
demo.launch() |