tsqn commited on
Commit
560234d
·
verified ·
1 Parent(s): e5423cc

Delete scripts

Browse files
Files changed (1) hide show
  1. scripts/to_safetensors.py +0 -143
scripts/to_safetensors.py DELETED
@@ -1,143 +0,0 @@
1
- import argparse
2
- from pathlib import Path
3
- from typing import Dict
4
- import safetensors.torch
5
- import torch
6
- import json
7
- import shutil
8
-
9
-
10
- def load_text_encoder(index_path: Path) -> Dict:
11
- with open(index_path, "r") as f:
12
- index: Dict = json.load(f)
13
-
14
- loaded_tensors = {}
15
- for part_file in set(index.get("weight_map", {}).values()):
16
- tensors = safetensors.torch.load_file(
17
- index_path.parent / part_file, device="cpu"
18
- )
19
- for tensor_name in tensors:
20
- loaded_tensors[tensor_name] = tensors[tensor_name]
21
-
22
- return loaded_tensors
23
-
24
-
25
- def convert_unet(unet: Dict, add_prefix=True) -> Dict:
26
- if add_prefix:
27
- return {"model.diffusion_model." + key: value for key, value in unet.items()}
28
- return unet
29
-
30
-
31
- def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
32
- state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
33
- stats_path = vae_path / "per_channel_statistics.json"
34
- if stats_path.exists():
35
- with open(stats_path, "r") as f:
36
- data = json.load(f)
37
- transposed_data = list(zip(*data["data"]))
38
- data_dict = {
39
- f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
40
- vals
41
- )
42
- for col, vals in zip(data["columns"], transposed_data)
43
- }
44
- else:
45
- data_dict = {}
46
-
47
- result = {
48
- ("vae." if add_prefix else "") + key: value for key, value in state_dict.items()
49
- }
50
- result.update(data_dict)
51
- return result
52
-
53
-
54
- def convert_encoder(encoder: Dict) -> Dict:
55
- return {
56
- "text_encoders.t5xxl.transformer." + key: value
57
- for key, value in encoder.items()
58
- }
59
-
60
-
61
- def save_config(config_src: str, config_dst: str):
62
- shutil.copy(config_src, config_dst)
63
-
64
-
65
- def load_vae_config(vae_path: Path) -> str:
66
- config_path = vae_path / "config.json"
67
- if not config_path.exists():
68
- raise FileNotFoundError(f"VAE config file {config_path} not found.")
69
- return str(config_path)
70
-
71
-
72
- def main(
73
- unet_path: str,
74
- vae_path: str,
75
- out_path: str,
76
- mode: str,
77
- unet_config_path: str = None,
78
- scheduler_config_path: str = None,
79
- ) -> None:
80
- unet = convert_unet(
81
- torch.load(unet_path, weights_only=True), add_prefix=(mode == "single")
82
- )
83
-
84
- # Load VAE from directory and config
85
- vae = convert_vae(Path(vae_path), add_prefix=(mode == "single"))
86
- vae_config_path = load_vae_config(Path(vae_path))
87
-
88
- if mode == "single":
89
- result = {**unet, **vae}
90
- safetensors.torch.save_file(result, out_path)
91
- elif mode == "separate":
92
- # Create directories for unet, vae, and scheduler
93
- unet_dir = Path(out_path) / "unet"
94
- vae_dir = Path(out_path) / "vae"
95
- scheduler_dir = Path(out_path) / "scheduler"
96
-
97
- unet_dir.mkdir(parents=True, exist_ok=True)
98
- vae_dir.mkdir(parents=True, exist_ok=True)
99
- scheduler_dir.mkdir(parents=True, exist_ok=True)
100
-
101
- # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
102
- safetensors.torch.save_file(
103
- unet, unet_dir / "unet_diffusion_pytorch_model.safetensors"
104
- )
105
- safetensors.torch.save_file(
106
- vae, vae_dir / "vae_diffusion_pytorch_model.safetensors"
107
- )
108
-
109
- # Save config files for unet, vae, and scheduler
110
- if unet_config_path:
111
- save_config(unet_config_path, unet_dir / "config.json")
112
- if vae_config_path:
113
- save_config(vae_config_path, vae_dir / "config.json")
114
- if scheduler_config_path:
115
- save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json")
116
-
117
-
118
- if __name__ == "__main__":
119
- parser = argparse.ArgumentParser()
120
- parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt")
121
- parser.add_argument("--vae_path", "-v", type=str, default="vae/")
122
- parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors")
123
- parser.add_argument(
124
- "--mode",
125
- "-m",
126
- type=str,
127
- choices=["single", "separate"],
128
- default="single",
129
- help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.",
130
- )
131
- parser.add_argument(
132
- "--unet_config_path",
133
- type=str,
134
- help="Path to the UNet config file (for separate mode)",
135
- )
136
- parser.add_argument(
137
- "--scheduler_config_path",
138
- type=str,
139
- help="Path to the Scheduler config file (for separate mode)",
140
- )
141
-
142
- args = parser.parse_args()
143
- main(**args.__dict__)