File size: 9,194 Bytes
46175dc
8f6d6cb
 
 
 
 
 
5544c41
 
8f6d6cb
 
 
 
5544c41
8f6d6cb
33fc81c
43f1e39
33fc81c
cfb0d74
 
 
 
 
411cd1a
cfb0d74
 
 
 
8f6d6cb
 
d26acb1
8f6d6cb
 
 
 
 
 
 
 
cfb0d74
8f6d6cb
 
 
 
 
 
cfb0d74
33fc81c
cfb0d74
8f6d6cb
cfb0d74
33fc81c
cfb0d74
8f6d6cb
cfb0d74
33fc81c
cfb0d74
8f6d6cb
 
 
 
cfb0d74
 
 
 
 
 
 
 
 
 
 
 
 
8f6d6cb
 
 
d26acb1
8f6d6cb
 
 
cfb0d74
8f6d6cb
 
 
 
cfb0d74
 
 
 
 
 
 
 
 
 
 
 
 
8f6d6cb
 
 
5ab7839
b6130e4
8f6d6cb
 
 
7c95d4a
45f491c
cd06f65
8f6d6cb
 
7c95d4a
8f6d6cb
 
 
 
 
 
c8e85df
8f6d6cb
 
 
 
fdbf3a9
 
 
 
 
 
8f6d6cb
 
 
73a9679
 
8f6d6cb
 
 
 
 
 
 
 
 
 
 
 
 
33fc81c
8f6d6cb
 
 
 
 
 
 
 
 
 
f3c9924
 
8f6d6cb
 
 
 
 
 
 
 
 
 
 
33fc81c
8f6d6cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33fc81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr

from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora

# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"    
lora_base_path = "./models"
style_lora_base_path = "Shakker-Labs"


pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")

def clear_cache(transformer):
    for name, attn_processor in transformer.attn_processors.items():
        attn_processor.bank_kv.clear()

# Define the Gradio interface
@spaces.GPU()
def single_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora=None):
    # Set the control type
    if control_type == "subject":
        lora_path = os.path.join(lora_base_path, "subject.safetensors")
    elif control_type == "pose":
        lora_path = os.path.join(lora_base_path, "pose.safetensors")
    elif control_type == "inpainting":
        lora_path = os.path.join(lora_base_path, "inpainting.safetensors")
    set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
    
    # Set the style LoRA
    if style_lora=="None":
        pass
    else:
        if style_lora == "Simple_Sketch":
            pipe.unload_lora_weights()
            style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch")
            pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
        if style_lora == "Text_Poster":
            pipe.unload_lora_weights()
            style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster")
            pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Text-Poster.safetensors")
        if style_lora == "Vector_Style":
            pipe.unload_lora_weights()
            style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey")
            pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Vector-Journey.safetensors")

    # Process the image
    subject_imgs = [subject_img] if subject_img else []
    spatial_imgs = [spatial_img] if spatial_img else []
    image = pipe(
        prompt,
        height=int(height),
        width=int(width),
        guidance_scale=3.5,
        num_inference_steps=25,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(seed), 
        subject_images=subject_imgs,
        spatial_images=spatial_imgs,
        cond_size=512,
    ).images[0]
    clear_cache(pipe.transformer)
    return image

# Define the Gradio interface
@spaces.GPU()
def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed):
    subject_path = os.path.join(lora_base_path, "subject.safetensors")
    inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors")
    set_multi_lora(pipe.transformer, [subject_path, inpainting_path], lora_weights=[[1],[1]],cond_size=512)

    # Process the image
    subject_imgs = [subject_img] if subject_img else []
    spatial_imgs = [spatial_img] if spatial_img else []
    image = pipe(
        prompt,
        height=int(height),
        width=int(width),
        guidance_scale=3.5,
        num_inference_steps=25,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(seed), 
        subject_images=subject_imgs,
        spatial_images=spatial_imgs,
        cond_size=512,
    ).images[0]
    clear_cache(pipe.transformer)
    return image

