File size: 4,090 Bytes
d53c0bb
7f891bb
2e306db
 
 
044186b
 
13ab5d1
7f891bb
d2b0012
 
 
2e306db
d2cb214
7f891bb
d2b0012
 
 
 
 
d53c0bb
d2b0012
2e306db
69e75b1
044186b
6af450a
044186b
cec333d
044186b
d2b0012
044186b
 
d2b0012
 
 
 
044186b
d2cb214
da39f41
 
 
6af450a
69e75b1
6af450a
 
 
d2b0012
5b33905
 
 
 
 
d2b0012
 
6af450a
 
 
d2b0012
 
6af450a
 
 
d2b0012
 
 
13ab5d1
6af450a
5b33905
d2b0012
29a504c
 
 
 
 
 
 
d2b0012
 
6af450a
 
 
d2b0012
2e306db
d2b0012
 
 
 
 
2e306db
d2b0012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ab5d1
da39f41
d2b0012
2e306db
7f891bb
d2b0012
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import spaces
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
from torchvision import transforms
from diffusers import DiffusionPipeline, AutoencoderKL

# Constants
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

# Load models
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)

def preprocess_image(image, image_size):
    preprocess = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    image = preprocess(image).unsqueeze(0).to(device, dtype=torch.float32)
    return image

def encode_image(image):
    with torch.no_grad():
        latents = vae.encode(image).latent_dist.sample() * 0.18215
    return latents

@spaces.GPU()
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)

    try:
        if init_image is None:
            # text2img case
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0
            ).images[0]
        else:
            # img2img case
            init_image = init_image.convert("RGB")
            init_image = preprocess_image(init_image, 1024)  # Using 1024 as FLUX VAE sample size
            latents = encode_image(init_image)
            
            latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
            
            if latents.shape[1] != pipe.vae.config.latent_channels:
                conv = torch.nn.Conv2d(latents.shape[1], pipe.vae.config.latent_channels, kernel_size=1).to(device, dtype=dtype)
                latents = conv(latents.to(dtype))

            latents = latents.permute(0, 2, 3, 1).contiguous().view(-1, pipe.vae.config.latent_channels)

            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=generator,
                guidance_scale=0.0,
                latents=latents
            ).images[0]

        return image, seed
    except Exception as e:
        print(f"Error during inference: {e}")
        return Image.new("RGB", (width, height), (255, 0, 0)), seed  # Red fallback image

# Gradio interface setup
with gr.Blocks() as demo:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
        init_image = gr.Image(label="Initial Image (optional)", type="pil")

    with gr.Row():
        generate = gr.Button("Generate")
        
    with gr.Row():
        result = gr.Image(label="Result")
        seed_output = gr.Number(label="Seed")

    with gr.Accordion("Advanced Settings", open=False):
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
        num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)

    generate.click(
        infer,
        inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
        outputs=[result, seed_output]
    )

demo.launch()