wangmengchao commited on
Commit
3570591
·
1 Parent(s): c27cabb
diffsynth/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_manager import *
diffsynth/models/downloader.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ from modelscope import snapshot_download
3
+ import os, shutil
4
+ from typing_extensions import Literal, TypeAlias
5
+ from typing import List
6
+ from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
7
+
8
+
9
+ def download_from_modelscope(model_id, origin_file_path, local_dir):
10
+ os.makedirs(local_dir, exist_ok=True)
11
+ file_name = os.path.basename(origin_file_path)
12
+ if file_name in os.listdir(local_dir):
13
+ print(f" {file_name} has been already in {local_dir}.")
14
+ else:
15
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
16
+ snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
17
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
18
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
19
+ if downloaded_file_path != target_file_path:
20
+ shutil.move(downloaded_file_path, target_file_path)
21
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
22
+
23
+
24
+ def download_from_huggingface(model_id, origin_file_path, local_dir):
25
+ os.makedirs(local_dir, exist_ok=True)
26
+ file_name = os.path.basename(origin_file_path)
27
+ if file_name in os.listdir(local_dir):
28
+ print(f" {file_name} has been already in {local_dir}.")
29
+ else:
30
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
31
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
32
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
33
+ target_file_path = os.path.join(local_dir, file_name)
34
+ if downloaded_file_path != target_file_path:
35
+ shutil.move(downloaded_file_path, target_file_path)
36
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
37
+
38
+
39
+ Preset_model_website: TypeAlias = Literal[
40
+ "HuggingFace",
41
+ "ModelScope",
42
+ ]
43
+ website_to_preset_models = {
44
+ "HuggingFace": preset_models_on_huggingface,
45
+ "ModelScope": preset_models_on_modelscope,
46
+ }
47
+ website_to_download_fn = {
48
+ "HuggingFace": download_from_huggingface,
49
+ "ModelScope": download_from_modelscope,
50
+ }
51
+
52
+
53
+ def download_customized_models(
54
+ model_id,
55
+ origin_file_path,
56
+ local_dir,
57
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
58
+ ):
59
+ downloaded_files = []
60
+ for website in downloading_priority:
61
+ # Check if the file is downloaded.
62
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
63
+ if file_to_download in downloaded_files:
64
+ continue
65
+ # Download
66
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
67
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
68
+ downloaded_files.append(file_to_download)
69
+ return downloaded_files
70
+
71
+
72
+ def download_models(
73
+ model_id_list: List[Preset_model_id] = [],
74
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
75
+ ):
76
+ print(f"Downloading models: {model_id_list}")
77
+ downloaded_files = []
78
+ load_files = []
79
+
80
+ for model_id in model_id_list:
81
+ for website in downloading_priority:
82
+ if model_id in website_to_preset_models[website]:
83
+
84
+ # Parse model metadata
85
+ model_metadata = website_to_preset_models[website][model_id]
86
+ if isinstance(model_metadata, list):
87
+ file_data = model_metadata
88
+ else:
89
+ file_data = model_metadata.get("file_list", [])
90
+
91
+ # Try downloading the model from this website.
92
+ model_files = []
93
+ for model_id, origin_file_path, local_dir in file_data:
94
+ # Check if the file is downloaded.
95
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
96
+ if file_to_download in downloaded_files:
97
+ continue
98
+ # Download
99
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
100
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
101
+ downloaded_files.append(file_to_download)
102
+ model_files.append(file_to_download)
103
+
104
+ # If the model is successfully downloaded, break.
105
+ if len(model_files) > 0:
106
+ if isinstance(model_metadata, dict) and "load_path" in model_metadata:
107
+ model_files = model_metadata["load_path"]
108
+ load_files.extend(model_files)
109
+ break
110
+
111
+ return load_files
diffsynth/models/model_manager.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, json, importlib
2
+ from typing import List
3
+
4
+ from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
5
+
6
+ from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
7
+ from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
8
+
9
+
10
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
11
+ loaded_model_names, loaded_models = [], []
12
+ for model_name, model_class in zip(model_names, model_classes):
13
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
14
+ state_dict_converter = model_class.state_dict_converter()
15
+ if model_resource == "civitai":
16
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
17
+ elif model_resource == "diffusers":
18
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
19
+ if isinstance(state_dict_results, tuple):
20
+ model_state_dict, extra_kwargs = state_dict_results
21
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
22
+ else:
23
+ model_state_dict, extra_kwargs = state_dict_results, {}
24
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
25
+ with init_weights_on_device():
26
+ model = model_class(**extra_kwargs)
27
+ if hasattr(model, "eval"):
28
+ model = model.eval()
29
+ model.load_state_dict(model_state_dict, assign=True)
30
+ model = model.to(dtype=torch_dtype, device=device)
31
+ loaded_model_names.append(model_name)
32
+ loaded_models.append(model)
33
+ return loaded_model_names, loaded_models
34
+
35
+
36
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
37
+ loaded_model_names, loaded_models = [], []
38
+ for model_name, model_class in zip(model_names, model_classes):
39
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
40
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
41
+ else:
42
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
43
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
44
+ model = model.half()
45
+ try:
46
+ model = model.to(device=device)
47
+ except:
48
+ pass
49
+ loaded_model_names.append(model_name)
50
+ loaded_models.append(model)
51
+ return loaded_model_names, loaded_models
52
+
53
+
54
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
55
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
56
+ base_state_dict = base_model.state_dict()
57
+ base_model.to("cpu")
58
+ del base_model
59
+ model = model_class(**extra_kwargs)
60
+ model.load_state_dict(base_state_dict, strict=False)
61
+ model.load_state_dict(state_dict, strict=False)
62
+ model.to(dtype=torch_dtype, device=device)
63
+ return model
64
+
65
+
66
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
67
+ loaded_model_names, loaded_models = [], []
68
+ for model_name, model_class in zip(model_names, model_classes):
69
+ while True:
70
+ for model_id in range(len(model_manager.model)):
71
+ base_model_name = model_manager.model_name[model_id]
72
+ if base_model_name == model_name:
73
+ base_model_path = model_manager.model_path[model_id]
74
+ base_model = model_manager.model[model_id]
75
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
76
+ patched_model = load_single_patch_model_from_single_file(
77
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
78
+ loaded_model_names.append(base_model_name)
79
+ loaded_models.append(patched_model)
80
+ model_manager.model.pop(model_id)
81
+ model_manager.model_path.pop(model_id)
82
+ model_manager.model_name.pop(model_id)
83
+ break
84
+ else:
85
+ break
86
+ return loaded_model_names, loaded_models
87
+
88
+
89
+
90
+ class ModelDetectorTemplate:
91
+ def __init__(self):
92
+ pass
93
+
94
+ def match(self, file_path="", state_dict={}):
95
+ return False
96
+
97
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
98
+ return [], []
99
+
100
+
101
+
102
+ class ModelDetectorFromSingleFile:
103
+ def __init__(self, model_loader_configs=[]):
104
+ self.keys_hash_with_shape_dict = {}
105
+ self.keys_hash_dict = {}
106
+ for metadata in model_loader_configs:
107
+ self.add_model_metadata(*metadata)
108
+
109
+
110
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
111
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
112
+ if keys_hash is not None:
113
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
114
+
115
+
116
+ def match(self, file_path="", state_dict={}):
117
+ if isinstance(file_path, str) and os.path.isdir(file_path):
118
+ return False
119
+ if len(state_dict) == 0:
120
+ state_dict = load_state_dict(file_path)
121
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
122
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
123
+ return True
124
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
125
+ if keys_hash in self.keys_hash_dict:
126
+ return True
127
+ return False
128
+
129
+
130
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
131
+ if len(state_dict) == 0:
132
+ state_dict = load_state_dict(file_path)
133
+
134
+ # Load models with strict matching
135
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
136
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
137
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
138
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
139
+ return loaded_model_names, loaded_models
140
+
141
+ # Load models without strict matching
142
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
143
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
144
+ if keys_hash in self.keys_hash_dict:
145
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
146
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
147
+ return loaded_model_names, loaded_models
148
+
149
+ return loaded_model_names, loaded_models
150
+
151
+
152
+
153
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
154
+ def __init__(self, model_loader_configs=[]):
155
+ super().__init__(model_loader_configs)
156
+
157
+
158
+ def match(self, file_path="", state_dict={}):
159
+ if isinstance(file_path, str) and os.path.isdir(file_path):
160
+ return False
161
+ if len(state_dict) == 0:
162
+ state_dict = load_state_dict(file_path)
163
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
164
+ for sub_state_dict in splited_state_dict:
165
+ if super().match(file_path, sub_state_dict):
166
+ return True
167
+ return False
168
+
169
+
170
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
171
+ # Split the state_dict and load from each component
172
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
173
+ valid_state_dict = {}
174
+ for sub_state_dict in splited_state_dict:
175
+ if super().match(file_path, sub_state_dict):
176
+ valid_state_dict.update(sub_state_dict)
177
+ if super().match(file_path, valid_state_dict):
178
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
179
+ else:
180
+ loaded_model_names, loaded_models = [], []
181
+ for sub_state_dict in splited_state_dict:
182
+ if super().match(file_path, sub_state_dict):
183
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
184
+ loaded_model_names += loaded_model_names_
185
+ loaded_models += loaded_models_
186
+ return loaded_model_names, loaded_models
187
+
188
+
189
+
190
+ class ModelDetectorFromHuggingfaceFolder:
191
+ def __init__(self, model_loader_configs=[]):
192
+ self.architecture_dict = {}
193
+ for metadata in model_loader_configs:
194
+ self.add_model_metadata(*metadata)
195
+
196
+
197
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
198
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
199
+
200
+
201
+ def match(self, file_path="", state_dict={}):
202
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
203
+ return False
204
+ file_list = os.listdir(file_path)
205
+ if "config.json" not in file_list:
206
+ return False
207
+ with open(os.path.join(file_path, "config.json"), "r") as f:
208
+ config = json.load(f)
209
+ if "architectures" not in config and "_class_name" not in config:
210
+ return False
211
+ return True
212
+
213
+
214
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
215
+ with open(os.path.join(file_path, "config.json"), "r") as f:
216
+ config = json.load(f)
217
+ loaded_model_names, loaded_models = [], []
218
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
219
+ for architecture in architectures:
220
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
221
+ if redirected_architecture is not None:
222
+ architecture = redirected_architecture
223
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
224
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
225
+ loaded_model_names += loaded_model_names_
226
+ loaded_models += loaded_models_
227
+ return loaded_model_names, loaded_models
228
+
229
+
230
+
231
+ class ModelDetectorFromPatchedSingleFile:
232
+ def __init__(self, model_loader_configs=[]):
233
+ self.keys_hash_with_shape_dict = {}
234
+ for metadata in model_loader_configs:
235
+ self.add_model_metadata(*metadata)
236
+
237
+
238
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
239
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
240
+
241
+
242
+ def match(self, file_path="", state_dict={}):
243
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
244
+ return False
245
+ if len(state_dict) == 0:
246
+ state_dict = load_state_dict(file_path)
247
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
248
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
249
+ return True
250
+ return False
251
+
252
+
253
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
254
+ if len(state_dict) == 0:
255
+ state_dict = load_state_dict(file_path)
256
+
257
+ # Load models with strict matching
258
+ loaded_model_names, loaded_models = [], []
259
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
260
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
261
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
262
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
263
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
264
+ loaded_model_names += loaded_model_names_
265
+ loaded_models += loaded_models_
266
+ return loaded_model_names, loaded_models
267
+
268
+
269
+
270
+ class ModelManager:
271
+ def __init__(
272
+ self,
273
+ torch_dtype=torch.float16,
274
+ device="cuda",
275
+ model_id_list: List[Preset_model_id] = [],
276
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
277
+ file_path_list: List[str] = [],
278
+ ):
279
+ self.torch_dtype = torch_dtype
280
+ self.device = device
281
+ self.model = []
282
+ self.model_path = []
283
+ self.model_name = []
284
+ downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
285
+ self.model_detector = [
286
+ ModelDetectorFromSingleFile(model_loader_configs),
287
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
288
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
289
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
290
+ ]
291
+ self.load_models(downloaded_files + file_path_list)
292
+
293
+
294
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
295
+ print(f"Loading models from file: {file_path}")
296
+ if len(state_dict) == 0:
297
+ state_dict = load_state_dict(file_path)
298
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
299
+ for model_name, model in zip(model_names, models):
300
+ self.model.append(model)
301
+ self.model_path.append(file_path)
302
+ self.model_name.append(model_name)
303
+ print(f" The following models are loaded: {model_names}.")
304
+
305
+
306
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
307
+ print(f"Loading models from folder: {file_path}")
308
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
309
+ for model_name, model in zip(model_names, models):
310
+ self.model.append(model)
311
+ self.model_path.append(file_path)
312
+ self.model_name.append(model_name)
313
+ print(f" The following models are loaded: {model_names}.")
314
+
315
+
316
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
317
+ print(f"Loading patch models from file: {file_path}")
318
+ model_names, models = load_patch_model_from_single_file(
319
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
320
+ for model_name, model in zip(model_names, models):
321
+ self.model.append(model)
322
+ self.model_path.append(file_path)
323
+ self.model_name.append(model_name)
324
+ print(f" The following patched models are loaded: {model_names}.")
325
+
326
+
327
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
328
+ if isinstance(file_path, list):
329
+ for file_path_ in file_path:
330
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
331
+ else:
332
+ print(f"Loading LoRA models from file: {file_path}")
333
+ if len(state_dict) == 0:
334
+ state_dict = load_state_dict(file_path)
335
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
336
+ for lora in get_lora_loaders():
337
+ match_results = lora.match(model, state_dict)
338
+ if match_results is not None:
339
+ print(f" Adding LoRA to {model_name} ({model_path}).")
340
+ lora_prefix, model_resource = match_results
341
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
342
+ break
343
+
344
+
345
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
346
+ print(f"Loading models from: {file_path}")
347
+ if device is None: device = self.device
348
+ if torch_dtype is None: torch_dtype = self.torch_dtype
349
+ if isinstance(file_path, list):
350
+ state_dict = {}
351
+ for path in file_path:
352
+ state_dict.update(load_state_dict(path))
353
+ elif os.path.isfile(file_path):
354
+ state_dict = load_state_dict(file_path)
355
+ else:
356
+ state_dict = None
357
+ for model_detector in self.model_detector:
358
+ if model_detector.match(file_path, state_dict):
359
+ model_names, models = model_detector.load(
360
+ file_path, state_dict,
361
+ device=device, torch_dtype=torch_dtype,
362
+ allowed_model_names=model_names, model_manager=self
363
+ )
364
+ for model_name, model in zip(model_names, models):
365
+ self.model.append(model)
366
+ self.model_path.append(file_path)
367
+ self.model_name.append(model_name)
368
+ print(f" The following models are loaded: {model_names}.")
369
+ break
370
+ else:
371
+ print(f" We cannot detect the model type. No models are loaded.")
372
+
373
+
374
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
375
+ for file_path in file_path_list:
376
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
377
+
378
+
379
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
380
+ fetched_models = []
381
+ fetched_model_paths = []
382
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
383
+ if file_path is not None and file_path != model_path:
384
+ continue
385
+ if model_name == model_name_:
386
+ fetched_models.append(model)
387
+ fetched_model_paths.append(model_path)
388
+ if len(fetched_models) == 0:
389
+ print(f"No {model_name} models available.")
390
+ return None
391
+ if len(fetched_models) == 1:
392
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
393
+ else:
394
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
395
+ if require_model_path:
396
+ return fetched_models[0], fetched_model_paths[0]
397
+ else:
398
+ return fetched_models[0]
399
+
400
+
401
+ def to(self, device):
402
+ for model in self.model:
403
+ model.to(device)
404
+
diffsynth/models/utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from safetensors import safe_open
3
+ from contextlib import contextmanager
4
+ import hashlib
5
+
6
+ @contextmanager
7
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
8
+
9
+ old_register_parameter = torch.nn.Module.register_parameter
10
+ if include_buffers:
11
+ old_register_buffer = torch.nn.Module.register_buffer
12
+
13
+ def register_empty_parameter(module, name, param):
14
+ old_register_parameter(module, name, param)
15
+ if param is not None:
16
+ param_cls = type(module._parameters[name])
17
+ kwargs = module._parameters[name].__dict__
18
+ kwargs["requires_grad"] = param.requires_grad
19
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
20
+
21
+ def register_empty_buffer(module, name, buffer, persistent=True):
22
+ old_register_buffer(module, name, buffer, persistent=persistent)
23
+ if buffer is not None:
24
+ module._buffers[name] = module._buffers[name].to(device)
25
+
26
+ def patch_tensor_constructor(fn):
27
+ def wrapper(*args, **kwargs):
28
+ kwargs["device"] = device
29
+ return fn(*args, **kwargs)
30
+
31
+ return wrapper
32
+
33
+ if include_buffers:
34
+ tensor_constructors_to_patch = {
35
+ torch_function_name: getattr(torch, torch_function_name)
36
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
37
+ }
38
+ else:
39
+ tensor_constructors_to_patch = {}
40
+
41
+ try:
42
+ torch.nn.Module.register_parameter = register_empty_parameter
43
+ if include_buffers:
44
+ torch.nn.Module.register_buffer = register_empty_buffer
45
+ for torch_function_name in tensor_constructors_to_patch.keys():
46
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
47
+ yield
48
+ finally:
49
+ torch.nn.Module.register_parameter = old_register_parameter
50
+ if include_buffers:
51
+ torch.nn.Module.register_buffer = old_register_buffer
52
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
53
+ setattr(torch, torch_function_name, old_torch_function)
54
+
55
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
56
+ state_dict = {}
57
+ for file_name in os.listdir(file_path):
58
+ if "." in file_name and file_name.split(".")[-1] in [
59
+ "safetensors", "bin", "ckpt", "pth", "pt"
60
+ ]:
61
+ state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
62
+ return state_dict
63
+
64
+
65
+ def load_state_dict(file_path, torch_dtype=None):
66
+ if file_path.endswith(".safetensors"):
67
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
68
+ else:
69
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
70
+
71
+
72
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
73
+ state_dict = {}
74
+ with safe_open(file_path, framework="pt", device="cpu") as f:
75
+ for k in f.keys():
76
+ state_dict[k] = f.get_tensor(k)
77
+ if torch_dtype is not None:
78
+ state_dict[k] = state_dict[k].to(torch_dtype)
79
+ return state_dict
80
+
81
+
82
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
83
+ state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
84
+ if torch_dtype is not None:
85
+ for i in state_dict:
86
+ if isinstance(state_dict[i], torch.Tensor):
87
+ state_dict[i] = state_dict[i].to(torch_dtype)
88
+ return state_dict
89
+
90
+
91
+ def search_for_embeddings(state_dict):
92
+ embeddings = []
93
+ for k in state_dict:
94
+ if isinstance(state_dict[k], torch.Tensor):
95
+ embeddings.append(state_dict[k])
96
+ elif isinstance(state_dict[k], dict):
97
+ embeddings += search_for_embeddings(state_dict[k])
98
+ return embeddings
99
+
100
+
101
+ def search_parameter(param, state_dict):
102
+ for name, param_ in state_dict.items():
103
+ if param.numel() == param_.numel():
104
+ if param.shape == param_.shape:
105
+ if torch.dist(param, param_) < 1e-3:
106
+ return name
107
+ else:
108
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
109
+ return name
110
+ return None
111
+
112
+
113
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
114
+ matched_keys = set()
115
+ with torch.no_grad():
116
+ for name in source_state_dict:
117
+ rename = search_parameter(source_state_dict[name], target_state_dict)
118
+ if rename is not None:
119
+ print(f'"{name}": "{rename}",')
120
+ matched_keys.add(rename)
121
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
122
+ length = source_state_dict[name].shape[0] // 3
123
+ rename = []
124
+ for i in range(3):
125
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
126
+ if None not in rename:
127
+ print(f'"{name}": {rename},')
128
+ for rename_ in rename:
129
+ matched_keys.add(rename_)
130
+ for name in target_state_dict:
131
+ if name not in matched_keys:
132
+ print("Cannot find", name, target_state_dict[name].shape)
133
+
134
+
135
+ def search_for_files(folder, extensions):
136
+ files = []
137
+ if os.path.isdir(folder):
138
+ for file in sorted(os.listdir(folder)):
139
+ files += search_for_files(os.path.join(folder, file), extensions)
140
+ elif os.path.isfile(folder):
141
+ for extension in extensions:
142
+ if folder.endswith(extension):
143
+ files.append(folder)
144
+ break
145
+ return files
146
+
147
+
148
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
149
+ keys = []
150
+ for key, value in state_dict.items():
151
+ if isinstance(key, str):
152
+ if isinstance(value, torch.Tensor):
153
+ if with_shape:
154
+ shape = "_".join(map(str, list(value.shape)))
155
+ keys.append(key + ":" + shape)
156
+ keys.append(key)
157
+ elif isinstance(value, dict):
158
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
159
+ keys.sort()
160
+ keys_str = ",".join(keys)
161
+ return keys_str
162
+
163
+
164
+ def split_state_dict_with_prefix(state_dict):
165
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
166
+ prefix_dict = {}
167
+ for key in keys:
168
+ prefix = key if "." not in key else key.split(".")[0]
169
+ if prefix not in prefix_dict:
170
+ prefix_dict[prefix] = []
171
+ prefix_dict[prefix].append(key)
172
+ state_dicts = []
173
+ for prefix, keys in prefix_dict.items():
174
+ sub_state_dict = {key: state_dict[key] for key in keys}
175
+ state_dicts.append(sub_state_dict)
176
+ return state_dicts
177
+
178
+
179
+ def hash_state_dict_keys(state_dict, with_shape=True):
180
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
181
+ keys_str = keys_str.encode(encoding="UTF-8")
182
+ return hashlib.md5(keys_str).hexdigest()
diffsynth/models/wan_video_dit.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.amp as amp
5
+ import torch.nn as nn
6
+ from tqdm import tqdm
7
+ from .utils import hash_state_dict_keys
8
+
9
+ try:
10
+ import flash_attn_interface
11
+ FLASH_ATTN_3_AVAILABLE = True
12
+ except ModuleNotFoundError:
13
+ FLASH_ATTN_3_AVAILABLE = False
14
+
15
+ try:
16
+ import flash_attn
17
+ FLASH_ATTN_2_AVAILABLE = True
18
+ except ModuleNotFoundError:
19
+ FLASH_ATTN_2_AVAILABLE = False
20
+
21
+ try:
22
+ from sageattention import sageattn
23
+ SAGE_ATTN_AVAILABLE = True
24
+ except ModuleNotFoundError:
25
+ SAGE_ATTN_AVAILABLE = False
26
+
27
+ import warnings
28
+
29
+
30
+ __all__ = ['WanModel']
31
+
32
+
33
+ def flash_attention(
34
+ q,
35
+ k,
36
+ v,
37
+ q_lens=None,
38
+ k_lens=None,
39
+ dropout_p=0.,
40
+ softmax_scale=None,
41
+ q_scale=None,
42
+ causal=False,
43
+ window_size=(-1, -1),
44
+ deterministic=False,
45
+ dtype=torch.bfloat16,
46
+ version=None,
47
+ ):
48
+ """
49
+ q: [B, Lq, Nq, C1].
50
+ k: [B, Lk, Nk, C1].
51
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
52
+ q_lens: [B].
53
+ k_lens: [B].
54
+ dropout_p: float. Dropout probability.
55
+ softmax_scale: float. The scaling of QK^T before applying softmax.
56
+ causal: bool. Whether to apply causal attention mask.
57
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
58
+ deterministic: bool. If True, slightly slower and uses more memory.
59
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
60
+ """
61
+ half_dtypes = (torch.float16, torch.bfloat16)
62
+ assert dtype in half_dtypes
63
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
64
+
65
+ # params
66
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
67
+
68
+ def half(x):
69
+ return x if x.dtype in half_dtypes else x.to(dtype)
70
+
71
+ # preprocess query
72
+ if q_lens is None:
73
+ q = half(q.flatten(0, 1))
74
+ q_lens = torch.tensor(
75
+ [lq] * b, dtype=torch.int32).to(
76
+ device=q.device, non_blocking=True)
77
+ else:
78
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
79
+
80
+ # preprocess key, value
81
+ if k_lens is None:
82
+ k = half(k.flatten(0, 1))
83
+ v = half(v.flatten(0, 1))
84
+ k_lens = torch.tensor(
85
+ [lk] * b, dtype=torch.int32).to(
86
+ device=k.device, non_blocking=True)
87
+ else:
88
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
89
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
90
+
91
+ q = q.to(v.dtype)
92
+ k = k.to(v.dtype)
93
+
94
+ if q_scale is not None:
95
+ q = q * q_scale
96
+
97
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
98
+ warnings.warn(
99
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
100
+ )
101
+
102
+ # apply attention
103
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
104
+ # Note: dropout_p, window_size are not supported in FA3 now.
105
+ x = flash_attn_interface.flash_attn_varlen_func(
106
+ q=q,
107
+ k=k,
108
+ v=v,
109
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
110
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
111
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
112
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
113
+ seqused_q=None,
114
+ seqused_k=None,
115
+ max_seqlen_q=lq,
116
+ max_seqlen_k=lk,
117
+ softmax_scale=softmax_scale,
118
+ causal=causal,
119
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
120
+ elif FLASH_ATTN_2_AVAILABLE:
121
+ x = flash_attn.flash_attn_varlen_func(
122
+ q=q,
123
+ k=k,
124
+ v=v,
125
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
126
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
127
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
128
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
129
+ max_seqlen_q=lq,
130
+ max_seqlen_k=lk,
131
+ dropout_p=dropout_p,
132
+ softmax_scale=softmax_scale,
133
+ causal=causal,
134
+ window_size=window_size,
135
+ deterministic=deterministic).unflatten(0, (b, lq))
136
+ elif SAGE_ATTN_AVAILABLE:
137
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
138
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
139
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
140
+ x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
141
+ x = x.transpose(1, 2).contiguous()
142
+ else:
143
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
144
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
145
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
146
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
147
+ x = x.transpose(1, 2).contiguous()
148
+
149
+ # output
150
+ return x.type(out_dtype)
151
+
152
+
153
+ def create_sdpa_mask(q, k, q_lens, k_lens, causal=False):
154
+ b, lq, lk = q.size(0), q.size(1), k.size(1)
155
+ if q_lens is None:
156
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32)
157
+ if k_lens is None:
158
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32)
159
+ attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool)
160
+ for i in range(b):
161
+ q_len, k_len = q_lens[i], k_lens[i]
162
+ attn_mask[i, q_len:, :] = True
163
+ attn_mask[i, :, k_len:] = True
164
+
165
+ if causal:
166
+ causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1)
167
+ attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask)
168
+
169
+ attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True)
170
+ return attn_mask
171
+
172
+
173
+ def attention(
174
+ q,
175
+ k,
176
+ v,
177
+ q_lens=None,
178
+ k_lens=None,
179
+ dropout_p=0.,
180
+ softmax_scale=None,
181
+ q_scale=None,
182
+ causal=False,
183
+ window_size=(-1, -1),
184
+ deterministic=False,
185
+ dtype=torch.bfloat16,
186
+ fa_version=None,
187
+ ):
188
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
189
+ return flash_attention(
190
+ q=q,
191
+ k=k,
192
+ v=v,
193
+ q_lens=q_lens,
194
+ k_lens=k_lens,
195
+ dropout_p=dropout_p,
196
+ softmax_scale=softmax_scale,
197
+ q_scale=q_scale,
198
+ causal=causal,
199
+ window_size=window_size,
200
+ deterministic=deterministic,
201
+ dtype=dtype,
202
+ version=fa_version,
203
+ )
204
+ else:
205
+ if q_lens is not None or k_lens is not None:
206
+ warnings.warn('Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.')
207
+ attn_mask = None
208
+
209
+ q = q.transpose(1, 2).to(dtype)
210
+ k = k.transpose(1, 2).to(dtype)
211
+ v = v.transpose(1, 2).to(dtype)
212
+
213
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
214
+
215
+ out = out.transpose(1, 2).contiguous()
216
+ return out
217
+
218
+
219
+
220
+ def sinusoidal_embedding_1d(dim, position):
221
+ # preprocess
222
+ assert dim % 2 == 0
223
+ half = dim // 2
224
+ position = position.type(torch.float64)
225
+
226
+ # calculation
227
+ sinusoid = torch.outer(
228
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
229
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
230
+ return x
231
+
232
+
233
+ @amp.autocast(enabled=False, device_type="cuda")
234
+ def rope_params(max_seq_len, dim, theta=10000):
235
+ assert dim % 2 == 0
236
+ freqs = torch.outer(
237
+ torch.arange(max_seq_len),
238
+ 1.0 / torch.pow(theta,
239
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
240
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
241
+ return freqs
242
+
243
+
244
+ @amp.autocast(enabled=False, device_type="cuda")
245
+ def rope_apply(x, grid_sizes, freqs):
246
+ n, c = x.size(2), x.size(3) // 2
247
+
248
+ # split freqs
249
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
250
+
251
+ # loop over samples
252
+ output = []
253
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
254
+ seq_len = f * h * w
255
+
256
+ # precompute multipliers
257
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
258
+ seq_len, n, -1, 2))
259
+ freqs_i = torch.cat([
260
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
261
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
262
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
263
+ ],
264
+ dim=-1).reshape(seq_len, 1, -1)
265
+
266
+ # apply rotary embedding
267
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
268
+ x_i = torch.cat([x_i, x[i, seq_len:]])
269
+
270
+ # append to collection
271
+ output.append(x_i)
272
+ return torch.stack(output).float()
273
+
274
+
275
+ class WanRMSNorm(nn.Module):
276
+
277
+ def __init__(self, dim, eps=1e-5):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.eps = eps
281
+ self.weight = nn.Parameter(torch.ones(dim))
282
+
283
+ def forward(self, x):
284
+ return self._norm(x.float()).type_as(x) * self.weight
285
+
286
+ def _norm(self, x):
287
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
288
+
289
+
290
+ class WanLayerNorm(nn.LayerNorm):
291
+
292
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
293
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
294
+
295
+ def forward(self, x):
296
+ return super().forward(x.float()).type_as(x)
297
+
298
+
299
+ class WanSelfAttention(nn.Module):
300
+
301
+ def __init__(self,
302
+ dim,
303
+ num_heads,
304
+ window_size=(-1, -1),
305
+ qk_norm=True,
306
+ eps=1e-6):
307
+ assert dim % num_heads == 0
308
+ super().__init__()
309
+ self.dim = dim
310
+ self.num_heads = num_heads
311
+ self.head_dim = dim // num_heads
312
+ self.window_size = window_size
313
+ self.qk_norm = qk_norm
314
+ self.eps = eps
315
+
316
+ # layers
317
+ self.q = nn.Linear(dim, dim)
318
+ self.k = nn.Linear(dim, dim)
319
+ self.v = nn.Linear(dim, dim)
320
+ self.o = nn.Linear(dim, dim)
321
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
322
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
323
+
324
+ def forward(self, x, seq_lens, grid_sizes, freqs):
325
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
326
+
327
+ # query, key, value function
328
+ def qkv_fn(x):
329
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
330
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
331
+ v = self.v(x).view(b, s, n, d)
332
+ return q, k, v
333
+
334
+ q, k, v = qkv_fn(x)
335
+
336
+ x = flash_attention(
337
+ q=rope_apply(q, grid_sizes, freqs),
338
+ k=rope_apply(k, grid_sizes, freqs),
339
+ v=v,
340
+ k_lens=seq_lens,
341
+ window_size=self.window_size)
342
+
343
+ # output
344
+ x = x.flatten(2)
345
+ x = self.o(x)
346
+ return x
347
+
348
+
349
+ class WanT2VCrossAttention(WanSelfAttention):
350
+
351
+ def forward(self, x, context, context_lens):
352
+ """
353
+ x: [B, L1, C].
354
+ context: [B, L2, C].
355
+ context_lens: [B].
356
+ """
357
+ b, n, d = x.size(0), self.num_heads, self.head_dim
358
+
359
+ # compute query, key, value
360
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
361
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
362
+ v = self.v(context).view(b, -1, n, d)
363
+
364
+ # compute attention
365
+ x = flash_attention(q, k, v, k_lens=context_lens)
366
+
367
+ # output
368
+ x = x.flatten(2)
369
+ x = self.o(x)
370
+ return x
371
+
372
+ class WanI2VCrossAttentionProcessor:
373
+ def __call__(self, attn, x, context, context_lens) -> torch.Tensor:
374
+ """
375
+ x: [B, L1, C].
376
+ context: [B, L2, C].
377
+ context_lens: [B].
378
+ """
379
+ context_img = context[:, :257]
380
+ context = context[:, 257:]
381
+ b, n, d = x.size(0), attn.num_heads, attn.head_dim
382
+
383
+ # compute query, key, value
384
+ q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
385
+ k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
386
+ v = attn.v(context).view(b, -1, n, d)
387
+ k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
388
+ v_img = attn.v_img(context_img).view(b, -1, n, d)
389
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
390
+ # compute attention
391
+ x = flash_attention(q, k, v, k_lens=context_lens)
392
+
393
+ # output
394
+ x = x.flatten(2)
395
+ img_x = img_x.flatten(2)
396
+ x = x + img_x
397
+ x = attn.o(x)
398
+ return x
399
+
400
+ class WanI2VCrossAttention(WanSelfAttention):
401
+
402
+ def __init__(self,
403
+ dim,
404
+ num_heads,
405
+ window_size=(-1, -1),
406
+ qk_norm=True,
407
+ eps=1e-6):
408
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
409
+
410
+ self.k_img = nn.Linear(dim, dim)
411
+ self.v_img = nn.Linear(dim, dim)
412
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
413
+ self.norm_k_img = WanRMSNorm(
414
+ dim, eps=eps) if qk_norm else nn.Identity()
415
+
416
+ processor = WanI2VCrossAttentionProcessor()
417
+ self.set_processor(processor)
418
+
419
+ def set_processor(self, processor) -> None:
420
+ self.processor = processor
421
+
422
+ def get_processor(self):
423
+ return self.processor
424
+
425
+ def forward(self, x, context, context_lens, audio_proj, audio_context_lens, latents_num_frames, audio_scale: float = 1.0, **kwargs):
426
+ """
427
+ x: [B, L1, C].
428
+ context: [B, L2, C].
429
+ context_lens: [B].
430
+ """
431
+ if audio_proj is None:
432
+ return self.processor(self, x, context, context_lens)
433
+ else:
434
+ return self.processor(self, x, context, context_lens, audio_proj, audio_context_lens, latents_num_frames, audio_scale)
435
+
436
+ WANX_CROSSATTENTION_CLASSES = {
437
+ 't2v_cross_attn': WanT2VCrossAttention,
438
+ 'i2v_cross_attn': WanI2VCrossAttention,
439
+ }
440
+
441
+
442
+ class WanAttentionBlock(nn.Module):
443
+
444
+ def __init__(self,
445
+ cross_attn_type,
446
+ dim,
447
+ ffn_dim,
448
+ num_heads,
449
+ window_size=(-1, -1),
450
+ qk_norm=True,
451
+ cross_attn_norm=False,
452
+ eps=1e-6):
453
+ super().__init__()
454
+ self.dim = dim
455
+ self.ffn_dim = ffn_dim
456
+ self.num_heads = num_heads
457
+ self.window_size = window_size
458
+ self.qk_norm = qk_norm
459
+ self.cross_attn_norm = cross_attn_norm
460
+ self.eps = eps
461
+
462
+ # layers
463
+ self.norm1 = WanLayerNorm(dim, eps)
464
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
465
+ eps)
466
+ self.norm3 = WanLayerNorm(
467
+ dim, eps,
468
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
469
+ self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
470
+ dim, num_heads, (-1, -1), qk_norm, eps)
471
+ self.norm2 = WanLayerNorm(dim, eps)
472
+ self.ffn = nn.Sequential(
473
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
474
+ nn.Linear(ffn_dim, dim))
475
+
476
+ # modulation
477
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
478
+
479
+ def forward(
480
+ self,
481
+ x,
482
+ e,
483
+ seq_lens,
484
+ grid_sizes,
485
+ freqs,
486
+ context,
487
+ context_lens,
488
+ **kwargs,
489
+ ):
490
+ assert e.dtype == torch.float32
491
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
492
+ e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
493
+ assert e[0].dtype == torch.float32
494
+
495
+ # self-attention
496
+ y = self.self_attn(
497
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
498
+ freqs)
499
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
500
+ x = x + y * e[2]
501
+
502
+ # cross-attention & ffn function
503
+ def cross_attn_ffn(x, context, context_lens, e, **kwargs):
504
+ x = x + self.cross_attn(self.norm3(x), context, context_lens, **kwargs)
505
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
506
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
507
+ x = x + y * e[5]
508
+ return x
509
+
510
+ x = cross_attn_ffn(x, context, context_lens, e, **kwargs)
511
+ return x
512
+
513
+
514
+ class Head(nn.Module):
515
+
516
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
517
+ super().__init__()
518
+ self.dim = dim
519
+ self.out_dim = out_dim
520
+ self.patch_size = patch_size
521
+ self.eps = eps
522
+
523
+ # layers
524
+ out_dim = math.prod(patch_size) * out_dim
525
+ self.norm = WanLayerNorm(dim, eps)
526
+ self.head = nn.Linear(dim, out_dim)
527
+
528
+ # modulation
529
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
530
+
531
+ def forward(self, x, e):
532
+ assert e.dtype == torch.float32
533
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
534
+ e = (self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)).chunk(2, dim=1)
535
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
536
+ return x
537
+
538
+
539
+ class MLPProj(torch.nn.Module):
540
+
541
+ def __init__(self, in_dim, out_dim):
542
+ super().__init__()
543
+
544
+ self.proj = torch.nn.Sequential(
545
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
546
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
547
+ torch.nn.LayerNorm(out_dim))
548
+
549
+ def forward(self, image_embeds):
550
+ clip_extra_context_tokens = self.proj(image_embeds)
551
+ return clip_extra_context_tokens
552
+
553
+
554
+ class WanModel(nn.Module):
555
+
556
+ def __init__(self,
557
+ model_type='t2v',
558
+ patch_size=(1, 2, 2),
559
+ text_len=512,
560
+ in_dim=16,
561
+ dim=2048,
562
+ ffn_dim=8192,
563
+ freq_dim=256,
564
+ text_dim=4096,
565
+ out_dim=16,
566
+ num_heads=16,
567
+ num_layers=32,
568
+ window_size=(-1, -1),
569
+ qk_norm=True,
570
+ cross_attn_norm=False,
571
+ eps=1e-6):
572
+ super().__init__()
573
+
574
+ assert model_type in ['t2v', 'i2v']
575
+ self.model_type = model_type
576
+
577
+ self.patch_size = patch_size
578
+ self.text_len = text_len
579
+ self.in_dim = in_dim
580
+ self.dim = dim
581
+ self.ffn_dim = ffn_dim
582
+ self.freq_dim = freq_dim
583
+ self.text_dim = text_dim
584
+ self.out_dim = out_dim
585
+ self.num_heads = num_heads
586
+ self.num_layers = num_layers
587
+ self.window_size = window_size
588
+ self.qk_norm = qk_norm
589
+ self.cross_attn_norm = cross_attn_norm
590
+ self.eps = eps
591
+
592
+ # embeddings
593
+ self.patch_embedding = nn.Conv3d(
594
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
595
+ self.text_embedding = nn.Sequential(
596
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
597
+ nn.Linear(dim, dim))
598
+
599
+ self.time_embedding = nn.Sequential(
600
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
601
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
602
+
603
+ # blocks
604
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
605
+ self.blocks = nn.ModuleList([
606
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
607
+ window_size, qk_norm, cross_attn_norm, eps)
608
+ for _ in range(num_layers)
609
+ ])
610
+
611
+ # head
612
+ self.head = Head(dim, out_dim, patch_size, eps)
613
+
614
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
615
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
616
+ d = dim // num_heads
617
+ self.freqs = torch.cat([
618
+ rope_params(1024, d - 4 * (d // 6)),
619
+ rope_params(1024, 2 * (d // 6)),
620
+ rope_params(1024, 2 * (d // 6))
621
+ ],
622
+ dim=1)
623
+
624
+ if model_type == 'i2v':
625
+ self.img_emb = MLPProj(1280, dim)
626
+
627
+ # initialize weights
628
+ self.init_weights()
629
+
630
+ def forward(
631
+ self,
632
+ x,
633
+ timestep,
634
+ context,
635
+ seq_len,
636
+ clip_fea=None,
637
+ y=None,
638
+ use_gradient_checkpointing=False,
639
+ audio_proj=None,
640
+ audio_context_lens=None,
641
+ latents_num_frames=None,
642
+ audio_scale=1.0,
643
+ **kwargs,
644
+ ):
645
+ """
646
+ x: A list of videos each with shape [C, T, H, W].
647
+ t: [B].
648
+ context: A list of text embeddings each with shape [L, C].
649
+ """
650
+ if self.model_type == 'i2v':
651
+ assert clip_fea is not None and y is not None
652
+ # params
653
+ device = x[0].device
654
+ if self.freqs.device != device:
655
+ self.freqs = self.freqs.to(device)
656
+
657
+ if y is not None:
658
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
659
+
660
+ # embeddings
661
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
662
+ grid_sizes = torch.stack(
663
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) # [B,2]
664
+ x = [u.flatten(2).transpose(1, 2) for u in x] # [[C, L, T],,]
665
+ # print(f"x0.shape:{x[0].shape}")
666
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
667
+ assert seq_lens.max() <= seq_len
668
+ x = torch.cat([
669
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
670
+ dim=1) for u in x
671
+ ])
672
+
673
+ # time embeddings
674
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
675
+ e = self.time_embedding(
676
+ sinusoidal_embedding_1d(self.freq_dim, timestep).float())
677
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
678
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
679
+
680
+ # context
681
+ context_lens = None
682
+ context = self.text_embedding(
683
+ torch.stack([
684
+ torch.cat(
685
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
686
+ for u in context
687
+ ]))
688
+
689
+ if clip_fea is not None:
690
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
691
+ context = torch.concat([context_clip, context], dim=1)
692
+
693
+ # arguments
694
+ kwargs = dict(
695
+ e=e0,
696
+ seq_lens=seq_lens,
697
+ grid_sizes=grid_sizes,
698
+ freqs=self.freqs,
699
+ context=context,
700
+ context_lens=context_lens,
701
+ audio_proj=audio_proj,
702
+ audio_context_lens=audio_context_lens,
703
+ latents_num_frames=latents_num_frames,
704
+ audio_scale=audio_scale)
705
+
706
+ def create_custom_forward(module):
707
+ def custom_forward(*inputs, **kwargs):
708
+ return module(*inputs, **kwargs)
709
+ return custom_forward
710
+
711
+ for block in self.blocks:
712
+ if self.training and use_gradient_checkpointing:
713
+ x = torch.utils.checkpoint.checkpoint(
714
+ create_custom_forward(block),
715
+ x, **kwargs,
716
+ use_reentrant=False,
717
+ )
718
+ else:
719
+ x = block(x, **kwargs)
720
+
721
+ # head
722
+ x = self.head(x, e)
723
+
724
+ # unpatchify
725
+ x = self.unpatchify(x, grid_sizes)
726
+ x = torch.stack(x).float()
727
+ return x
728
+
729
+ def unpatchify(self, x, grid_sizes):
730
+ c = self.out_dim
731
+ out = []
732
+ for u, v in zip(x, grid_sizes.tolist()):
733
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
734
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
735
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
736
+ out.append(u)
737
+ return out
738
+
739
+ def init_weights(self):
740
+ # basic init
741
+ for m in self.modules():
742
+ if isinstance(m, nn.Linear):
743
+ nn.init.xavier_uniform_(m.weight)
744
+ if m.bias is not None:
745
+ nn.init.zeros_(m.bias)
746
+
747
+ # init embeddings
748
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
749
+ for m in self.text_embedding.modules():
750
+ if isinstance(m, nn.Linear):
751
+ nn.init.normal_(m.weight, std=.02)
752
+ for m in self.time_embedding.modules():
753
+ if isinstance(m, nn.Linear):
754
+ nn.init.normal_(m.weight, std=.02)
755
+
756
+ # init output layer
757
+ nn.init.zeros_(self.head.head.weight)
758
+
759
+ @staticmethod
760
+ def state_dict_converter():
761
+ return WanModelStateDictConverter()
762
+
763
+ @property
764
+ def attn_processors(self): #copy from https://github.com/XLabs-AI/x-flux/blob/main/src/flux/model.py
765
+ # set recursively
766
+ processors = {}
767
+
768
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
769
+ if hasattr(module, "set_processor"):
770
+ processors[f"{name}.processor"] = module.processor
771
+
772
+ for sub_name, child in module.named_children():
773
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
774
+
775
+ return processors
776
+
777
+ for name, module in self.named_children():
778
+ fn_recursive_add_processors(name, module, processors)
779
+
780
+ return processors
781
+
782
+ def set_attn_processor(self, processor):
783
+ r""" copy from https://github.com/XLabs-AI/x-flux/blob/main/src/flux/model.py
784
+ Sets the attention processor to use to compute attention.
785
+
786
+ Parameters:
787
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
788
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
789
+ for **all** `Attention` layers.
790
+
791
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
792
+ processor. This is strongly recommended when setting trainable attention processors.
793
+
794
+ """
795
+ count = len(self.attn_processors.keys())
796
+
797
+ if isinstance(processor, dict) and len(processor) != count:
798
+ raise ValueError(
799
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
800
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
801
+ )
802
+
803
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
804
+ if hasattr(module, "set_processor"):
805
+ if not isinstance(processor, dict):
806
+ module.set_processor(processor)
807
+ else:
808
+ module.set_processor(processor.pop(f"{name}.processor"))
809
+
810
+ for sub_name, child in module.named_children():
811
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
812
+
813
+ for name, module in self.named_children():
814
+ fn_recursive_attn_processor(name, module, processor)
815
+
816
+
817
+ class WanModelStateDictConverter:
818
+ def __init__(self):
819
+ pass
820
+
821
+ def from_diffusers(self, state_dict):
822
+ return state_dict
823
+
824
+ def from_civitai(self, state_dict):
825
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
826
+ config = {
827
+ "model_type": "t2v",
828
+ "patch_size": (1, 2, 2),
829
+ "text_len": 512,
830
+ "in_dim": 16,
831
+ "dim": 1536,
832
+ "ffn_dim": 8960,
833
+ "freq_dim": 256,
834
+ "text_dim": 4096,
835
+ "out_dim": 16,
836
+ "num_heads": 12,
837
+ "num_layers": 30,
838
+ "window_size": (-1, -1),
839
+ "qk_norm": True,
840
+ "cross_attn_norm": True,
841
+ "eps": 1e-6,
842
+ }
843
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
844
+ config = {
845
+ "model_type": "t2v",
846
+ "patch_size": (1, 2, 2),
847
+ "text_len": 512,
848
+ "in_dim": 16,
849
+ "dim": 5120,
850
+ "ffn_dim": 13824,
851
+ "freq_dim": 256,
852
+ "text_dim": 4096,
853
+ "out_dim": 16,
854
+ "num_heads": 40,
855
+ "num_layers": 40,
856
+ "window_size": (-1, -1),
857
+ "qk_norm": True,
858
+ "cross_attn_norm": True,
859
+ "eps": 1e-6,
860
+ }
861
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
862
+ config = {
863
+ "model_type": "i2v",
864
+ "patch_size": (1, 2, 2),
865
+ "text_len": 512,
866
+ "in_dim": 36,
867
+ "dim": 5120,
868
+ "ffn_dim": 13824,
869
+ "freq_dim": 256,
870
+ "text_dim": 4096,
871
+ "out_dim": 16,
872
+ "num_heads": 40,
873
+ "num_layers": 40,
874
+ "window_size": (-1, -1),
875
+ "qk_norm": True,
876
+ "cross_attn_norm": True,
877
+ "eps": 1e-6,
878
+ }
879
+ else:
880
+ config = {}
881
+ return state_dict, config
diffsynth/models/wan_video_image_encoder.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Concise re-implementation of
3
+ ``https://github.com/openai/CLIP'' and
4
+ ``https://github.com/mlfoundations/open_clip''.
5
+ """
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ from .wan_video_dit import flash_attention
12
+
13
+
14
+ class SelfAttention(nn.Module):
15
+
16
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
17
+ assert dim % num_heads == 0
18
+ super().__init__()
19
+ self.dim = dim
20
+ self.num_heads = num_heads
21
+ self.head_dim = dim // num_heads
22
+ self.eps = eps
23
+
24
+ # layers
25
+ self.q = nn.Linear(dim, dim)
26
+ self.k = nn.Linear(dim, dim)
27
+ self.v = nn.Linear(dim, dim)
28
+ self.o = nn.Linear(dim, dim)
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ def forward(self, x, mask):
32
+ """
33
+ x: [B, L, C].
34
+ """
35
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
36
+
37
+ # compute query, key, value
38
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
39
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
40
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
41
+
42
+ # compute attention
43
+ p = self.dropout.p if self.training else 0.0
44
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
45
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
46
+
47
+ # output
48
+ x = self.o(x)
49
+ x = self.dropout(x)
50
+ return x
51
+
52
+
53
+ class AttentionBlock(nn.Module):
54
+
55
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.num_heads = num_heads
59
+ self.post_norm = post_norm
60
+ self.eps = eps
61
+
62
+ # layers
63
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
64
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
65
+ self.ffn = nn.Sequential(
66
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
67
+ nn.Dropout(dropout))
68
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
69
+
70
+ def forward(self, x, mask):
71
+ if self.post_norm:
72
+ x = self.norm1(x + self.attn(x, mask))
73
+ x = self.norm2(x + self.ffn(x))
74
+ else:
75
+ x = x + self.attn(self.norm1(x), mask)
76
+ x = x + self.ffn(self.norm2(x))
77
+ return x
78
+
79
+
80
+ class XLMRoberta(nn.Module):
81
+ """
82
+ XLMRobertaModel with no pooler and no LM head.
83
+ """
84
+
85
+ def __init__(self,
86
+ vocab_size=250002,
87
+ max_seq_len=514,
88
+ type_size=1,
89
+ pad_id=1,
90
+ dim=1024,
91
+ num_heads=16,
92
+ num_layers=24,
93
+ post_norm=True,
94
+ dropout=0.1,
95
+ eps=1e-5):
96
+ super().__init__()
97
+ self.vocab_size = vocab_size
98
+ self.max_seq_len = max_seq_len
99
+ self.type_size = type_size
100
+ self.pad_id = pad_id
101
+ self.dim = dim
102
+ self.num_heads = num_heads
103
+ self.num_layers = num_layers
104
+ self.post_norm = post_norm
105
+ self.eps = eps
106
+
107
+ # embeddings
108
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
109
+ self.type_embedding = nn.Embedding(type_size, dim)
110
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
111
+ self.dropout = nn.Dropout(dropout)
112
+
113
+ # blocks
114
+ self.blocks = nn.ModuleList([
115
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
116
+ for _ in range(num_layers)
117
+ ])
118
+
119
+ # norm layer
120
+ self.norm = nn.LayerNorm(dim, eps=eps)
121
+
122
+ def forward(self, ids):
123
+ """
124
+ ids: [B, L] of torch.LongTensor.
125
+ """
126
+ b, s = ids.shape
127
+ mask = ids.ne(self.pad_id).long()
128
+
129
+ # embeddings
130
+ x = self.token_embedding(ids) + \
131
+ self.type_embedding(torch.zeros_like(ids)) + \
132
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
133
+ if self.post_norm:
134
+ x = self.norm(x)
135
+ x = self.dropout(x)
136
+
137
+ # blocks
138
+ mask = torch.where(
139
+ mask.view(b, 1, 1, s).gt(0), 0.0,
140
+ torch.finfo(x.dtype).min)
141
+ for block in self.blocks:
142
+ x = block(x, mask)
143
+
144
+ # output
145
+ if not self.post_norm:
146
+ x = self.norm(x)
147
+ return x
148
+
149
+
150
+ def xlm_roberta_large(pretrained=False,
151
+ return_tokenizer=False,
152
+ device='cpu',
153
+ **kwargs):
154
+ """
155
+ XLMRobertaLarge adapted from Huggingface.
156
+ """
157
+ # params
158
+ cfg = dict(
159
+ vocab_size=250002,
160
+ max_seq_len=514,
161
+ type_size=1,
162
+ pad_id=1,
163
+ dim=1024,
164
+ num_heads=16,
165
+ num_layers=24,
166
+ post_norm=True,
167
+ dropout=0.1,
168
+ eps=1e-5)
169
+ cfg.update(**kwargs)
170
+
171
+ # init model
172
+ if pretrained:
173
+ from sora import DOWNLOAD_TO_CACHE
174
+
175
+ # init a meta model
176
+ with torch.device('meta'):
177
+ model = XLMRoberta(**cfg)
178
+
179
+ # load checkpoint
180
+ model.load_state_dict(
181
+ torch.load(
182
+ DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
183
+ map_location=device),
184
+ assign=True)
185
+ else:
186
+ # init a model on device
187
+ with torch.device(device):
188
+ model = XLMRoberta(**cfg)
189
+
190
+ # init tokenizer
191
+ if return_tokenizer:
192
+ from sora.data import HuggingfaceTokenizer
193
+ tokenizer = HuggingfaceTokenizer(
194
+ name='xlm-roberta-large',
195
+ seq_len=model.text_len,
196
+ clean='whitespace')
197
+ return model, tokenizer
198
+ else:
199
+ return model
200
+
201
+
202
+
203
+ def pos_interpolate(pos, seq_len):
204
+ if pos.size(1) == seq_len:
205
+ return pos
206
+ else:
207
+ src_grid = int(math.sqrt(pos.size(1)))
208
+ tar_grid = int(math.sqrt(seq_len))
209
+ n = pos.size(1) - src_grid * src_grid
210
+ return torch.cat([
211
+ pos[:, :n],
212
+ F.interpolate(
213
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
214
+ 0, 3, 1, 2),
215
+ size=(tar_grid, tar_grid),
216
+ mode='bicubic',
217
+ align_corners=False).flatten(2).transpose(1, 2)
218
+ ],
219
+ dim=1)
220
+
221
+
222
+ class QuickGELU(nn.Module):
223
+
224
+ def forward(self, x):
225
+ return x * torch.sigmoid(1.702 * x)
226
+
227
+
228
+ class LayerNorm(nn.LayerNorm):
229
+
230
+ def forward(self, x):
231
+ return super().forward(x.float()).type_as(x)
232
+
233
+
234
+ class SelfAttention(nn.Module):
235
+
236
+ def __init__(self,
237
+ dim,
238
+ num_heads,
239
+ causal=False,
240
+ attn_dropout=0.0,
241
+ proj_dropout=0.0):
242
+ assert dim % num_heads == 0
243
+ super().__init__()
244
+ self.dim = dim
245
+ self.num_heads = num_heads
246
+ self.head_dim = dim // num_heads
247
+ self.causal = causal
248
+ self.attn_dropout = attn_dropout
249
+ self.proj_dropout = proj_dropout
250
+
251
+ # layers
252
+ self.to_qkv = nn.Linear(dim, dim * 3)
253
+ self.proj = nn.Linear(dim, dim)
254
+
255
+ def forward(self, x):
256
+ """
257
+ x: [B, L, C].
258
+ """
259
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
260
+
261
+ # compute query, key, value
262
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
263
+
264
+ # compute attention
265
+ p = self.attn_dropout if self.training else 0.0
266
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
267
+ x = x.reshape(b, s, c)
268
+
269
+ # output
270
+ x = self.proj(x)
271
+ x = F.dropout(x, self.proj_dropout, self.training)
272
+ return x
273
+
274
+
275
+ class SwiGLU(nn.Module):
276
+
277
+ def __init__(self, dim, mid_dim):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.mid_dim = mid_dim
281
+
282
+ # layers
283
+ self.fc1 = nn.Linear(dim, mid_dim)
284
+ self.fc2 = nn.Linear(dim, mid_dim)
285
+ self.fc3 = nn.Linear(mid_dim, dim)
286
+
287
+ def forward(self, x):
288
+ x = F.silu(self.fc1(x)) * self.fc2(x)
289
+ x = self.fc3(x)
290
+ return x
291
+
292
+
293
+ class AttentionBlock(nn.Module):
294
+
295
+ def __init__(self,
296
+ dim,
297
+ mlp_ratio,
298
+ num_heads,
299
+ post_norm=False,
300
+ causal=False,
301
+ activation='quick_gelu',
302
+ attn_dropout=0.0,
303
+ proj_dropout=0.0,
304
+ norm_eps=1e-5):
305
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
306
+ super().__init__()
307
+ self.dim = dim
308
+ self.mlp_ratio = mlp_ratio
309
+ self.num_heads = num_heads
310
+ self.post_norm = post_norm
311
+ self.causal = causal
312
+ self.norm_eps = norm_eps
313
+
314
+ # layers
315
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
316
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
317
+ proj_dropout)
318
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
319
+ if activation == 'swi_glu':
320
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
321
+ else:
322
+ self.mlp = nn.Sequential(
323
+ nn.Linear(dim, int(dim * mlp_ratio)),
324
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
325
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
326
+
327
+ def forward(self, x):
328
+ if self.post_norm:
329
+ x = x + self.norm1(self.attn(x))
330
+ x = x + self.norm2(self.mlp(x))
331
+ else:
332
+ x = x + self.attn(self.norm1(x))
333
+ x = x + self.mlp(self.norm2(x))
334
+ return x
335
+
336
+
337
+ class AttentionPool(nn.Module):
338
+
339
+ def __init__(self,
340
+ dim,
341
+ mlp_ratio,
342
+ num_heads,
343
+ activation='gelu',
344
+ proj_dropout=0.0,
345
+ norm_eps=1e-5):
346
+ assert dim % num_heads == 0
347
+ super().__init__()
348
+ self.dim = dim
349
+ self.mlp_ratio = mlp_ratio
350
+ self.num_heads = num_heads
351
+ self.head_dim = dim // num_heads
352
+ self.proj_dropout = proj_dropout
353
+ self.norm_eps = norm_eps
354
+
355
+ # layers
356
+ gain = 1.0 / math.sqrt(dim)
357
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
358
+ self.to_q = nn.Linear(dim, dim)
359
+ self.to_kv = nn.Linear(dim, dim * 2)
360
+ self.proj = nn.Linear(dim, dim)
361
+ self.norm = LayerNorm(dim, eps=norm_eps)
362
+ self.mlp = nn.Sequential(
363
+ nn.Linear(dim, int(dim * mlp_ratio)),
364
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
365
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
366
+
367
+ def forward(self, x):
368
+ """
369
+ x: [B, L, C].
370
+ """
371
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
372
+
373
+ # compute query, key, value
374
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
375
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
376
+
377
+ # compute attention
378
+ x = flash_attention(q, k, v, version=2)
379
+ x = x.reshape(b, 1, c)
380
+
381
+ # output
382
+ x = self.proj(x)
383
+ x = F.dropout(x, self.proj_dropout, self.training)
384
+
385
+ # mlp
386
+ x = x + self.mlp(self.norm(x))
387
+ return x[:, 0]
388
+
389
+
390
+ class VisionTransformer(nn.Module):
391
+
392
+ def __init__(self,
393
+ image_size=224,
394
+ patch_size=16,
395
+ dim=768,
396
+ mlp_ratio=4,
397
+ out_dim=512,
398
+ num_heads=12,
399
+ num_layers=12,
400
+ pool_type='token',
401
+ pre_norm=True,
402
+ post_norm=False,
403
+ activation='quick_gelu',
404
+ attn_dropout=0.0,
405
+ proj_dropout=0.0,
406
+ embedding_dropout=0.0,
407
+ norm_eps=1e-5):
408
+ if image_size % patch_size != 0:
409
+ print(
410
+ '[WARNING] image_size is not divisible by patch_size',
411
+ flush=True)
412
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
413
+ out_dim = out_dim or dim
414
+ super().__init__()
415
+ self.image_size = image_size
416
+ self.patch_size = patch_size
417
+ self.num_patches = (image_size // patch_size)**2
418
+ self.dim = dim
419
+ self.mlp_ratio = mlp_ratio
420
+ self.out_dim = out_dim
421
+ self.num_heads = num_heads
422
+ self.num_layers = num_layers
423
+ self.pool_type = pool_type
424
+ self.post_norm = post_norm
425
+ self.norm_eps = norm_eps
426
+
427
+ # embeddings
428
+ gain = 1.0 / math.sqrt(dim)
429
+ self.patch_embedding = nn.Conv2d(
430
+ 3,
431
+ dim,
432
+ kernel_size=patch_size,
433
+ stride=patch_size,
434
+ bias=not pre_norm)
435
+ if pool_type in ('token', 'token_fc'):
436
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
437
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
438
+ 1, self.num_patches +
439
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
440
+ self.dropout = nn.Dropout(embedding_dropout)
441
+
442
+ # transformer
443
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
444
+ self.transformer = nn.Sequential(*[
445
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
446
+ activation, attn_dropout, proj_dropout, norm_eps)
447
+ for _ in range(num_layers)
448
+ ])
449
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
450
+
451
+ # head
452
+ if pool_type == 'token':
453
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
454
+ elif pool_type == 'token_fc':
455
+ self.head = nn.Linear(dim, out_dim)
456
+ elif pool_type == 'attn_pool':
457
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
458
+ proj_dropout, norm_eps)
459
+
460
+ def forward(self, x, interpolation=False, use_31_block=False):
461
+ b = x.size(0)
462
+
463
+ # embeddings
464
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
465
+ if self.pool_type in ('token', 'token_fc'):
466
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
467
+ if interpolation:
468
+ e = pos_interpolate(self.pos_embedding, x.size(1))
469
+ else:
470
+ e = self.pos_embedding
471
+ e = e.to(dtype=x.dtype, device=x.device)
472
+ x = self.dropout(x + e)
473
+ if self.pre_norm is not None:
474
+ x = self.pre_norm(x)
475
+
476
+ # transformer
477
+ if use_31_block:
478
+ x = self.transformer[:-1](x)
479
+ return x
480
+ else:
481
+ x = self.transformer(x)
482
+ return x
483
+
484
+
485
+ class CLIP(nn.Module):
486
+
487
+ def __init__(self,
488
+ embed_dim=512,
489
+ image_size=224,
490
+ patch_size=16,
491
+ vision_dim=768,
492
+ vision_mlp_ratio=4,
493
+ vision_heads=12,
494
+ vision_layers=12,
495
+ vision_pool='token',
496
+ vision_pre_norm=True,
497
+ vision_post_norm=False,
498
+ vocab_size=49408,
499
+ text_len=77,
500
+ text_dim=512,
501
+ text_mlp_ratio=4,
502
+ text_heads=8,
503
+ text_layers=12,
504
+ text_causal=True,
505
+ text_pool='argmax',
506
+ text_head_bias=False,
507
+ logit_bias=None,
508
+ activation='quick_gelu',
509
+ attn_dropout=0.0,
510
+ proj_dropout=0.0,
511
+ embedding_dropout=0.0,
512
+ norm_eps=1e-5):
513
+ super().__init__()
514
+ self.embed_dim = embed_dim
515
+ self.image_size = image_size
516
+ self.patch_size = patch_size
517
+ self.vision_dim = vision_dim
518
+ self.vision_mlp_ratio = vision_mlp_ratio
519
+ self.vision_heads = vision_heads
520
+ self.vision_layers = vision_layers
521
+ self.vision_pool = vision_pool
522
+ self.vision_pre_norm = vision_pre_norm
523
+ self.vision_post_norm = vision_post_norm
524
+ self.vocab_size = vocab_size
525
+ self.text_len = text_len
526
+ self.text_dim = text_dim
527
+ self.text_mlp_ratio = text_mlp_ratio
528
+ self.text_heads = text_heads
529
+ self.text_layers = text_layers
530
+ self.text_causal = text_causal
531
+ self.text_pool = text_pool
532
+ self.text_head_bias = text_head_bias
533
+ self.norm_eps = norm_eps
534
+
535
+ # models
536
+ self.visual = VisionTransformer(
537
+ image_size=image_size,
538
+ patch_size=patch_size,
539
+ dim=vision_dim,
540
+ mlp_ratio=vision_mlp_ratio,
541
+ out_dim=embed_dim,
542
+ num_heads=vision_heads,
543
+ num_layers=vision_layers,
544
+ pool_type=vision_pool,
545
+ pre_norm=vision_pre_norm,
546
+ post_norm=vision_post_norm,
547
+ activation=activation,
548
+ attn_dropout=attn_dropout,
549
+ proj_dropout=proj_dropout,
550
+ embedding_dropout=embedding_dropout,
551
+ norm_eps=norm_eps)
552
+ self.textual = TextTransformer(
553
+ vocab_size=vocab_size,
554
+ text_len=text_len,
555
+ dim=text_dim,
556
+ mlp_ratio=text_mlp_ratio,
557
+ out_dim=embed_dim,
558
+ num_heads=text_heads,
559
+ num_layers=text_layers,
560
+ causal=text_causal,
561
+ pool_type=text_pool,
562
+ head_bias=text_head_bias,
563
+ activation=activation,
564
+ attn_dropout=attn_dropout,
565
+ proj_dropout=proj_dropout,
566
+ embedding_dropout=embedding_dropout,
567
+ norm_eps=norm_eps)
568
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
569
+ if logit_bias is not None:
570
+ self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
571
+
572
+ # initialize weights
573
+ self.init_weights()
574
+
575
+ def forward(self, imgs, txt_ids):
576
+ """
577
+ imgs: [B, 3, H, W] of torch.float32.
578
+ - mean: [0.48145466, 0.4578275, 0.40821073]
579
+ - std: [0.26862954, 0.26130258, 0.27577711]
580
+ txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
581
+ """
582
+ xi = self.visual(imgs)
583
+ xt = self.textual(txt_ids)
584
+ return xi, xt
585
+
586
+ def init_weights(self):
587
+ # embeddings
588
+ nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
589
+ nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
590
+
591
+ # attentions
592
+ for modality in ['visual', 'textual']:
593
+ dim = self.vision_dim if modality == 'visual' else self.text_dim
594
+ transformer = getattr(self, modality).transformer
595
+ proj_gain = (1.0 / math.sqrt(dim)) * (
596
+ 1.0 / math.sqrt(2 * len(transformer)))
597
+ attn_gain = 1.0 / math.sqrt(dim)
598
+ mlp_gain = 1.0 / math.sqrt(2.0 * dim)
599
+ for block in transformer:
600
+ nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
601
+ nn.init.normal_(block.attn.proj.weight, std=proj_gain)
602
+ nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
603
+ nn.init.normal_(block.mlp[2].weight, std=proj_gain)
604
+
605
+ def param_groups(self):
606
+ groups = [{
607
+ 'params': [
608
+ p for n, p in self.named_parameters()
609
+ if 'norm' in n or n.endswith('bias')
610
+ ],
611
+ 'weight_decay': 0.0
612
+ }, {
613
+ 'params': [
614
+ p for n, p in self.named_parameters()
615
+ if not ('norm' in n or n.endswith('bias'))
616
+ ]
617
+ }]
618
+ return groups
619
+
620
+
621
+ class XLMRobertaWithHead(XLMRoberta):
622
+
623
+ def __init__(self, **kwargs):
624
+ self.out_dim = kwargs.pop('out_dim')
625
+ super().__init__(**kwargs)
626
+
627
+ # head
628
+ mid_dim = (self.dim + self.out_dim) // 2
629
+ self.head = nn.Sequential(
630
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
631
+ nn.Linear(mid_dim, self.out_dim, bias=False))
632
+
633
+ def forward(self, ids):
634
+ # xlm-roberta
635
+ x = super().forward(ids)
636
+
637
+ # average pooling
638
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
639
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
640
+
641
+ # head
642
+ x = self.head(x)
643
+ return x
644
+
645
+
646
+ class XLMRobertaCLIP(nn.Module):
647
+
648
+ def __init__(self,
649
+ embed_dim=1024,
650
+ image_size=224,
651
+ patch_size=14,
652
+ vision_dim=1280,
653
+ vision_mlp_ratio=4,
654
+ vision_heads=16,
655
+ vision_layers=32,
656
+ vision_pool='token',
657
+ vision_pre_norm=True,
658
+ vision_post_norm=False,
659
+ activation='gelu',
660
+ vocab_size=250002,
661
+ max_text_len=514,
662
+ type_size=1,
663
+ pad_id=1,
664
+ text_dim=1024,
665
+ text_heads=16,
666
+ text_layers=24,
667
+ text_post_norm=True,
668
+ text_dropout=0.1,
669
+ attn_dropout=0.0,
670
+ proj_dropout=0.0,
671
+ embedding_dropout=0.0,
672
+ norm_eps=1e-5):
673
+ super().__init__()
674
+ self.embed_dim = embed_dim
675
+ self.image_size = image_size
676
+ self.patch_size = patch_size
677
+ self.vision_dim = vision_dim
678
+ self.vision_mlp_ratio = vision_mlp_ratio
679
+ self.vision_heads = vision_heads
680
+ self.vision_layers = vision_layers
681
+ self.vision_pre_norm = vision_pre_norm
682
+ self.vision_post_norm = vision_post_norm
683
+ self.activation = activation
684
+ self.vocab_size = vocab_size
685
+ self.max_text_len = max_text_len
686
+ self.type_size = type_size
687
+ self.pad_id = pad_id
688
+ self.text_dim = text_dim
689
+ self.text_heads = text_heads
690
+ self.text_layers = text_layers
691
+ self.text_post_norm = text_post_norm
692
+ self.norm_eps = norm_eps
693
+
694
+ # models
695
+ self.visual = VisionTransformer(
696
+ image_size=image_size,
697
+ patch_size=patch_size,
698
+ dim=vision_dim,
699
+ mlp_ratio=vision_mlp_ratio,
700
+ out_dim=embed_dim,
701
+ num_heads=vision_heads,
702
+ num_layers=vision_layers,
703
+ pool_type=vision_pool,
704
+ pre_norm=vision_pre_norm,
705
+ post_norm=vision_post_norm,
706
+ activation=activation,
707
+ attn_dropout=attn_dropout,
708
+ proj_dropout=proj_dropout,
709
+ embedding_dropout=embedding_dropout,
710
+ norm_eps=norm_eps)
711
+ self.textual = None
712
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
713
+
714
+ def forward(self, imgs, txt_ids):
715
+ """
716
+ imgs: [B, 3, H, W] of torch.float32.
717
+ - mean: [0.48145466, 0.4578275, 0.40821073]
718
+ - std: [0.26862954, 0.26130258, 0.27577711]
719
+ txt_ids: [B, L] of torch.long.
720
+ Encoded by data.CLIPTokenizer.
721
+ """
722
+ xi = self.visual(imgs)
723
+ xt = self.textual(txt_ids)
724
+ return xi, xt
725
+
726
+ def param_groups(self):
727
+ groups = [{
728
+ 'params': [
729
+ p for n, p in self.named_parameters()
730
+ if 'norm' in n or n.endswith('bias')
731
+ ],
732
+ 'weight_decay': 0.0
733
+ }, {
734
+ 'params': [
735
+ p for n, p in self.named_parameters()
736
+ if not ('norm' in n or n.endswith('bias'))
737
+ ]
738
+ }]
739
+ return groups
740
+
741
+
742
+ def _clip(pretrained=False,
743
+ pretrained_name=None,
744
+ model_cls=CLIP,
745
+ return_transforms=False,
746
+ return_tokenizer=False,
747
+ tokenizer_padding='eos',
748
+ dtype=torch.float32,
749
+ device='cpu',
750
+ **kwargs):
751
+ # init model
752
+ if pretrained and pretrained_name:
753
+ from sora import BUCKET, DOWNLOAD_TO_CACHE
754
+
755
+ # init a meta model
756
+ with torch.device('meta'):
757
+ model = model_cls(**kwargs)
758
+
759
+ # checkpoint path
760
+ checkpoint = f'models/clip/{pretrained_name}'
761
+ if dtype in (torch.float16, torch.bfloat16):
762
+ suffix = '-' + {
763
+ torch.float16: 'fp16',
764
+ torch.bfloat16: 'bf16'
765
+ }[dtype]
766
+ if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
767
+ checkpoint = f'{checkpoint}{suffix}'
768
+ checkpoint += '.pth'
769
+
770
+ # load
771
+ model.load_state_dict(
772
+ torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
773
+ assign=True,
774
+ strict=False)
775
+ else:
776
+ # init a model on device
777
+ with torch.device(device):
778
+ model = model_cls(**kwargs)
779
+
780
+ # set device
781
+ output = (model,)
782
+
783
+ # init transforms
784
+ if return_transforms:
785
+ # mean and std
786
+ if 'siglip' in pretrained_name.lower():
787
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
788
+ else:
789
+ mean = [0.48145466, 0.4578275, 0.40821073]
790
+ std = [0.26862954, 0.26130258, 0.27577711]
791
+
792
+ # transforms
793
+ transforms = T.Compose([
794
+ T.Resize((model.image_size, model.image_size),
795
+ interpolation=T.InterpolationMode.BICUBIC),
796
+ T.ToTensor(),
797
+ T.Normalize(mean=mean, std=std)
798
+ ])
799
+ output += (transforms,)
800
+
801
+ # init tokenizer
802
+ if return_tokenizer:
803
+ from sora import data
804
+ if 'siglip' in pretrained_name.lower():
805
+ tokenizer = data.HuggingfaceTokenizer(
806
+ name=f'timm/{pretrained_name}',
807
+ seq_len=model.text_len,
808
+ clean='canonicalize')
809
+ elif 'xlm' in pretrained_name.lower():
810
+ tokenizer = data.HuggingfaceTokenizer(
811
+ name='xlm-roberta-large',
812
+ seq_len=model.max_text_len - 2,
813
+ clean='whitespace')
814
+ elif 'mba' in pretrained_name.lower():
815
+ tokenizer = data.HuggingfaceTokenizer(
816
+ name='facebook/xlm-roberta-xl',
817
+ seq_len=model.max_text_len - 2,
818
+ clean='whitespace')
819
+ else:
820
+ tokenizer = data.CLIPTokenizer(
821
+ seq_len=model.text_len, padding=tokenizer_padding)
822
+ output += (tokenizer,)
823
+ return output[0] if len(output) == 1 else output
824
+
825
+
826
+ def clip_xlm_roberta_vit_h_14(
827
+ pretrained=False,
828
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
829
+ **kwargs):
830
+ cfg = dict(
831
+ embed_dim=1024,
832
+ image_size=224,
833
+ patch_size=14,
834
+ vision_dim=1280,
835
+ vision_mlp_ratio=4,
836
+ vision_heads=16,
837
+ vision_layers=32,
838
+ vision_pool='token',
839
+ activation='gelu',
840
+ vocab_size=250002,
841
+ max_text_len=514,
842
+ type_size=1,
843
+ pad_id=1,
844
+ text_dim=1024,
845
+ text_heads=16,
846
+ text_layers=24,
847
+ text_post_norm=True,
848
+ text_dropout=0.1,
849
+ attn_dropout=0.0,
850
+ proj_dropout=0.0,
851
+ embedding_dropout=0.0)
852
+ cfg.update(**kwargs)
853
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
854
+
855
+
856
+ class WanImageEncoder(torch.nn.Module):
857
+
858
+ def __init__(self):
859
+ super().__init__()
860
+ # init model
861
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
862
+ pretrained=False,
863
+ return_transforms=True,
864
+ return_tokenizer=False,
865
+ dtype=torch.float32,
866
+ device="cpu")
867
+
868
+ def encode_image(self, videos):
869
+ # preprocess
870
+ size = (self.model.image_size,) * 2
871
+ videos = torch.cat([
872
+ F.interpolate(
873
+ u,
874
+ size=size,
875
+ mode='bicubic',
876
+ align_corners=False) for u in videos
877
+ ])
878
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
879
+
880
+ # forward
881
+ out = self.model.visual(videos, use_31_block=True)
882
+ return out
883
+
884
+ @staticmethod
885
+ def state_dict_converter():
886
+ return WanImageEncoderStateDictConverter()
887
+
888
+
889
+ class WanImageEncoderStateDictConverter:
890
+ def __init__(self):
891
+ pass
892
+
893
+ def from_diffusers(self, state_dict):
894
+ return state_dict
895
+
896
+ def from_civitai(self, state_dict):
897
+ state_dict_ = {}
898
+ for name, param in state_dict.items():
899
+ if name.startswith("textual."):
900
+ continue
901
+ name = "model." + name
902
+ state_dict_[name] = param
903
+ return state_dict_
904
+
diffsynth/models/wan_video_text_encoder.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fp16_clamp(x):
9
+ if x.dtype == torch.float16 and torch.isinf(x).any():
10
+ clamp = torch.finfo(x.dtype).max - 1000
11
+ x = torch.clamp(x, min=-clamp, max=clamp)
12
+ return x
13
+
14
+
15
+ class GELU(nn.Module):
16
+
17
+ def forward(self, x):
18
+ return 0.5 * x * (1.0 + torch.tanh(
19
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
20
+
21
+
22
+ class T5LayerNorm(nn.Module):
23
+
24
+ def __init__(self, dim, eps=1e-6):
25
+ super(T5LayerNorm, self).__init__()
26
+ self.dim = dim
27
+ self.eps = eps
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+
30
+ def forward(self, x):
31
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
32
+ self.eps)
33
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
34
+ x = x.type_as(self.weight)
35
+ return self.weight * x
36
+
37
+
38
+ class T5Attention(nn.Module):
39
+
40
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
41
+ assert dim_attn % num_heads == 0
42
+ super(T5Attention, self).__init__()
43
+ self.dim = dim
44
+ self.dim_attn = dim_attn
45
+ self.num_heads = num_heads
46
+ self.head_dim = dim_attn // num_heads
47
+
48
+ # layers
49
+ self.q = nn.Linear(dim, dim_attn, bias=False)
50
+ self.k = nn.Linear(dim, dim_attn, bias=False)
51
+ self.v = nn.Linear(dim, dim_attn, bias=False)
52
+ self.o = nn.Linear(dim_attn, dim, bias=False)
53
+ self.dropout = nn.Dropout(dropout)
54
+
55
+ def forward(self, x, context=None, mask=None, pos_bias=None):
56
+ """
57
+ x: [B, L1, C].
58
+ context: [B, L2, C] or None.
59
+ mask: [B, L2] or [B, L1, L2] or None.
60
+ """
61
+ # check inputs
62
+ context = x if context is None else context
63
+ b, n, c = x.size(0), self.num_heads, self.head_dim
64
+
65
+ # compute query, key, value
66
+ q = self.q(x).view(b, -1, n, c)
67
+ k = self.k(context).view(b, -1, n, c)
68
+ v = self.v(context).view(b, -1, n, c)
69
+
70
+ # attention bias
71
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
72
+ if pos_bias is not None:
73
+ attn_bias += pos_bias
74
+ if mask is not None:
75
+ assert mask.ndim in [2, 3]
76
+ mask = mask.view(b, 1, 1,
77
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
78
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
79
+
80
+ # compute attention (T5 does not use scaling)
81
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
82
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
83
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
84
+
85
+ # output
86
+ x = x.reshape(b, -1, n * c)
87
+ x = self.o(x)
88
+ x = self.dropout(x)
89
+ return x
90
+
91
+
92
+ class T5FeedForward(nn.Module):
93
+
94
+ def __init__(self, dim, dim_ffn, dropout=0.1):
95
+ super(T5FeedForward, self).__init__()
96
+ self.dim = dim
97
+ self.dim_ffn = dim_ffn
98
+
99
+ # layers
100
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
101
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
102
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
103
+ self.dropout = nn.Dropout(dropout)
104
+
105
+ def forward(self, x):
106
+ x = self.fc1(x) * self.gate(x)
107
+ x = self.dropout(x)
108
+ x = self.fc2(x)
109
+ x = self.dropout(x)
110
+ return x
111
+
112
+
113
+ class T5SelfAttention(nn.Module):
114
+
115
+ def __init__(self,
116
+ dim,
117
+ dim_attn,
118
+ dim_ffn,
119
+ num_heads,
120
+ num_buckets,
121
+ shared_pos=True,
122
+ dropout=0.1):
123
+ super(T5SelfAttention, self).__init__()
124
+ self.dim = dim
125
+ self.dim_attn = dim_attn
126
+ self.dim_ffn = dim_ffn
127
+ self.num_heads = num_heads
128
+ self.num_buckets = num_buckets
129
+ self.shared_pos = shared_pos
130
+
131
+ # layers
132
+ self.norm1 = T5LayerNorm(dim)
133
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
134
+ self.norm2 = T5LayerNorm(dim)
135
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
136
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
137
+ num_buckets, num_heads, bidirectional=True)
138
+
139
+ def forward(self, x, mask=None, pos_bias=None):
140
+ e = pos_bias if self.shared_pos else self.pos_embedding(
141
+ x.size(1), x.size(1))
142
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
143
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
144
+ return x
145
+
146
+
147
+ class T5RelativeEmbedding(nn.Module):
148
+
149
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
150
+ super(T5RelativeEmbedding, self).__init__()
151
+ self.num_buckets = num_buckets
152
+ self.num_heads = num_heads
153
+ self.bidirectional = bidirectional
154
+ self.max_dist = max_dist
155
+
156
+ # layers
157
+ self.embedding = nn.Embedding(num_buckets, num_heads)
158
+
159
+ def forward(self, lq, lk):
160
+ device = self.embedding.weight.device
161
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
162
+ # torch.arange(lq).unsqueeze(1).to(device)
163
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
164
+ torch.arange(lq, device=device).unsqueeze(1)
165
+ rel_pos = self._relative_position_bucket(rel_pos)
166
+ rel_pos_embeds = self.embedding(rel_pos)
167
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
168
+ 0) # [1, N, Lq, Lk]
169
+ return rel_pos_embeds.contiguous()
170
+
171
+ def _relative_position_bucket(self, rel_pos):
172
+ # preprocess
173
+ if self.bidirectional:
174
+ num_buckets = self.num_buckets // 2
175
+ rel_buckets = (rel_pos > 0).long() * num_buckets
176
+ rel_pos = torch.abs(rel_pos)
177
+ else:
178
+ num_buckets = self.num_buckets
179
+ rel_buckets = 0
180
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
181
+
182
+ # embeddings for small and large positions
183
+ max_exact = num_buckets // 2
184
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
185
+ math.log(self.max_dist / max_exact) *
186
+ (num_buckets - max_exact)).long()
187
+ rel_pos_large = torch.min(
188
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
189
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
190
+ return rel_buckets
191
+
192
+ def init_weights(m):
193
+ if isinstance(m, T5LayerNorm):
194
+ nn.init.ones_(m.weight)
195
+ elif isinstance(m, T5FeedForward):
196
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
197
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
198
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
199
+ elif isinstance(m, T5Attention):
200
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
201
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
202
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
203
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
204
+ elif isinstance(m, T5RelativeEmbedding):
205
+ nn.init.normal_(
206
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
207
+
208
+
209
+ class WanTextEncoder(torch.nn.Module):
210
+
211
+ def __init__(self,
212
+ vocab=256384,
213
+ dim=4096,
214
+ dim_attn=4096,
215
+ dim_ffn=10240,
216
+ num_heads=64,
217
+ num_layers=24,
218
+ num_buckets=32,
219
+ shared_pos=False,
220
+ dropout=0.1):
221
+ super(WanTextEncoder, self).__init__()
222
+ self.dim = dim
223
+ self.dim_attn = dim_attn
224
+ self.dim_ffn = dim_ffn
225
+ self.num_heads = num_heads
226
+ self.num_layers = num_layers
227
+ self.num_buckets = num_buckets
228
+ self.shared_pos = shared_pos
229
+
230
+ # layers
231
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
232
+ else nn.Embedding(vocab, dim)
233
+ self.pos_embedding = T5RelativeEmbedding(
234
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
235
+ self.dropout = nn.Dropout(dropout)
236
+ self.blocks = nn.ModuleList([
237
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
238
+ shared_pos, dropout) for _ in range(num_layers)
239
+ ])
240
+ self.norm = T5LayerNorm(dim)
241
+
242
+ # initialize weights
243
+ self.apply(init_weights)
244
+
245
+ def forward(self, ids, mask=None):
246
+ x = self.token_embedding(ids)
247
+ x = self.dropout(x)
248
+ e = self.pos_embedding(x.size(1),
249
+ x.size(1)) if self.shared_pos else None
250
+ for block in self.blocks:
251
+ x = block(x, mask, pos_bias=e)
252
+ x = self.norm(x)
253
+ x = self.dropout(x)
254
+ return x
255
+
256
+ @staticmethod
257
+ def state_dict_converter():
258
+ return WanTextEncoderStateDictConverter()
259
+
260
+
261
+ class WanTextEncoderStateDictConverter:
262
+ def __init__(self):
263
+ pass
264
+
265
+ def from_diffusers(self, state_dict):
266
+ return state_dict
267
+
268
+ def from_civitai(self, state_dict):
269
+ return state_dict
diffsynth/models/wan_video_vae.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+ CACHE_T = 2
9
+
10
+
11
+ def check_is_instance(model, module_class):
12
+ if isinstance(model, module_class):
13
+ return True
14
+ if hasattr(model, "module") and isinstance(model.module, module_class):
15
+ return True
16
+ return False
17
+
18
+
19
+ def block_causal_mask(x, block_size):
20
+ # params
21
+ b, n, s, _, device = *x.size(), x.device
22
+ assert s % block_size == 0
23
+ num_blocks = s // block_size
24
+
25
+ # build mask
26
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
27
+ for i in range(num_blocks):
28
+ mask[:, :,
29
+ i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
30
+ return mask
31
+
32
+
33
+ class CausalConv3d(nn.Conv3d):
34
+ """
35
+ Causal 3d convolusion.
36
+ """
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
41
+ self.padding[1], 2 * self.padding[0], 0)
42
+ self.padding = (0, 0, 0)
43
+
44
+ def forward(self, x, cache_x=None):
45
+ padding = list(self._padding)
46
+ if cache_x is not None and self._padding[4] > 0:
47
+ cache_x = cache_x.to(x.device)
48
+ x = torch.cat([cache_x, x], dim=2)
49
+ padding[4] -= cache_x.shape[2]
50
+ x = F.pad(x, padding)
51
+
52
+ return super().forward(x)
53
+
54
+
55
+ class RMS_norm(nn.Module):
56
+
57
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
58
+ super().__init__()
59
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
60
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
61
+
62
+ self.channel_first = channel_first
63
+ self.scale = dim**0.5
64
+ self.gamma = nn.Parameter(torch.ones(shape))
65
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
66
+
67
+ def forward(self, x):
68
+ return F.normalize(
69
+ x, dim=(1 if self.channel_first else
70
+ -1)) * self.scale * self.gamma + self.bias
71
+
72
+
73
+ class Upsample(nn.Upsample):
74
+
75
+ def forward(self, x):
76
+ """
77
+ Fix bfloat16 support for nearest neighbor interpolation.
78
+ """
79
+ return super().forward(x.float()).type_as(x)
80
+
81
+
82
+ class Resample(nn.Module):
83
+
84
+ def __init__(self, dim, mode):
85
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
86
+ 'downsample3d')
87
+ super().__init__()
88
+ self.dim = dim
89
+ self.mode = mode
90
+
91
+ # layers
92
+ if mode == 'upsample2d':
93
+ self.resample = nn.Sequential(
94
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
95
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
96
+ elif mode == 'upsample3d':
97
+ self.resample = nn.Sequential(
98
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
99
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
100
+ self.time_conv = CausalConv3d(dim,
101
+ dim * 2, (3, 1, 1),
102
+ padding=(1, 0, 0))
103
+
104
+ elif mode == 'downsample2d':
105
+ self.resample = nn.Sequential(
106
+ nn.ZeroPad2d((0, 1, 0, 1)),
107
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108
+ elif mode == 'downsample3d':
109
+ self.resample = nn.Sequential(
110
+ nn.ZeroPad2d((0, 1, 0, 1)),
111
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
112
+ self.time_conv = CausalConv3d(dim,
113
+ dim, (3, 1, 1),
114
+ stride=(2, 1, 1),
115
+ padding=(0, 0, 0))
116
+
117
+ else:
118
+ self.resample = nn.Identity()
119
+
120
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
121
+ b, c, t, h, w = x.size()
122
+ if self.mode == 'upsample3d':
123
+ if feat_cache is not None:
124
+ idx = feat_idx[0]
125
+ if feat_cache[idx] is None:
126
+ feat_cache[idx] = 'Rep'
127
+ feat_idx[0] += 1
128
+ else:
129
+
130
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
131
+ if cache_x.shape[2] < 2 and feat_cache[
132
+ idx] is not None and feat_cache[idx] != 'Rep':
133
+ # cache last frame of last two chunk
134
+ cache_x = torch.cat([
135
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
136
+ cache_x.device), cache_x
137
+ ],
138
+ dim=2)
139
+ if cache_x.shape[2] < 2 and feat_cache[
140
+ idx] is not None and feat_cache[idx] == 'Rep':
141
+ cache_x = torch.cat([
142
+ torch.zeros_like(cache_x).to(cache_x.device),
143
+ cache_x
144
+ ],
145
+ dim=2)
146
+ if feat_cache[idx] == 'Rep':
147
+ x = self.time_conv(x)
148
+ else:
149
+ x = self.time_conv(x, feat_cache[idx])
150
+ feat_cache[idx] = cache_x
151
+ feat_idx[0] += 1
152
+
153
+ x = x.reshape(b, 2, c, t, h, w)
154
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
155
+ 3)
156
+ x = x.reshape(b, c, t * 2, h, w)
157
+ t = x.shape[2]
158
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
159
+ x = self.resample(x)
160
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
161
+
162
+ if self.mode == 'downsample3d':
163
+ if feat_cache is not None:
164
+ idx = feat_idx[0]
165
+ if feat_cache[idx] is None:
166
+ feat_cache[idx] = x.clone()
167
+ feat_idx[0] += 1
168
+ else:
169
+ cache_x = x[:, :, -1:, :, :].clone()
170
+ x = self.time_conv(
171
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
172
+ feat_cache[idx] = cache_x
173
+ feat_idx[0] += 1
174
+ return x
175
+
176
+ def init_weight(self, conv):
177
+ conv_weight = conv.weight
178
+ nn.init.zeros_(conv_weight)
179
+ c1, c2, t, h, w = conv_weight.size()
180
+ one_matrix = torch.eye(c1, c2)
181
+ init_matrix = one_matrix
182
+ nn.init.zeros_(conv_weight)
183
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
184
+ conv.weight.data.copy_(conv_weight)
185
+ nn.init.zeros_(conv.bias.data)
186
+
187
+ def init_weight2(self, conv):
188
+ conv_weight = conv.weight.data
189
+ nn.init.zeros_(conv_weight)
190
+ c1, c2, t, h, w = conv_weight.size()
191
+ init_matrix = torch.eye(c1 // 2, c2)
192
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
193
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
194
+ conv.weight.data.copy_(conv_weight)
195
+ nn.init.zeros_(conv.bias.data)
196
+
197
+
198
+ class ResidualBlock(nn.Module):
199
+
200
+ def __init__(self, in_dim, out_dim, dropout=0.0):
201
+ super().__init__()
202
+ self.in_dim = in_dim
203
+ self.out_dim = out_dim
204
+
205
+ # layers
206
+ self.residual = nn.Sequential(
207
+ RMS_norm(in_dim, images=False), nn.SiLU(),
208
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
209
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
210
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
211
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
212
+ if in_dim != out_dim else nn.Identity()
213
+
214
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
215
+ h = self.shortcut(x)
216
+ for layer in self.residual:
217
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
218
+ idx = feat_idx[0]
219
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
220
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
221
+ # cache last frame of last two chunk
222
+ cache_x = torch.cat([
223
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
224
+ cache_x.device), cache_x
225
+ ],
226
+ dim=2)
227
+ x = layer(x, feat_cache[idx])
228
+ feat_cache[idx] = cache_x
229
+ feat_idx[0] += 1
230
+ else:
231
+ x = layer(x)
232
+ return x + h
233
+
234
+
235
+ class AttentionBlock(nn.Module):
236
+ """
237
+ Causal self-attention with a single head.
238
+ """
239
+
240
+ def __init__(self, dim):
241
+ super().__init__()
242
+ self.dim = dim
243
+
244
+ # layers
245
+ self.norm = RMS_norm(dim)
246
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
247
+ self.proj = nn.Conv2d(dim, dim, 1)
248
+
249
+ # zero out the last layer params
250
+ nn.init.zeros_(self.proj.weight)
251
+
252
+ def forward(self, x):
253
+ identity = x
254
+ b, c, t, h, w = x.size()
255
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
256
+ x = self.norm(x)
257
+ # compute query, key, value
258
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
259
+ 0, 1, 3, 2).contiguous().chunk(3, dim=-1)
260
+
261
+ # apply attention
262
+ x = F.scaled_dot_product_attention(
263
+ q,
264
+ k,
265
+ v,
266
+ #attn_mask=block_causal_mask(q, block_size=h * w)
267
+ )
268
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
269
+
270
+ # output
271
+ x = self.proj(x)
272
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
273
+ return x + identity
274
+
275
+
276
+ class Encoder3d(nn.Module):
277
+
278
+ def __init__(self,
279
+ dim=128,
280
+ z_dim=4,
281
+ dim_mult=[1, 2, 4, 4],
282
+ num_res_blocks=2,
283
+ attn_scales=[],
284
+ temperal_downsample=[True, True, False],
285
+ dropout=0.0):
286
+ super().__init__()
287
+ self.dim = dim
288
+ self.z_dim = z_dim
289
+ self.dim_mult = dim_mult
290
+ self.num_res_blocks = num_res_blocks
291
+ self.attn_scales = attn_scales
292
+ self.temperal_downsample = temperal_downsample
293
+
294
+ # dimensions
295
+ dims = [dim * u for u in [1] + dim_mult]
296
+ scale = 1.0
297
+
298
+ # init block
299
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
300
+
301
+ # downsample blocks
302
+ downsamples = []
303
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
304
+ # residual (+attention) blocks
305
+ for _ in range(num_res_blocks):
306
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
307
+ if scale in attn_scales:
308
+ downsamples.append(AttentionBlock(out_dim))
309
+ in_dim = out_dim
310
+
311
+ # downsample block
312
+ if i != len(dim_mult) - 1:
313
+ mode = 'downsample3d' if temperal_downsample[
314
+ i] else 'downsample2d'
315
+ downsamples.append(Resample(out_dim, mode=mode))
316
+ scale /= 2.0
317
+ self.downsamples = nn.Sequential(*downsamples)
318
+
319
+ # middle blocks
320
+ self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
321
+ AttentionBlock(out_dim),
322
+ ResidualBlock(out_dim, out_dim, dropout))
323
+
324
+ # output blocks
325
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
326
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
327
+
328
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
329
+ if feat_cache is not None:
330
+ idx = feat_idx[0]
331
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
332
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
333
+ # cache last frame of last two chunk
334
+ cache_x = torch.cat([
335
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
336
+ cache_x.device), cache_x
337
+ ],
338
+ dim=2)
339
+ x = self.conv1(x, feat_cache[idx])
340
+ feat_cache[idx] = cache_x
341
+ feat_idx[0] += 1
342
+ else:
343
+ x = self.conv1(x)
344
+
345
+ ## downsamples
346
+ for layer in self.downsamples:
347
+ if feat_cache is not None:
348
+ x = layer(x, feat_cache, feat_idx)
349
+ else:
350
+ x = layer(x)
351
+
352
+ ## middle
353
+ for layer in self.middle:
354
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
355
+ x = layer(x, feat_cache, feat_idx)
356
+ else:
357
+ x = layer(x)
358
+
359
+ ## head
360
+ for layer in self.head:
361
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
362
+ idx = feat_idx[0]
363
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
364
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
365
+ # cache last frame of last two chunk
366
+ cache_x = torch.cat([
367
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
368
+ cache_x.device), cache_x
369
+ ],
370
+ dim=2)
371
+ x = layer(x, feat_cache[idx])
372
+ feat_cache[idx] = cache_x
373
+ feat_idx[0] += 1
374
+ else:
375
+ x = layer(x)
376
+ return x
377
+
378
+
379
+ class Decoder3d(nn.Module):
380
+
381
+ def __init__(self,
382
+ dim=128,
383
+ z_dim=4,
384
+ dim_mult=[1, 2, 4, 4],
385
+ num_res_blocks=2,
386
+ attn_scales=[],
387
+ temperal_upsample=[False, True, True],
388
+ dropout=0.0):
389
+ super().__init__()
390
+ self.dim = dim
391
+ self.z_dim = z_dim
392
+ self.dim_mult = dim_mult
393
+ self.num_res_blocks = num_res_blocks
394
+ self.attn_scales = attn_scales
395
+ self.temperal_upsample = temperal_upsample
396
+
397
+ # dimensions
398
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
399
+ scale = 1.0 / 2**(len(dim_mult) - 2)
400
+
401
+ # init block
402
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
403
+
404
+ # middle blocks
405
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
406
+ AttentionBlock(dims[0]),
407
+ ResidualBlock(dims[0], dims[0], dropout))
408
+
409
+ # upsample blocks
410
+ upsamples = []
411
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
412
+ # residual (+attention) blocks
413
+ if i == 1 or i == 2 or i == 3:
414
+ in_dim = in_dim // 2
415
+ for _ in range(num_res_blocks + 1):
416
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
417
+ if scale in attn_scales:
418
+ upsamples.append(AttentionBlock(out_dim))
419
+ in_dim = out_dim
420
+
421
+ # upsample block
422
+ if i != len(dim_mult) - 1:
423
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
424
+ upsamples.append(Resample(out_dim, mode=mode))
425
+ scale *= 2.0
426
+ self.upsamples = nn.Sequential(*upsamples)
427
+
428
+ # output blocks
429
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
430
+ CausalConv3d(out_dim, 3, 3, padding=1))
431
+
432
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
433
+ ## conv1
434
+ if feat_cache is not None:
435
+ idx = feat_idx[0]
436
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
437
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
438
+ # cache last frame of last two chunk
439
+ cache_x = torch.cat([
440
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
441
+ cache_x.device), cache_x
442
+ ],
443
+ dim=2)
444
+ x = self.conv1(x, feat_cache[idx])
445
+ feat_cache[idx] = cache_x
446
+ feat_idx[0] += 1
447
+ else:
448
+ x = self.conv1(x)
449
+
450
+ ## middle
451
+ for layer in self.middle:
452
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
453
+ x = layer(x, feat_cache, feat_idx)
454
+ else:
455
+ x = layer(x)
456
+
457
+ ## upsamples
458
+ for layer in self.upsamples:
459
+ if feat_cache is not None:
460
+ x = layer(x, feat_cache, feat_idx)
461
+ else:
462
+ x = layer(x)
463
+
464
+ ## head
465
+ for layer in self.head:
466
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
467
+ idx = feat_idx[0]
468
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
469
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
470
+ # cache last frame of last two chunk
471
+ cache_x = torch.cat([
472
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
473
+ cache_x.device), cache_x
474
+ ],
475
+ dim=2)
476
+ x = layer(x, feat_cache[idx])
477
+ feat_cache[idx] = cache_x
478
+ feat_idx[0] += 1
479
+ else:
480
+ x = layer(x)
481
+ return x
482
+
483
+
484
+ def count_conv3d(model):
485
+ count = 0
486
+ for m in model.modules():
487
+ if check_is_instance(m, CausalConv3d):
488
+ count += 1
489
+ return count
490
+
491
+
492
+ class VideoVAE_(nn.Module):
493
+
494
+ def __init__(self,
495
+ dim=96,
496
+ z_dim=16,
497
+ dim_mult=[1, 2, 4, 4],
498
+ num_res_blocks=2,
499
+ attn_scales=[],
500
+ temperal_downsample=[False, True, True],
501
+ dropout=0.0):
502
+ super().__init__()
503
+ self.dim = dim
504
+ self.z_dim = z_dim
505
+ self.dim_mult = dim_mult
506
+ self.num_res_blocks = num_res_blocks
507
+ self.attn_scales = attn_scales
508
+ self.temperal_downsample = temperal_downsample
509
+ self.temperal_upsample = temperal_downsample[::-1]
510
+
511
+ # modules
512
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
513
+ attn_scales, self.temperal_downsample, dropout)
514
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
515
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
516
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
517
+ attn_scales, self.temperal_upsample, dropout)
518
+
519
+ def forward(self, x):
520
+ mu, log_var = self.encode(x)
521
+ z = self.reparameterize(mu, log_var)
522
+ x_recon = self.decode(z)
523
+ return x_recon, mu, log_var
524
+
525
+ def encode(self, x, scale): # x: B, C, T, H, W
526
+ self.clear_cache()
527
+ ## cache
528
+ t = x.shape[2]
529
+ iter_ = 1 + (t - 1) // 4
530
+
531
+ for i in range(iter_):
532
+ self._enc_conv_idx = [0]
533
+ if i == 0:
534
+ out = self.encoder(x[:, :, :1, :, :],
535
+ feat_cache=self._enc_feat_map,
536
+ feat_idx=self._enc_conv_idx)
537
+ else:
538
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
539
+ feat_cache=self._enc_feat_map,
540
+ feat_idx=self._enc_conv_idx)
541
+ out = torch.cat([out, out_], 2)
542
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
543
+ if isinstance(scale[0], torch.Tensor):
544
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
545
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
546
+ 1, self.z_dim, 1, 1, 1)
547
+ else:
548
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
549
+ mu = (mu - scale[0]) * scale[1]
550
+ return mu
551
+
552
+ def decode(self, z, scale):
553
+ self.clear_cache()
554
+ # z: [b,c,t,h,w]
555
+ if isinstance(scale[0], torch.Tensor):
556
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
557
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
558
+ 1, self.z_dim, 1, 1, 1)
559
+ else:
560
+ scale = scale.to(dtype=z.dtype, device=z.device)
561
+ z = z / scale[1] + scale[0]
562
+ iter_ = z.shape[2]
563
+ x = self.conv2(z)
564
+ for i in range(iter_):
565
+ self._conv_idx = [0]
566
+ if i == 0:
567
+ out = self.decoder(x[:, :, i:i + 1, :, :],
568
+ feat_cache=self._feat_map,
569
+ feat_idx=self._conv_idx)
570
+ else:
571
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
572
+ feat_cache=self._feat_map,
573
+ feat_idx=self._conv_idx)
574
+ out = torch.cat([out, out_], 2) # may add tensor offload
575
+ return out
576
+
577
+ def reparameterize(self, mu, log_var):
578
+ std = torch.exp(0.5 * log_var)
579
+ eps = torch.randn_like(std)
580
+ return eps * std + mu
581
+
582
+ def sample(self, imgs, deterministic=False):
583
+ mu, log_var = self.encode(imgs)
584
+ if deterministic:
585
+ return mu
586
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
587
+ return mu + std * torch.randn_like(std)
588
+
589
+ def clear_cache(self):
590
+ self._conv_num = count_conv3d(self.decoder)
591
+ self._conv_idx = [0]
592
+ self._feat_map = [None] * self._conv_num
593
+ # cache encode
594
+ self._enc_conv_num = count_conv3d(self.encoder)
595
+ self._enc_conv_idx = [0]
596
+ self._enc_feat_map = [None] * self._enc_conv_num
597
+
598
+
599
+ class WanVideoVAE(nn.Module):
600
+
601
+ def __init__(self, z_dim=16):
602
+ super().__init__()
603
+
604
+ mean = [
605
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
606
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
607
+ ]
608
+ std = [
609
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
610
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
611
+ ]
612
+ self.mean = torch.tensor(mean)
613
+ self.std = torch.tensor(std)
614
+ self.scale = [self.mean, 1.0 / self.std]
615
+
616
+ # init model
617
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
618
+ self.upsampling_factor = 8
619
+
620
+
621
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
622
+ x = torch.ones((length,))
623
+ if not left_bound:
624
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
625
+ if not right_bound:
626
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
627
+ return x
628
+
629
+
630
+ def build_mask(self, data, is_bound, border_width):
631
+ _, _, _, H, W = data.shape
632
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
633
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
634
+
635
+ h = repeat(h, "H -> H W", H=H, W=W)
636
+ w = repeat(w, "W -> H W", H=H, W=W)
637
+
638
+ mask = torch.stack([h, w]).min(dim=0).values
639
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
640
+ return mask
641
+
642
+
643
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
644
+ _, _, T, H, W = hidden_states.shape
645
+ size_h, size_w = tile_size
646
+ stride_h, stride_w = tile_stride
647
+
648
+ # Split tasks
649
+ tasks = []
650
+ for h in range(0, H, stride_h):
651
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
652
+ for w in range(0, W, stride_w):
653
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
654
+ h_, w_ = h + size_h, w + size_w
655
+ tasks.append((h, h_, w, w_))
656
+
657
+ data_device = "cpu"
658
+ computation_device = device
659
+
660
+ out_T = T * 4 - 3
661
+ weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
662
+ values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
663
+
664
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
665
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
666
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
667
+
668
+ mask = self.build_mask(
669
+ hidden_states_batch,
670
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
671
+ border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
672
+ ).to(dtype=hidden_states.dtype, device=data_device)
673
+
674
+ target_h = h * self.upsampling_factor
675
+ target_w = w * self.upsampling_factor
676
+ values[
677
+ :,
678
+ :,
679
+ :,
680
+ target_h:target_h + hidden_states_batch.shape[3],
681
+ target_w:target_w + hidden_states_batch.shape[4],
682
+ ] += hidden_states_batch * mask
683
+ weight[
684
+ :,
685
+ :,
686
+ :,
687
+ target_h: target_h + hidden_states_batch.shape[3],
688
+ target_w: target_w + hidden_states_batch.shape[4],
689
+ ] += mask
690
+ values = values / weight
691
+ values = values.float().clamp_(-1, 1)
692
+ return values
693
+
694
+
695
+ def tiled_encode(self, video, device, tile_size, tile_stride):
696
+ _, _, T, H, W = video.shape
697
+ size_h, size_w = tile_size
698
+ stride_h, stride_w = tile_stride
699
+
700
+ # Split tasks
701
+ tasks = []
702
+ for h in range(0, H, stride_h):
703
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
704
+ for w in range(0, W, stride_w):
705
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
706
+ h_, w_ = h + size_h, w + size_w
707
+ tasks.append((h, h_, w, w_))
708
+
709
+ data_device = "cpu"
710
+ computation_device = device
711
+
712
+ out_T = (T + 3) // 4
713
+ weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
714
+ values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
715
+
716
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
717
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
718
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
719
+
720
+ mask = self.build_mask(
721
+ hidden_states_batch,
722
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
723
+ border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
724
+ ).to(dtype=video.dtype, device=data_device)
725
+
726
+ target_h = h // self.upsampling_factor
727
+ target_w = w // self.upsampling_factor
728
+ values[
729
+ :,
730
+ :,
731
+ :,
732
+ target_h:target_h + hidden_states_batch.shape[3],
733
+ target_w:target_w + hidden_states_batch.shape[4],
734
+ ] += hidden_states_batch * mask
735
+ weight[
736
+ :,
737
+ :,
738
+ :,
739
+ target_h: target_h + hidden_states_batch.shape[3],
740
+ target_w: target_w + hidden_states_batch.shape[4],
741
+ ] += mask
742
+ values = values / weight
743
+ values = values.float()
744
+ return values
745
+
746
+
747
+ def single_encode(self, video, device):
748
+ video = video.to(device)
749
+ x = self.model.encode(video, self.scale)
750
+ return x.float()
751
+
752
+
753
+ def single_decode(self, hidden_state, device):
754
+ hidden_state = hidden_state.to(device)
755
+ video = self.model.decode(hidden_state, self.scale)
756
+ return video.float().clamp_(-1, 1)
757
+
758
+
759
+ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
760
+
761
+ videos = [video.to("cpu") for video in videos]
762
+ hidden_states = []
763
+ for video in videos:
764
+ video = video.unsqueeze(0)
765
+ if tiled:
766
+ tile_size = (tile_size[0] * 8, tile_size[1] * 8)
767
+ tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
768
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
769
+ else:
770
+ hidden_state = self.single_encode(video, device)
771
+ hidden_state = hidden_state.squeeze(0)
772
+ hidden_states.append(hidden_state)
773
+ hidden_states = torch.stack(hidden_states)
774
+ return hidden_states
775
+
776
+
777
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
778
+ hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
779
+ videos = []
780
+ for hidden_state in hidden_states:
781
+ hidden_state = hidden_state.unsqueeze(0)
782
+ if tiled:
783
+ video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
784
+ else:
785
+ video = self.single_decode(hidden_state, device)
786
+ video = video.squeeze(0)
787
+ videos.append(video)
788
+ videos = torch.stack(videos)
789
+ return videos
790
+
791
+
792
+ @staticmethod
793
+ def state_dict_converter():
794
+ return WanVideoVAEStateDictConverter()
795
+
796
+
797
+ class WanVideoVAEStateDictConverter:
798
+
799
+ def __init__(self):
800
+ pass
801
+
802
+ def from_civitai(self, state_dict):
803
+ state_dict_ = {}
804
+ if 'model_state' in state_dict:
805
+ state_dict = state_dict['model_state']
806
+ for name in state_dict:
807
+ state_dict_['model.' + name] = state_dict[name]
808
+ return state_dict_