blanchon commited on
Commit
e875d79
·
verified ·
1 Parent(s): 13949f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -21,22 +21,24 @@ transformer = FluxTransformer2DModel.from_pretrained(
21
  print("Start loading LoRA weights")
22
  state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
23
  pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture",
24
- weight_name="lora_fill.safetensors",
25
  return_alphas=True
26
  )
27
  is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
28
  if not is_correct_format:
29
  raise ValueError("Invalid LoRA checkpoint.")
 
 
 
 
 
 
30
  FluxFillPipeline.load_lora_into_transformer(
31
  state_dict=state_dict,
32
  network_alphas=network_alphas,
33
- transformer=transformer,
34
  )
35
- pipe = FluxFillPipeline.from_pretrained(
36
- "black-forest-labs/FLUX.1-Fill-dev",
37
- torch_dtype=torch.bfloat16,
38
- transformer=transformer,
39
- ).to("cuda")
40
  # pipe.load_lora_weights("blanchon/FluxFillFurniture", weight_name="lora_fill.safetensors")
41
  # pipe.fuse_lora(lora_scale=1.0)
42
  pipe.to("cuda")
 
21
  print("Start loading LoRA weights")
22
  state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
23
  pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture",
24
+ weight_name="pytorch_lora_weights.safetensors",
25
  return_alphas=True
26
  )
27
  is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
28
  if not is_correct_format:
29
  raise ValueError("Invalid LoRA checkpoint.")
30
+
31
+
32
+ pipe = FluxFillPipeline.from_pretrained(
33
+ "black-forest-labs/FLUX.1-Fill-dev",
34
+ torch_dtype=torch.bfloat16
35
+ ).to(device)
36
  FluxFillPipeline.load_lora_into_transformer(
37
  state_dict=state_dict,
38
  network_alphas=network_alphas,
39
+ transformer=pipe.transformer,
40
  )
41
+
 
 
 
 
42
  # pipe.load_lora_weights("blanchon/FluxFillFurniture", weight_name="lora_fill.safetensors")
43
  # pipe.fuse_lora(lora_scale=1.0)
44
  pipe.to("cuda")