# Define the Gradio interface components
control_types = ["pose", "subject", "inpainting"]
style_loras = ["None", "Simple_Sketch", "Text_Poster", "Vector_Style"]

# Example data
single_examples = [
    ["A SKS in the library", Image.open("./test_imgs/subject1.png"), None, 768, 768, 5, "subject", "None"],
    ["sketched style,A joyful girl with balloons floats above a city wearing a hat and striped pants", None, Image.open("./test_imgs/spatial0.png"), 768, 512, 42, "pose", "Simple_Sketch"],
    ["In a picturesque village, a narrow cobblestone street with rustic stone buildings, colorful blinds, and lush green spaces, a cartoon man drawn with simple lines and solid colors stands in the foreground, wearing a red shirt, beige work pants, and brown shoes, carrying a strap on his shoulder. The scene features warm and enticing colors, a pleasant fusion of nature and architecture, and the camera's perspective on the street clearly shows the charming and quaint environment., Integrating elements of reality and cartoon.", None, Image.open("./test_imgs/spatial1.png"), 768, 768, 1, "pose", "Vector_Style"],
]
multi_examples = [
    ["A SKS on the car", Image.open("./test_imgs/subject2.png"), Image.open("./test_imgs/spatial2.png"), 768, 768, 7],
]


# Create the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("# Image Generation with EasyControl")
    gr.Markdown("Generate images using EasyControl with different control types and style LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")

    with gr.Tab("Single Condition Generation"):
        with gr.Row():
            with gr.Column():
                gr.Markdown("""
                **Prompt** (When using LoRA, please try the recommended prompts available at the following links:  
                [FLUX.1-dev-LoRA-Text-Poster](https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Text-Poster),  
                [FLUX.1-dev-LoRA-Children-Simple-Sketch](https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch),  
                [FLUX.1-dev-LoRA-Vector-Journey](https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Vector-Journey))
                """)
                prompt = gr.Textbox(label="Prompt")
                subject_img = gr.Image(label="Subject Image", type="pil")  # 上传图像文件
                spatial_img = gr.Image(label="Spatial Image", type="pil")  # 上传图像文件
                height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
                width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
                seed = gr.Number(label="Seed", value=42)
                control_type = gr.Dropdown(choices=control_types, label="Control Type")
                style_lora = gr.Dropdown(choices=style_loras, label="Style LoRA")
                single_generate_btn = gr.Button("Generate Image")
            with gr.Column():
                single_output_image = gr.Image(label="Generated Image")

        # Add examples for Single Condition Generation
        gr.Examples(
            examples=single_examples,
            inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora],
            outputs=single_output_image,
            fn=single_condition_generate_image,
            cache_examples=False,  # 缓存示例结果以加快加载速度
            label="Single Condition Examples"
        )


    with gr.Tab("Multi-Condition Generation"):
        with gr.Row():
            with gr.Column():
                multi_prompt = gr.Textbox(label="Prompt")
                multi_subject_img = gr.Image(label="Subject Image", type="pil")  # 上传图像文件
                multi_spatial_img = gr.Image(label="Spatial Image", type="pil")  # 上传图像文件
                multi_height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
                multi_width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
                multi_seed = gr.Number(label="Seed", value=42)
                multi_generate_btn = gr.Button("Generate Image")
            with gr.Column():
                multi_output_image = gr.Image(label="Generated Image")
                
        # Add examples for Multi-Condition Generation
        gr.Examples(
            examples=multi_examples,
            inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed],
            outputs=multi_output_image,
            fn=multi_condition_generate_image,
            cache_examples=False,  # 缓存示例结果以加快加载速度
            label="Multi-Condition Examples"
        )


    # Link the buttons to the functions
    single_generate_btn.click(
        single_condition_generate_image,
        inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora],
        outputs=single_output_image
    )
    multi_generate_btn.click(
        multi_condition_generate_image,
        inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed],
        outputs=multi_output_image
    )

# Launch the Gradio app
demo.queue().launch()