File size: 3,786 Bytes
2911f3b
89b3db2
1368e65
2911f3b
 
 
 
adb82a6
 
 
2911f3b
 
e19c312
adb82a6
 
 
 
 
 
754b60e
2911f3b
754b60e
adb82a6
2911f3b
 
adb82a6
 
 
754b60e
2911f3b
 
 
 
 
 
 
adb82a6
2911f3b
1368e65
2911f3b
 
 
adb82a6
 
2911f3b
 
 
adb82a6
2911f3b
adb82a6
 
2911f3b
 
 
 
 
 
 
14d5805
adb82a6
 
2911f3b
 
 
adb82a6
 
1368e65
adb82a6
1368e65
adb82a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1368e65
2911f3b
1368e65
 
2911f3b
 
 
 
 
 
 
 
 
1368e65
2911f3b
adb82a6
 
1368e65
 
2911f3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
from einops import rearrange
import requests
import spaces
from huggingface_hub import login
from gradio_imageslider import ImageSlider  # Import ImageSlider
from diffusers.utils import load_image
from diffusers import FluxControlNetPipeline, FluxControlNetModel

# Source: https://github.com/XLabs-AI/x-flux.git
name = "flux-dev"
device = torch.device("cuda")
offload = False
is_schnell = name == "flux-schnell"

base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'

# Load the new ControlNet model and pipeline
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to(device)

controlnet_conditioning_scale = 0.5

control_modes = {
    "canny": 0,
    "tile": 1,
    "depth": 2,
    "blur": 3,
    "pose": 4,
    "gray": 5,
    "lq": 6,
}

def preprocess_image(image, target_width, target_height, crop=True):
    if crop:
        original_width, original_height = image.size

        # Resize to match the target size without stretching
        scale = max(target_width / original_width, target_height / original_height)
        resized_width = int(scale * original_width)
        resized_height = int(scale * original_height)

        image = image.resize((resized_width, resized_height), Image.LANCZOS)
        
        # Center crop to match the target dimensions
        left = (resized_width - target_width) // 2
        top = (resized_height - target_height) // 2
        image = image.crop((left, top, left + target_width, top + target_height))
    else:
        image = image.resize((target_width, target_height), Image.LANCZOS)
    
    return image

@spaces.GPU(duration=120)
def generate_image(prompt, control_image, control_mode, num_steps=50, guidance=4, width=512, height=512, seed=42, random_seed=False):
    if random_seed:
        seed = np.random.randint(0, 10000)
    
    if not os.path.isdir("./controlnet_results/"):
        os.makedirs("./controlnet_results/")

    torch_device = torch.device("cuda")

    control_image = preprocess_image(control_image, width, height)
    
    torch.manual_seed(seed)
    with torch.no_grad():
        image = pipe(
            prompt,
            control_image=control_image,
            control_mode=control_modes[control_mode],
            width=width,
            height=height,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            num_inference_steps=num_steps,
            guidance_scale=guidance,
        ).images[0]
    
    return [control_image, image]  # Return both images for slider

interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Image(type="pil", label="Control Image"),
        gr.Dropdown(choices=list(control_modes.keys()), label="Control Mode", value="canny"),
        gr.Slider(step=1, minimum=1, maximum=64, value=28, label="Num Steps"),
        gr.Slider(minimum=0.1, maximum=10, value=4, label="Guidance"),
        gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Width"),
        gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Height"),
        gr.Number(value=42, label="Seed"),
        gr.Checkbox(label="Random Seed")
    ],
    outputs=ImageSlider(label="Before / After"),  # Use ImageSlider as the output
    title="FLUX.1 Controlnet Canny",
    description="Generate images using ControlNet and a text prompt.\n[[non-commercial license, Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]"
)

if __name__ == "__main__":
    interface.launch()