Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,194 Bytes
46175dc 8f6d6cb 5544c41 8f6d6cb 5544c41 8f6d6cb 33fc81c 43f1e39 33fc81c cfb0d74 411cd1a cfb0d74 8f6d6cb d26acb1 8f6d6cb cfb0d74 8f6d6cb cfb0d74 33fc81c cfb0d74 8f6d6cb cfb0d74 33fc81c cfb0d74 8f6d6cb cfb0d74 33fc81c cfb0d74 8f6d6cb cfb0d74 8f6d6cb d26acb1 8f6d6cb cfb0d74 8f6d6cb cfb0d74 8f6d6cb 5ab7839 b6130e4 8f6d6cb 7c95d4a 45f491c cd06f65 8f6d6cb 7c95d4a 8f6d6cb c8e85df 8f6d6cb fdbf3a9 8f6d6cb 73a9679 8f6d6cb 33fc81c 8f6d6cb f3c9924 8f6d6cb 33fc81c 8f6d6cb 33fc81c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr
from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"
lora_base_path = "./models"
style_lora_base_path = "Shakker-Labs"
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
# Define the Gradio interface
@spaces.GPU()
def single_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora=None):
# Set the control type
if control_type == "subject":
lora_path = os.path.join(lora_base_path, "subject.safetensors")
elif control_type == "pose":
lora_path = os.path.join(lora_base_path, "pose.safetensors")
elif control_type == "inpainting":
lora_path = os.path.join(lora_base_path, "inpainting.safetensors")
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
# Set the style LoRA
if style_lora=="None":
pass
else:
if style_lora == "Simple_Sketch":
pipe.unload_lora_weights()
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch")
pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
if style_lora == "Text_Poster":
pipe.unload_lora_weights()
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster")
pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Text-Poster.safetensors")
if style_lora == "Vector_Style":
pipe.unload_lora_weights()
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey")
pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Vector-Journey.safetensors")
# Process the image
subject_imgs = [subject_img] if subject_img else []
spatial_imgs = [spatial_img] if spatial_img else []
image = pipe(
prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=subject_imgs,
spatial_images=spatial_imgs,
cond_size=512,
).images[0]
clear_cache(pipe.transformer)
return image
# Define the Gradio interface
@spaces.GPU()
def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed):
subject_path = os.path.join(lora_base_path, "subject.safetensors")
inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors")
set_multi_lora(pipe.transformer, [subject_path, inpainting_path], lora_weights=[[1],[1]],cond_size=512)
# Process the image
subject_imgs = [subject_img] if subject_img else []
spatial_imgs = [spatial_img] if spatial_img else []
image = pipe(
prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=subject_imgs,
spatial_images=spatial_imgs,
cond_size=512,
).images[0]
clear_cache(pipe.transformer)
return image
# Define the Gradio interface components
control_types = ["pose", "subject", "inpainting"]
style_loras = ["None", "Simple_Sketch", "Text_Poster", "Vector_Style"]
# Example data
single_examples = [
["A SKS in the library", Image.open("./test_imgs/subject1.png"), None, 768, 768, 5, "subject", "None"],
["sketched style,A joyful girl with balloons floats above a city wearing a hat and striped pants", None, Image.open("./test_imgs/spatial0.png"), 768, 512, 42, "pose", "Simple_Sketch"],
["In a picturesque village, a narrow cobblestone street with rustic stone buildings, colorful blinds, and lush green spaces, a cartoon man drawn with simple lines and solid colors stands in the foreground, wearing a red shirt, beige work pants, and brown shoes, carrying a strap on his shoulder. The scene features warm and enticing colors, a pleasant fusion of nature and architecture, and the camera's perspective on the street clearly shows the charming and quaint environment., Integrating elements of reality and cartoon.", None, Image.open("./test_imgs/spatial1.png"), 768, 768, 1, "pose", "Vector_Style"],
]
multi_examples = [
["A SKS on the car", Image.open("./test_imgs/subject2.png"), Image.open("./test_imgs/spatial2.png"), 768, 768, 7],
]
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown("# Image Generation with EasyControl")
gr.Markdown("Generate images using EasyControl with different control types and style LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
with gr.Tab("Single Condition Generation"):
with gr.Row():
with gr.Column():
gr.Markdown("""
**Prompt** (When using LoRA, please try the recommended prompts available at the following links:
[FLUX.1-dev-LoRA-Text-Poster](https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Text-Poster),
[FLUX.1-dev-LoRA-Children-Simple-Sketch](https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch),
[FLUX.1-dev-LoRA-Vector-Journey](https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Vector-Journey))
""")
prompt = gr.Textbox(label="Prompt")
subject_img = gr.Image(label="Subject Image", type="pil") # 上传图像文件
spatial_img = gr.Image(label="Spatial Image", type="pil") # 上传图像文件
height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
seed = gr.Number(label="Seed", value=42)
control_type = gr.Dropdown(choices=control_types, label="Control Type")
style_lora = gr.Dropdown(choices=style_loras, label="Style LoRA")
single_generate_btn = gr.Button("Generate Image")
with gr.Column():
single_output_image = gr.Image(label="Generated Image")
# Add examples for Single Condition Generation
gr.Examples(
examples=single_examples,
inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora],
outputs=single_output_image,
fn=single_condition_generate_image,
cache_examples=False, # 缓存示例结果以加快加载速度
label="Single Condition Examples"
)
with gr.Tab("Multi-Condition Generation"):
with gr.Row():
with gr.Column():
multi_prompt = gr.Textbox(label="Prompt")
multi_subject_img = gr.Image(label="Subject Image", type="pil") # 上传图像文件
multi_spatial_img = gr.Image(label="Spatial Image", type="pil") # 上传图像文件
multi_height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
multi_width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
multi_seed = gr.Number(label="Seed", value=42)
multi_generate_btn = gr.Button("Generate Image")
with gr.Column():
multi_output_image = gr.Image(label="Generated Image")
# Add examples for Multi-Condition Generation
gr.Examples(
examples=multi_examples,
inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed],
outputs=multi_output_image,
fn=multi_condition_generate_image,
cache_examples=False, # 缓存示例结果以加快加载速度
label="Multi-Condition Examples"
)
# Link the buttons to the functions
single_generate_btn.click(
single_condition_generate_image,
inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora],
outputs=single_output_image
)
multi_generate_btn.click(
multi_condition_generate_image,
inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed],
outputs=multi_output_image
)
# Launch the Gradio app
demo.queue().launch() |