blanchon commited on
Commit
317c8a2
·
verified ·
1 Parent(s): 6aaae5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -13,9 +13,32 @@ from PIL import Image
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
15
 
16
- pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
17
- pipe.load_lora_weights("blanchon/FluxFillFurniture", weight_name="lora_fill.safetensors")
18
- pipe.fuse_lora(lora_scale=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  pipe.to("cuda")
20
 
21
  def calculate_optimal_dimensions(image: Image.Image):
 
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
15
 
16
+ # pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
17
+ transformer = FluxTransformer2DModel.from_pretrained(
18
+ "xiaozaa/flux1-fill-dev-diffusers", ## The official Flux-Fill weights
19
+ torch_dtype=torch.bfloat16
20
+ )
21
+ print("Start loading LoRA weights")
22
+ state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
23
+ pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture",
24
+ weight_name="lora_fill.safetensors",
25
+ return_alphas=True
26
+ )
27
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
28
+ if not is_correct_format:
29
+ raise ValueError("Invalid LoRA checkpoint.")
30
+ FluxFillPipeline.load_lora_into_transformer(
31
+ state_dict=state_dict,
32
+ network_alphas=network_alphas,
33
+ transformer=transformer,
34
+ )
35
+ pipe = FluxFillPipeline.from_pretrained(
36
+ "black-forest-labs/FLUX.1-Fill-dev",
37
+ torch_dtype=torch.bfloat16,
38
+ transformer=transformer,
39
+ ).to("cuda")
40
+ # pipe.load_lora_weights("blanchon/FluxFillFurniture", weight_name="lora_fill.safetensors")
41
+ # pipe.fuse_lora(lora_scale=1.0)
42
  pipe.to("cuda")
43
 
44
  def calculate_optimal_dimensions(image: Image.Image):