Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,42 +12,20 @@ from src.pipeline import FluxPipeline
|
|
12 |
from src.transformer_flux import FluxTransformer2DModel
|
13 |
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
|
14 |
|
15 |
-
|
16 |
-
class ImageProcessor:
|
17 |
-
def __init__(self, path):
|
18 |
-
device = "cuda"
|
19 |
-
self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device)
|
20 |
-
transformer = FluxTransformer2DModel.from_pretrained(path, subfolder="transformer", torch_dtype=torch.bfloat16, device=device)
|
21 |
-
self.pipe.transformer = transformer
|
22 |
-
self.pipe.to(device)
|
23 |
-
|
24 |
-
def clear_cache(self, transformer):
|
25 |
-
for name, attn_processor in transformer.attn_processors.items():
|
26 |
-
attn_processor.bank_kv.clear()
|
27 |
-
|
28 |
-
def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height=768, width=768, output_path=None, seed=42):
|
29 |
-
image = self.pipe(
|
30 |
-
prompt,
|
31 |
-
height=int(height),
|
32 |
-
width=int(width),
|
33 |
-
guidance_scale=3.5,
|
34 |
-
num_inference_steps=25,
|
35 |
-
max_sequence_length=512,
|
36 |
-
generator=torch.Generator("cpu").manual_seed(seed),
|
37 |
-
subject_images=subject_imgs,
|
38 |
-
spatial_images=spatial_imgs,
|
39 |
-
cond_size=512,
|
40 |
-
).images[0]
|
41 |
-
self.clear_cache(self.pipe.transformer)
|
42 |
-
if output_path:
|
43 |
-
image.save(output_path)
|
44 |
-
return image
|
45 |
-
|
46 |
# Initialize the image processor
|
47 |
base_path = "black-forest-labs/FLUX.1-dev"
|
48 |
lora_base_path = "./models"
|
49 |
style_lora_base_path = "Shakker-Labs"
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Define the Gradio interface
|
53 |
@spaces.GPU()
|
@@ -59,29 +37,41 @@ def single_condition_generate_image(prompt, subject_img, spatial_img, height, wi
|
|
59 |
lora_path = os.path.join(lora_base_path, "pose.safetensors")
|
60 |
elif control_type == "inpainting":
|
61 |
lora_path = os.path.join(lora_base_path, "inpainting.safetensors")
|
62 |
-
set_single_lora(
|
63 |
|
64 |
# Set the style LoRA
|
65 |
if style_lora=="None":
|
66 |
pass
|
67 |
else:
|
68 |
if style_lora == "Simple_Sketch":
|
69 |
-
|
70 |
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch")
|
71 |
-
|
72 |
if style_lora == "Text_Poster":
|
73 |
-
|
74 |
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster")
|
75 |
-
|
76 |
if style_lora == "Vector_Style":
|
77 |
-
|
78 |
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey")
|
79 |
-
|
80 |
|
81 |
# Process the image
|
82 |
subject_imgs = [subject_img] if subject_img else []
|
83 |
spatial_imgs = [spatial_img] if spatial_img else []
|
84 |
-
image =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
return image
|
86 |
|
87 |
# Define the Gradio interface
|
@@ -89,12 +79,24 @@ def single_condition_generate_image(prompt, subject_img, spatial_img, height, wi
|
|
89 |
def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed):
|
90 |
subject_path = os.path.join(lora_base_path, "subject.safetensors")
|
91 |
inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors")
|
92 |
-
set_multi_lora(
|
93 |
|
94 |
# Process the image
|
95 |
subject_imgs = [subject_img] if subject_img else []
|
96 |
spatial_imgs = [spatial_img] if spatial_img else []
|
97 |
-
image =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
return image
|
99 |
|
100 |
# Define the Gradio interface components
|
|
|
12 |
from src.transformer_flux import FluxTransformer2DModel
|
13 |
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Initialize the image processor
|
16 |
base_path = "black-forest-labs/FLUX.1-dev"
|
17 |
lora_base_path = "./models"
|
18 |
style_lora_base_path = "Shakker-Labs"
|
19 |
+
|
20 |
+
|
21 |
+
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
|
22 |
+
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
|
23 |
+
pipe.transformer = transformer
|
24 |
+
pipe.to(device)
|
25 |
+
|
26 |
+
def clear_cache(transformer):
|
27 |
+
for name, attn_processor in transformer.attn_processors.items():
|
28 |
+
attn_processor.bank_kv.clear()
|
29 |
|
30 |
# Define the Gradio interface
|
31 |
@spaces.GPU()
|
|
|
37 |
lora_path = os.path.join(lora_base_path, "pose.safetensors")
|
38 |
elif control_type == "inpainting":
|
39 |
lora_path = os.path.join(lora_base_path, "inpainting.safetensors")
|
40 |
+
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
|
41 |
|
42 |
# Set the style LoRA
|
43 |
if style_lora=="None":
|
44 |
pass
|
45 |
else:
|
46 |
if style_lora == "Simple_Sketch":
|
47 |
+
pipe.unload_lora_weights()
|
48 |
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch")
|
49 |
+
pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
|
50 |
if style_lora == "Text_Poster":
|
51 |
+
pipe.unload_lora_weights()
|
52 |
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster")
|
53 |
+
pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Text-Poster.safetensors")
|
54 |
if style_lora == "Vector_Style":
|
55 |
+
pipe.unload_lora_weights()
|
56 |
style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey")
|
57 |
+
pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Vector-Journey.safetensors")
|
58 |
|
59 |
# Process the image
|
60 |
subject_imgs = [subject_img] if subject_img else []
|
61 |
spatial_imgs = [spatial_img] if spatial_img else []
|
62 |
+
image = pipe(
|
63 |
+
prompt,
|
64 |
+
height=int(height),
|
65 |
+
width=int(width),
|
66 |
+
guidance_scale=3.5,
|
67 |
+
num_inference_steps=25,
|
68 |
+
max_sequence_length=512,
|
69 |
+
generator=torch.Generator("cpu").manual_seed(seed),
|
70 |
+
subject_images=subject_imgs,
|
71 |
+
spatial_images=spatial_imgs,
|
72 |
+
cond_size=512,
|
73 |
+
).images[0]
|
74 |
+
clear_cache(pipe.transformer)
|
75 |
return image
|
76 |
|
77 |
# Define the Gradio interface
|
|
|
79 |
def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed):
|
80 |
subject_path = os.path.join(lora_base_path, "subject.safetensors")
|
81 |
inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors")
|
82 |
+
set_multi_lora(pipe.transformer, [subject_path, inpainting_path], lora_weights=[[1],[1]],cond_size=512)
|
83 |
|
84 |
# Process the image
|
85 |
subject_imgs = [subject_img] if subject_img else []
|
86 |
spatial_imgs = [spatial_img] if spatial_img else []
|
87 |
+
image = pipe(
|
88 |
+
prompt,
|
89 |
+
height=int(height),
|
90 |
+
width=int(width),
|
91 |
+
guidance_scale=3.5,
|
92 |
+
num_inference_steps=25,
|
93 |
+
max_sequence_length=512,
|
94 |
+
generator=torch.Generator("cpu").manual_seed(seed),
|
95 |
+
subject_images=subject_imgs,
|
96 |
+
spatial_images=spatial_imgs,
|
97 |
+
cond_size=512,
|
98 |
+
).images[0]
|
99 |
+
clear_cache(pipe.transformer)
|
100 |
return image
|
101 |
|
102 |
# Define the Gradio interface components
|