Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -21,22 +21,24 @@ transformer = FluxTransformer2DModel.from_pretrained(
|
|
21 |
print("Start loading LoRA weights")
|
22 |
state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
|
23 |
pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture",
|
24 |
-
weight_name="
|
25 |
return_alphas=True
|
26 |
)
|
27 |
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
28 |
if not is_correct_format:
|
29 |
raise ValueError("Invalid LoRA checkpoint.")
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
FluxFillPipeline.load_lora_into_transformer(
|
31 |
state_dict=state_dict,
|
32 |
network_alphas=network_alphas,
|
33 |
-
transformer=transformer,
|
34 |
)
|
35 |
-
|
36 |
-
"black-forest-labs/FLUX.1-Fill-dev",
|
37 |
-
torch_dtype=torch.bfloat16,
|
38 |
-
transformer=transformer,
|
39 |
-
).to("cuda")
|
40 |
# pipe.load_lora_weights("blanchon/FluxFillFurniture", weight_name="lora_fill.safetensors")
|
41 |
# pipe.fuse_lora(lora_scale=1.0)
|
42 |
pipe.to("cuda")
|
|
|
21 |
print("Start loading LoRA weights")
|
22 |
state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
|
23 |
pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture",
|
24 |
+
weight_name="pytorch_lora_weights.safetensors",
|
25 |
return_alphas=True
|
26 |
)
|
27 |
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
28 |
if not is_correct_format:
|
29 |
raise ValueError("Invalid LoRA checkpoint.")
|
30 |
+
|
31 |
+
|
32 |
+
pipe = FluxFillPipeline.from_pretrained(
|
33 |
+
"black-forest-labs/FLUX.1-Fill-dev",
|
34 |
+
torch_dtype=torch.bfloat16
|
35 |
+
).to(device)
|
36 |
FluxFillPipeline.load_lora_into_transformer(
|
37 |
state_dict=state_dict,
|
38 |
network_alphas=network_alphas,
|
39 |
+
transformer=pipe.transformer,
|
40 |
)
|
41 |
+
|
|
|
|
|
|
|
|
|
42 |
# pipe.load_lora_weights("blanchon/FluxFillFurniture", weight_name="lora_fill.safetensors")
|
43 |
# pipe.fuse_lora(lora_scale=1.0)
|
44 |
pipe.to("cuda")
|