prithivMLmods commited on
Commit
d70bc68
·
verified ·
1 Parent(s): 472834d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +89 -1
README.md CHANGED
@@ -10,4 +10,92 @@ base_model:
10
  - black-forest-labs/FLUX.1-schnell
11
  base_model_relation: merge
12
  pipeline_tag: text-to-image
13
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  - black-forest-labs/FLUX.1-schnell
11
  base_model_relation: merge
12
  pipeline_tag: text-to-image
13
+ ---
14
+
15
+ # **FLUX.1-Merged**
16
+
17
+ This repository provides the merged params for [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev)
18
+ and [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell).
19
+
20
+ # **Merge & Upload**
21
+
22
+ ```python
23
+ from diffusers import FluxTransformer2DModel
24
+ from huggingface_hub import snapshot_download
25
+ from huggingface_hub import upload_folder
26
+ from accelerate import init_empty_weights
27
+ from diffusers.models.model_loading_utils import load_model_dict_into_meta
28
+ import safetensors.torch
29
+ import glob
30
+ import torch
31
+
32
+
33
+ # Initialize the model with empty weights
34
+ with init_empty_weights():
35
+ config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
36
+ model = FluxTransformer2DModel.from_config(config)
37
+
38
+ # Download the model checkpoints
39
+ dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
40
+ schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")
41
+
42
+ # Get the paths to the model shards
43
+ dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
44
+ schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
45
+
46
+ # Merge the state dictionaries
47
+ merged_state_dict = {}
48
+ guidance_state_dict = {}
49
+
50
+ for i in range(len(dev_shards)):
51
+ state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
52
+ state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])
53
+
54
+ keys = list(state_dict_dev_temp.keys())
55
+ for k in keys:
56
+ if "guidance" not in k:
57
+ merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2
58
+ else:
59
+ guidance_state_dict[k] = state_dict_dev_temp.pop(k)
60
+
61
+ if len(state_dict_dev_temp) > 0:
62
+ raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
63
+ if len(state_dict_schnell_temp) > 0:
64
+ raise ValueError(f"There should not be any residue but got: {list(state_dict_schnell_temp.keys())}.")
65
+
66
+ # Update the merged state dictionary with the guidance state dictionary
67
+ merged_state_dict.update(guidance_state_dict)
68
+
69
+ # Load the merged state dictionary into the model
70
+ load_model_dict_into_meta(model, merged_state_dict)
71
+
72
+ # Save the merged model
73
+ model.to(torch.bfloat16).save_pretrained("transformer")
74
+
75
+ # Upload the merged model to the Hugging Face Hub
76
+ upload_folder(
77
+ repo_id="prithivMLmods/Flux.1-Merged", # Replace with your Hugging Face username and desired repo name
78
+ folder_path="transformer",
79
+ path_in_repo="transformer",
80
+ )
81
+ ```
82
+ # **Inference**
83
+
84
+ ```python
85
+ from diffusers import FluxPipeline
86
+ import torch
87
+
88
+ pipeline = FluxPipeline.from_pretrained(
89
+ "prithivMLmods/Flux.1-Merged", torch_dtype=torch.bfloat16
90
+ ).to("cuda")
91
+ image = pipeline(
92
+ prompt="a tiny astronaut hatching from an egg on the moon",
93
+ guidance_scale=3.5,
94
+ num_inference_steps=4,
95
+ height=880,
96
+ width=1184,
97
+ max_sequence_length=512,
98
+ generator=torch.manual_seed(0),
99
+ ).images[0]
100
+ image.save("merged_flux.png")
101
+ ```