jbilcke-hf HF staff commited on
Commit
e238883
·
verified ·
1 Parent(s): 063139e

Upload 5 files

Browse files
skyreelsinfer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class TaskType(str, Enum):
5
+ T2V = "text2video"
6
+ I2V = "image2video"
skyreelsinfer/offload.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+ import os
4
+ import time
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+ from diffusers.pipelines import DiffusionPipeline
9
+ from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
10
+
11
+
12
+ @dataclass
13
+ class OffloadConfig:
14
+ # high_cpu_memory: Whether to use pinned memory for offload optimization. This can effectively prevent increased model offload latency caused by memory swapping.
15
+ high_cpu_memory: bool = True
16
+ # parameters_level: Whether to enable parameter-level offload. This further reduces VRAM requirements but may result in increased latency.
17
+ parameters_level: bool = False
18
+ # compiler_transformer: Whether to enable compilation optimization for the transformer.
19
+ compiler_transformer: bool = False
20
+ compiler_cache: str = "/tmp/compile_cache"
21
+
22
+
23
+ class HfHook:
24
+ def __init__(self):
25
+ device_id = os.environ.get("LOCAL_RANK", 0)
26
+ self.execution_device = f"cuda:{device_id}"
27
+
28
+ def detach_hook(self, module):
29
+ pass
30
+
31
+
32
+ class Offload:
33
+ def __init__(self) -> None:
34
+ self.active_models = []
35
+ self.active_models_ids = []
36
+ self.active_subcaches = {}
37
+ self.models = {}
38
+ self.verboseLevel = 0
39
+ self.models_to_quantize = []
40
+ self.pinned_modules_data = {}
41
+ self.blocks_of_modules = {}
42
+ self.blocks_of_modules_sizes = {}
43
+ self.compile = False
44
+ self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
45
+ self.last_reserved_mem_check = 0
46
+ self.loaded_blocks = {}
47
+ self.prev_blocks_names = {}
48
+ self.next_blocks_names = {}
49
+ device_id = os.environ.get("LOCAL_RANK", 0)
50
+ self.device_id = f"cuda:{device_id}"
51
+ self.default_stream = torch.cuda.default_stream(self.device_id) # torch.cuda.current_stream()
52
+ self.transfer_stream = torch.cuda.Stream()
53
+ self.async_transfers = False
54
+ self.last_run_model = None
55
+
56
+ @classmethod
57
+ def offload(cls, pipeline: DiffusionPipeline, config: OffloadConfig = OffloadConfig()):
58
+ """
59
+ Enable offloading for multiple models in the pipeline, supporting video generation inference on user-level GPUs.
60
+ pipe: the pipeline object
61
+ config: offload strategy configuration
62
+ """
63
+ self = cls()
64
+ self.pinned_modules_data = {}
65
+ if config.parameters_level:
66
+ model_budgets = {
67
+ "transformer": 600 * 1024 * 1024,
68
+ "text_encoder": 3 * 1024 * 1024 * 1024,
69
+ "text_encoder_2": 3 * 1024 * 1024 * 1024,
70
+ }
71
+ self.async_transfers = True
72
+ else:
73
+ model_budgets = {}
74
+
75
+ device_id = os.getenv("LOCAL_RANK", 0)
76
+ torch.set_default_device(f"cuda:{device_id}")
77
+ pipeline.hf_device_map = torch.device(f"cuda:{device_id}")
78
+ pipe_or_dict_of_modules = pipeline.components
79
+ if config.compiler_transformer:
80
+ pipeline.transformer.to("cuda")
81
+ models = {
82
+ k: v
83
+ for k, v in pipe_or_dict_of_modules.items()
84
+ if isinstance(v, torch.nn.Module) and not (config.compiler_transformer and k == "transformer")
85
+ }
86
+ print_info = {k: type(v) for k, v in models.items()}
87
+ print(f"offload models: {print_info}")
88
+ if config.compiler_transformer:
89
+ pipeline.text_encoder.to("cpu")
90
+ pipeline.text_encoder_2.to("cpu")
91
+ torch.cuda.empty_cache()
92
+ pipeline.transformer.to("cuda")
93
+ pipeline.vae.to("cuda")
94
+
95
+ def move_text_encoder_to_gpu(pipe):
96
+ torch.cuda.empty_cache()
97
+ pipe.text_encoder.to("cuda")
98
+ pipe.text_encoder_2.to("cuda")
99
+
100
+ def move_text_encoder_to_cpu(pipe):
101
+ pipe.text_encoder.to("cpu")
102
+ pipe.text_encoder_2.to("cpu")
103
+ torch.cuda.empty_cache()
104
+
105
+ setattr(pipeline, "text_encoder_to_cpu", functools.partial(move_text_encoder_to_cpu, pipeline))
106
+ setattr(pipeline, "text_encoder_to_gpu", functools.partial(move_text_encoder_to_gpu, pipeline))
107
+
108
+ for k, module in pipe_or_dict_of_modules.items():
109
+ if isinstance(module, torch.nn.Module):
110
+ for submodule_name, submodule in module.named_modules():
111
+ if not hasattr(submodule, "_hf_hook"):
112
+ setattr(submodule, "_hf_hook", HfHook())
113
+ return self
114
+
115
+ sizeofbfloat16 = torch.bfloat16.itemsize
116
+ modelPinned = config.high_cpu_memory
117
+ # Pin in RAM models
118
+ # Calculate the VRAM requirements of the computational modules to determine whether parameters-level offload is necessary.
119
+ for model_name, curr_model in models.items():
120
+ curr_model.to("cpu").eval()
121
+ pinned_parameters_data = {}
122
+ current_model_size = 0
123
+ print(f"{model_name} move to pinned memory:{modelPinned}")
124
+ for p in curr_model.parameters():
125
+ if isinstance(p, AffineQuantizedTensor):
126
+ if not modelPinned and p.tensor_impl.scale.dtype == torch.float32:
127
+ p.tensor_impl.scale = p.tensor_impl.scale.to(torch.bfloat16)
128
+ current_model_size += torch.numel(p.tensor_impl.scale) * sizeofbfloat16
129
+ current_model_size += torch.numel(p.tensor_impl.float8_data) * sizeofbfloat16 / 2
130
+ if modelPinned:
131
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.pin_memory()
132
+ p.tensor_impl.scale = p.tensor_impl.scale.pin_memory()
133
+ pinned_parameters_data[p] = [p.tensor_impl.float8_data, p.tensor_impl.scale]
134
+ else:
135
+ p.data = p.data.to(torch.bfloat16) if p.data.dtype == torch.float32 else p.data.to(p.data.dtype)
136
+ current_model_size += torch.numel(p.data) * p.data.element_size()
137
+ if modelPinned:
138
+ p.data = p.data.pin_memory()
139
+ pinned_parameters_data[p] = p.data
140
+
141
+ for buffer in curr_model.buffers():
142
+ buffer.data = (
143
+ buffer.data.to(torch.bfloat16)
144
+ if buffer.data.dtype == torch.float32
145
+ else buffer.data.to(buffer.data.dtype)
146
+ )
147
+ current_model_size += torch.numel(buffer.data) * buffer.data.element_size()
148
+ if modelPinned:
149
+ buffer.data = buffer.data.pin_memory()
150
+
151
+ if model_name not in self.models:
152
+ self.models[model_name] = curr_model
153
+
154
+ curr_model_budget = model_budgets.get(model_name, 0)
155
+ if curr_model_budget > 0 and curr_model_budget > current_model_size:
156
+ model_budgets[model_name] = 0
157
+
158
+ if modelPinned:
159
+ pinned_buffers_data = {b: b.data for b in curr_model.buffers()}
160
+ pinned_parameters_data.update(pinned_buffers_data)
161
+ self.pinned_modules_data[model_name] = pinned_parameters_data
162
+ gc.collect()
163
+ torch.cuda.empty_cache()
164
+
165
+ # if config.compiler_transformer:
166
+ # module = pipeline.transformer
167
+ # print("wrap transformer forward")
168
+ # # gpu model wrap
169
+ # for submodule_name, submodule in module.named_modules():
170
+ # if not hasattr(submodule, "_hf_hook"):
171
+ # setattr(submodule, "_hf_hook", HfHook())
172
+ #
173
+ # forward_method = getattr(module, "forward")
174
+ #
175
+ # def wrap_unload_all(*args, **kwargs):
176
+ # self.unload_all("transformer")
177
+ # return forward_method(*args, **kwargs)
178
+ #
179
+ # setattr(module, "forward", functools.update_wrapper(wrap_unload_all, forward_method))
180
+
181
+ # wrap forward methods
182
+ for model_name, curr_model in models.items():
183
+ current_budget = model_budgets.get(model_name, 0)
184
+ current_size = 0
185
+ self.loaded_blocks[model_name] = None
186
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name, cur_blocks_seq = None, None, None, -1
187
+
188
+ for submodule_name, submodule in curr_model.named_modules():
189
+ # create a fake accelerate parameter so that the _execution_device property returns always "cuda"
190
+ if not hasattr(submodule, "_hf_hook"):
191
+ setattr(submodule, "_hf_hook", HfHook())
192
+
193
+ if not submodule_name:
194
+ continue
195
+
196
+ # usr parameters-level offload
197
+ if current_budget > 0:
198
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
199
+ if cur_blocks_prefix == None:
200
+ cur_blocks_prefix = submodule_name + "."
201
+ else:
202
+ if not submodule_name.startswith(cur_blocks_prefix):
203
+ cur_blocks_prefix = submodule_name + "."
204
+ cur_blocks_name, cur_blocks_seq = None, -1
205
+ else:
206
+ if cur_blocks_prefix is not None:
207
+ if submodule_name.startswith(cur_blocks_prefix):
208
+ num = int(submodule_name[len(cur_blocks_prefix) :].split(".")[0])
209
+ if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
210
+ prev_blocks_name = cur_blocks_name
211
+ cur_blocks_name = cur_blocks_prefix + str(num)
212
+ cur_blocks_seq = num
213
+ else:
214
+ cur_blocks_prefix = None
215
+ prev_blocks_name = None
216
+ cur_blocks_name = None
217
+ cur_blocks_seq = -1
218
+
219
+ if hasattr(submodule, "forward"):
220
+ submodule_forward = getattr(submodule, "forward")
221
+ if not callable(submodule_forward):
222
+ print("***")
223
+ continue
224
+ if len(submodule_name.split(".")) == 1:
225
+ self.hook_me(submodule, curr_model, model_name, submodule_name, submodule_forward)
226
+ else:
227
+ self.hook_me_light(
228
+ submodule, model_name, cur_blocks_name, submodule_forward, context=submodule_name
229
+ )
230
+ current_size = self.add_module_to_blocks(model_name, cur_blocks_name, submodule, prev_blocks_name)
231
+
232
+ gc.collect()
233
+ torch.cuda.empty_cache()
234
+ return self
235
+
236
+ def add_module_to_blocks(self, model_name, blocks_name, submodule, prev_block_name):
237
+
238
+ entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
239
+ if entry_name in self.blocks_of_modules:
240
+ blocks_params = self.blocks_of_modules[entry_name]
241
+ blocks_params_size = self.blocks_of_modules_sizes[entry_name]
242
+ else:
243
+ blocks_params = []
244
+ self.blocks_of_modules[entry_name] = blocks_params
245
+ blocks_params_size = 0
246
+ if blocks_name != None:
247
+ prev_entry_name = None if prev_block_name == None else model_name + "/" + prev_block_name
248
+ self.prev_blocks_names[entry_name] = prev_entry_name
249
+ if not prev_block_name == None:
250
+ self.next_blocks_names[prev_entry_name] = entry_name
251
+
252
+ for p in submodule.parameters(recurse=False):
253
+ blocks_params.append(p)
254
+ if isinstance(p, AffineQuantizedTensor):
255
+ blocks_params_size += p.tensor_impl.float8_data.nbytes
256
+ blocks_params_size += p.tensor_impl.scale.nbytes
257
+ else:
258
+ blocks_params_size += p.data.nbytes
259
+
260
+ for p in submodule.buffers(recurse=False):
261
+ blocks_params.append(p)
262
+ blocks_params_size += p.data.nbytes
263
+
264
+ self.blocks_of_modules_sizes[entry_name] = blocks_params_size
265
+
266
+ return blocks_params_size
267
+
268
+ def can_model_be_cotenant(self, model_name):
269
+ cotenants_map = {
270
+ "text_encoder": ["vae", "text_encoder_2"],
271
+ "text_encoder_2": ["vae", "text_encoder"],
272
+ }
273
+ potential_cotenants = cotenants_map.get(model_name, None)
274
+ if potential_cotenants is None:
275
+ return False
276
+ for existing_cotenant in self.active_models_ids:
277
+ if existing_cotenant not in potential_cotenants:
278
+ return False
279
+ return True
280
+
281
+ @torch.compiler.disable()
282
+ def gpu_load_blocks(self, model_name, blocks_name, async_load=False):
283
+ if blocks_name != None:
284
+ self.loaded_blocks[model_name] = blocks_name
285
+
286
+ def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream=None):
287
+ with torch.cuda.stream(stream_to_use):
288
+ for p in blocks_params:
289
+ if isinstance(p, AffineQuantizedTensor):
290
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.cuda(
291
+ non_blocking=True, device=self.device_id
292
+ )
293
+ p.tensor_impl.scale = p.tensor_impl.scale.cuda(non_blocking=True, device=self.device_id)
294
+ else:
295
+ p.data = p.data.cuda(non_blocking=True, device=self.device_id)
296
+
297
+ if record_for_stream != None:
298
+ if isinstance(p, AffineQuantizedTensor):
299
+ p.tensor_impl.float8_data.record_stream(record_for_stream)
300
+ p.tensor_impl.scale.record_stream(record_for_stream)
301
+ else:
302
+ p.data.record_stream(record_for_stream)
303
+
304
+ entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
305
+ if self.verboseLevel >= 2:
306
+ model = self.models[model_name]
307
+ model_name = model._get_name()
308
+ print(f"Loading model {entry_name} ({model_name}) in GPU")
309
+
310
+ if self.async_transfers and blocks_name != None:
311
+ first = self.prev_blocks_names[entry_name] == None
312
+ next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
313
+ if first:
314
+ cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
315
+ torch.cuda.synchronize()
316
+
317
+ if next_blocks_entry != None:
318
+ cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry])
319
+
320
+ else:
321
+ cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
322
+ torch.cuda.synchronize()
323
+
324
+ @torch.compiler.disable()
325
+ def gpu_unload_blocks(self, model_name, blocks_name):
326
+ if blocks_name != None:
327
+ self.loaded_blocks[model_name] = None
328
+
329
+ blocks_name = model_name if blocks_name is None else model_name + "/" + blocks_name
330
+
331
+ if self.verboseLevel >= 2:
332
+ model = self.models[model_name]
333
+ model_name = model._get_name()
334
+ print(f"Unloading model {blocks_name} ({model_name}) from GPU")
335
+
336
+ blocks_params = self.blocks_of_modules[blocks_name]
337
+
338
+ if model_name in self.pinned_modules_data:
339
+ pinned_parameters_data = self.pinned_modules_data[model_name]
340
+ for p in blocks_params:
341
+ if isinstance(p, AffineQuantizedTensor):
342
+ data = pinned_parameters_data[p]
343
+ p.tensor_impl.float8_data = data[0]
344
+ p.tensor_impl.scale = data[1]
345
+ else:
346
+ p.data = pinned_parameters_data[p]
347
+ else:
348
+ for p in blocks_params:
349
+ if isinstance(p, AffineQuantizedTensor):
350
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.cpu()
351
+ p.tensor_impl.scale = p.tensor_impl.scale.cpu()
352
+ else:
353
+ p.data = p.data.cpu()
354
+
355
+ @torch.compiler.disable()
356
+ def gpu_load(self, model_name):
357
+ model = self.models[model_name]
358
+ self.active_models.append(model)
359
+ self.active_models_ids.append(model_name)
360
+
361
+ self.gpu_load_blocks(model_name, None)
362
+
363
+ # torch.cuda.current_stream().synchronize()
364
+
365
+ @torch.compiler.disable()
366
+ def unload_all(self, model_name: str):
367
+ if len(self.active_models_ids) == 0 and self.last_run_model == model_name:
368
+ self.last_run_model = model_name
369
+ return
370
+ for model_name in self.active_models_ids:
371
+ self.gpu_unload_blocks(model_name, None)
372
+ loaded_block = self.loaded_blocks[model_name]
373
+ if loaded_block != None:
374
+ self.gpu_unload_blocks(model_name, loaded_block)
375
+ self.loaded_blocks[model_name] = None
376
+
377
+ self.active_models = []
378
+ self.active_models_ids = []
379
+ self.active_subcaches = []
380
+ torch.cuda.empty_cache()
381
+ gc.collect()
382
+ self.last_reserved_mem_check = time.time()
383
+ self.last_run_model = model_name
384
+
385
+ def move_args_to_gpu(self, *args, **kwargs):
386
+ new_args = []
387
+ new_kwargs = {}
388
+ for arg in args:
389
+ if torch.is_tensor(arg):
390
+ if arg.dtype == torch.float32:
391
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
392
+ else:
393
+ arg = arg.cuda(non_blocking=True, device=self.device_id)
394
+ new_args.append(arg)
395
+
396
+ for k in kwargs:
397
+ arg = kwargs[k]
398
+ if torch.is_tensor(arg):
399
+ if arg.dtype == torch.float32:
400
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
401
+ else:
402
+ arg = arg.cuda(non_blocking=True, device=self.device_id)
403
+ new_kwargs[k] = arg
404
+
405
+ return new_args, new_kwargs
406
+
407
+ def ready_to_check_mem(self):
408
+ if self.compile:
409
+ return
410
+ cur_clock = time.time()
411
+ # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
412
+ if (cur_clock - self.last_reserved_mem_check) < 0.200:
413
+ return False
414
+ self.last_reserved_mem_check = cur_clock
415
+ return True
416
+
417
+ def empty_cache_if_needed(self):
418
+ mem_reserved = torch.cuda.memory_reserved()
419
+ mem_threshold = 0.9 * self.device_mem_capacity
420
+ if mem_reserved >= mem_threshold:
421
+ mem_allocated = torch.cuda.memory_allocated()
422
+ if mem_allocated <= 0.70 * mem_reserved:
423
+ torch.cuda.empty_cache()
424
+ tm = time.time()
425
+ if self.verboseLevel >= 2:
426
+ print(f"Empty Cuda cache at {tm}")
427
+
428
+ def any_param_or_buffer(self, target_module: torch.nn.Module):
429
+
430
+ for _ in target_module.parameters(recurse=False):
431
+ return True
432
+
433
+ for _ in target_module.buffers(recurse=False):
434
+ return True
435
+
436
+ return False
437
+
438
+ def hook_me_light(self, target_module, model_name, blocks_name, previous_method, context):
439
+
440
+ anyParam = self.any_param_or_buffer(target_module)
441
+
442
+ def check_empty_cuda_cache(module, *args, **kwargs):
443
+ if self.ready_to_check_mem():
444
+ self.empty_cache_if_needed()
445
+ return previous_method(*args, **kwargs)
446
+
447
+ def load_module_blocks(module, *args, **kwargs):
448
+ if blocks_name == None:
449
+ if self.ready_to_check_mem():
450
+ self.empty_cache_if_needed()
451
+ else:
452
+ loaded_block = self.loaded_blocks[model_name]
453
+ if loaded_block == None or loaded_block != blocks_name:
454
+ if loaded_block != None:
455
+ self.gpu_unload_blocks(model_name, loaded_block)
456
+ if self.ready_to_check_mem():
457
+ self.empty_cache_if_needed()
458
+ self.loaded_blocks[model_name] = blocks_name
459
+ self.gpu_load_blocks(model_name, blocks_name)
460
+ return previous_method(*args, **kwargs)
461
+
462
+ if hasattr(target_module, "_mm_id"):
463
+ orig_model_name = getattr(target_module, "_mm_id")
464
+ if self.verboseLevel >= 2:
465
+ print(
466
+ f"Model '{model_name}' shares module '{target_module._get_name()}' with module '{orig_model_name}' "
467
+ )
468
+ assert not anyParam
469
+ return
470
+ setattr(target_module, "_mm_id", model_name)
471
+
472
+ if blocks_name != None and anyParam:
473
+ setattr(
474
+ target_module,
475
+ "forward",
476
+ functools.update_wrapper(functools.partial(load_module_blocks, target_module), previous_method),
477
+ )
478
+ # print(f"new cache:{blocks_name}")
479
+ else:
480
+ setattr(
481
+ target_module,
482
+ "forward",
483
+ functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method),
484
+ )
485
+
486
+ def hook_me(self, target_module, model, model_name, module_id, previous_method):
487
+ def check_change_module(module, *args, **kwargs):
488
+ performEmptyCacheTest = False
489
+ if not model_name in self.active_models_ids:
490
+ new_model_name = getattr(module, "_mm_id")
491
+ if not self.can_model_be_cotenant(new_model_name):
492
+ self.unload_all(model_name)
493
+ performEmptyCacheTest = False
494
+ self.gpu_load(new_model_name)
495
+ args, kwargs = self.move_args_to_gpu(*args, **kwargs)
496
+ if performEmptyCacheTest:
497
+ self.empty_cache_if_needed()
498
+ return previous_method(*args, **kwargs)
499
+
500
+ if hasattr(target_module, "_mm_id"):
501
+ return
502
+ setattr(target_module, "_mm_id", model_name)
503
+
504
+ setattr(
505
+ target_module,
506
+ "forward",
507
+ functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method),
508
+ )
509
+
510
+ if not self.verboseLevel >= 1:
511
+ return
512
+
513
+ if module_id == None or module_id == "":
514
+ model_name = model._get_name()
515
+ print(f"Hooked in model '{model_name}' ({model_name})")
skyreelsinfer/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline_skyreels_video import SkyreelsVideoPipeline
skyreelsinfer/pipelines/pipeline_skyreels_video.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from typing import Callable
3
+ from typing import Dict
4
+ from typing import List
5
+ from typing import Optional
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers import HunyuanVideoPipeline
11
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
12
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import HunyuanVideoPipelineOutput
13
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import MultiPipelineCallbacks
14
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import PipelineCallback
15
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import retrieve_timesteps
16
+ from PIL import Image
17
+
18
+
19
+ def resizecrop(image, th, tw):
20
+ w, h = image.size
21
+ if h / w > th / tw:
22
+ new_w = int(w)
23
+ new_h = int(new_w * th / tw)
24
+ else:
25
+ new_h = int(h)
26
+ new_w = int(new_h * tw / th)
27
+ left = (w - new_w) / 2
28
+ top = (h - new_h) / 2
29
+ right = (w + new_w) / 2
30
+ bottom = (h + new_h) / 2
31
+ image = image.crop((left, top, right, bottom))
32
+ return image
33
+
34
+
35
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
36
+ """
37
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
38
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
39
+ """
40
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
41
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
42
+ # rescale the results from guidance (fixes overexposure)
43
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
44
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
45
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
46
+ return noise_cfg
47
+
48
+
49
+ class SkyreelsVideoPipeline(HunyuanVideoPipeline):
50
+ """
51
+ support i2v and t2v
52
+ support true_cfg
53
+ """
54
+
55
+ @property
56
+ def guidance_rescale(self):
57
+ return self._guidance_rescale
58
+
59
+ @property
60
+ def clip_skip(self):
61
+ return self._clip_skip
62
+
63
+ @property
64
+ def do_classifier_free_guidance(self):
65
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
66
+ return self._guidance_scale > 1
67
+
68
+ def encode_prompt(
69
+ self,
70
+ prompt: Union[str, List[str]],
71
+ do_classifier_free_guidance: bool,
72
+ negative_prompt: str = "",
73
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
74
+ num_videos_per_prompt: int = 1,
75
+ prompt_embeds: Optional[torch.Tensor] = None,
76
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
77
+ prompt_attention_mask: Optional[torch.Tensor] = None,
78
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
79
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
80
+ negative_attention_mask: Optional[torch.Tensor] = None,
81
+ device: Optional[torch.device] = None,
82
+ dtype: Optional[torch.dtype] = None,
83
+ max_sequence_length: int = 256,
84
+ ):
85
+ num_hidden_layers_to_skip = self.clip_skip if self.clip_skip is not None else 0
86
+ print(f"num_hidden_layers_to_skip: {num_hidden_layers_to_skip}")
87
+ if prompt_embeds is None:
88
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
89
+ prompt,
90
+ prompt_template,
91
+ num_videos_per_prompt,
92
+ device=device,
93
+ dtype=dtype,
94
+ num_hidden_layers_to_skip=num_hidden_layers_to_skip,
95
+ max_sequence_length=max_sequence_length,
96
+ )
97
+ if negative_prompt_embeds is None and do_classifier_free_guidance:
98
+ negative_prompt_embeds, negative_attention_mask = self._get_llama_prompt_embeds(
99
+ negative_prompt,
100
+ prompt_template,
101
+ num_videos_per_prompt,
102
+ device=device,
103
+ dtype=dtype,
104
+ num_hidden_layers_to_skip=num_hidden_layers_to_skip,
105
+ max_sequence_length=max_sequence_length,
106
+ )
107
+ if self.text_encoder_2 is not None and pooled_prompt_embeds is None:
108
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
109
+ prompt,
110
+ num_videos_per_prompt,
111
+ device=device,
112
+ dtype=dtype,
113
+ max_sequence_length=77,
114
+ )
115
+ if negative_pooled_prompt_embeds is None and do_classifier_free_guidance:
116
+ negative_pooled_prompt_embeds = self._get_clip_prompt_embeds(
117
+ negative_prompt,
118
+ num_videos_per_prompt,
119
+ device=device,
120
+ dtype=dtype,
121
+ max_sequence_length=77,
122
+ )
123
+ return (
124
+ prompt_embeds,
125
+ prompt_attention_mask,
126
+ negative_prompt_embeds,
127
+ negative_attention_mask,
128
+ pooled_prompt_embeds,
129
+ negative_pooled_prompt_embeds,
130
+ )
131
+
132
+ def image_latents(
133
+ self,
134
+ initial_image,
135
+ batch_size,
136
+ height,
137
+ width,
138
+ device,
139
+ dtype,
140
+ num_channels_latents,
141
+ video_length,
142
+ ):
143
+ initial_image = initial_image.unsqueeze(2)
144
+ image_latents = self.vae.encode(initial_image).latent_dist.sample()
145
+ if hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor:
146
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
147
+ else:
148
+ image_latents = image_latents * self.vae.config.scaling_factor
149
+ padding_shape = (
150
+ batch_size,
151
+ num_channels_latents,
152
+ video_length - 1,
153
+ int(height) // self.vae_scale_factor_spatial,
154
+ int(width) // self.vae_scale_factor_spatial,
155
+ )
156
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
157
+ image_latents = torch.cat([image_latents, latent_padding], dim=2)
158
+ return image_latents
159
+
160
+ @torch.no_grad()
161
+ def __call__(
162
+ self,
163
+ prompt: str,
164
+ negative_prompt: str = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
165
+ height: int = 720,
166
+ width: int = 1280,
167
+ num_frames: int = 129,
168
+ num_inference_steps: int = 50,
169
+ sigmas: List[float] = None,
170
+ guidance_scale: float = 1.0,
171
+ num_videos_per_prompt: Optional[int] = 1,
172
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
173
+ latents: Optional[torch.Tensor] = None,
174
+ prompt_embeds: Optional[torch.Tensor] = None,
175
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
176
+ prompt_attention_mask: Optional[torch.Tensor] = None,
177
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
178
+ negative_attention_mask: Optional[torch.Tensor] = None,
179
+ output_type: Optional[str] = "pil",
180
+ return_dict: bool = True,
181
+ attention_kwargs: Optional[Dict[str, Any]] = None,
182
+ guidance_rescale: float = 0.0,
183
+ clip_skip: Optional[int] = 2,
184
+ callback_on_step_end: Optional[
185
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
186
+ ] = None,
187
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
188
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
189
+ max_sequence_length: int = 256,
190
+ embedded_guidance_scale: Optional[float] = 6.0,
191
+ image: Optional[Union[torch.Tensor, Image.Image]] = None,
192
+ cfg_for: bool = False,
193
+ ):
194
+ if hasattr(self, "text_encoder_to_gpu"):
195
+ self.text_encoder_to_gpu()
196
+
197
+ if image is not None and isinstance(image, Image.Image):
198
+ image = resizecrop(image, height, width)
199
+
200
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
201
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
202
+
203
+ # 1. Check inputs. Raise error if not correct
204
+ self.check_inputs(
205
+ prompt,
206
+ None,
207
+ height,
208
+ width,
209
+ prompt_embeds,
210
+ callback_on_step_end_tensor_inputs,
211
+ prompt_template,
212
+ )
213
+ # add negative prompt check
214
+ if negative_prompt is not None and negative_prompt_embeds is not None:
215
+ raise ValueError(
216
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
217
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
218
+ )
219
+
220
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
221
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
222
+ raise ValueError(
223
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
224
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
225
+ f" {negative_prompt_embeds.shape}."
226
+ )
227
+
228
+ self._guidance_scale = guidance_scale
229
+ self._guidance_rescale = guidance_rescale
230
+ self._clip_skip = clip_skip
231
+ self._attention_kwargs = attention_kwargs
232
+ self._interrupt = False
233
+
234
+ device = self._execution_device
235
+
236
+ # 2. Define call parameters
237
+ if prompt is not None and isinstance(prompt, str):
238
+ batch_size = 1
239
+ elif prompt is not None and isinstance(prompt, list):
240
+ batch_size = len(prompt)
241
+ else:
242
+ batch_size = prompt_embeds.shape[0]
243
+
244
+ # 3. Encode input prompt
245
+ (
246
+ prompt_embeds,
247
+ prompt_attention_mask,
248
+ negative_prompt_embeds,
249
+ negative_attention_mask,
250
+ pooled_prompt_embeds,
251
+ negative_pooled_prompt_embeds,
252
+ ) = self.encode_prompt(
253
+ prompt=prompt,
254
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
255
+ negative_prompt=negative_prompt,
256
+ prompt_template=prompt_template,
257
+ num_videos_per_prompt=num_videos_per_prompt,
258
+ prompt_embeds=prompt_embeds,
259
+ prompt_attention_mask=prompt_attention_mask,
260
+ negative_prompt_embeds=negative_prompt_embeds,
261
+ negative_attention_mask=negative_attention_mask,
262
+ device=device,
263
+ max_sequence_length=max_sequence_length,
264
+ )
265
+
266
+ transformer_dtype = self.transformer.dtype
267
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
268
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
269
+ if pooled_prompt_embeds is not None:
270
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
271
+
272
+ ## Embeddings are concatenated to form a batch.
273
+ if self.do_classifier_free_guidance:
274
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
275
+ negative_attention_mask = negative_attention_mask.to(transformer_dtype)
276
+ if negative_pooled_prompt_embeds is not None:
277
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
278
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
279
+ if prompt_attention_mask is not None:
280
+ prompt_attention_mask = torch.cat([negative_attention_mask, prompt_attention_mask])
281
+ if pooled_prompt_embeds is not None:
282
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
283
+
284
+ # 4. Prepare timesteps
285
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
286
+ timesteps, num_inference_steps = retrieve_timesteps(
287
+ self.scheduler,
288
+ num_inference_steps,
289
+ device,
290
+ sigmas=sigmas,
291
+ )
292
+
293
+ # 5. Prepare latent variables
294
+ num_channels_latents = self.transformer.config.in_channels
295
+ if image is not None:
296
+ num_channels_latents = int(num_channels_latents / 2)
297
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
298
+ device, dtype=prompt_embeds.dtype
299
+ )
300
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
301
+ latents = self.prepare_latents(
302
+ batch_size * num_videos_per_prompt,
303
+ num_channels_latents,
304
+ height,
305
+ width,
306
+ num_latent_frames,
307
+ torch.float32,
308
+ device,
309
+ generator,
310
+ latents,
311
+ )
312
+ # add image latents
313
+ if image is not None:
314
+ image_latents = self.image_latents(
315
+ image, batch_size, height, width, device, torch.float32, num_channels_latents, num_latent_frames
316
+ )
317
+
318
+ image_latents = image_latents.to(transformer_dtype)
319
+ else:
320
+ image_latents = None
321
+
322
+ # 6. Prepare guidance condition
323
+ if self.do_classifier_free_guidance:
324
+ guidance = (
325
+ torch.tensor([embedded_guidance_scale] * latents.shape[0] * 2, dtype=transformer_dtype, device=device)
326
+ * 1000.0
327
+ )
328
+ else:
329
+ guidance = (
330
+ torch.tensor([embedded_guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device)
331
+ * 1000.0
332
+ )
333
+
334
+ # 7. Denoising loop
335
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
336
+ self._num_timesteps = len(timesteps)
337
+
338
+ if hasattr(self, "text_encoder_to_cpu"):
339
+ self.text_encoder_to_cpu()
340
+
341
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
342
+ for i, t in enumerate(timesteps):
343
+ if self.interrupt:
344
+ continue
345
+
346
+ latents = latents.to(transformer_dtype)
347
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
348
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
349
+ # timestep = t.expand(latents.shape[0]).to(latents.dtype)
350
+ if image_latents is not None:
351
+ latent_image_input = (
352
+ torch.cat([image_latents] * 2) if self.do_classifier_free_guidance else image_latents
353
+ )
354
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1)
355
+ timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32)
356
+ if cfg_for and self.do_classifier_free_guidance:
357
+ noise_pred_list = []
358
+ for idx in range(latent_model_input.shape[0]):
359
+ noise_pred_uncond = self.transformer(
360
+ hidden_states=latent_model_input[idx].unsqueeze(0),
361
+ timestep=timestep[idx].unsqueeze(0),
362
+ encoder_hidden_states=prompt_embeds[idx].unsqueeze(0),
363
+ encoder_attention_mask=prompt_attention_mask[idx].unsqueeze(0),
364
+ pooled_projections=pooled_prompt_embeds[idx].unsqueeze(0),
365
+ guidance=guidance[idx].unsqueeze(0),
366
+ attention_kwargs=attention_kwargs,
367
+ return_dict=False,
368
+ )[0]
369
+ noise_pred_list.append(noise_pred_uncond)
370
+ noise_pred = torch.cat(noise_pred_list, dim=0)
371
+ else:
372
+ noise_pred = self.transformer(
373
+ hidden_states=latent_model_input,
374
+ timestep=timestep,
375
+ encoder_hidden_states=prompt_embeds,
376
+ encoder_attention_mask=prompt_attention_mask,
377
+ pooled_projections=pooled_prompt_embeds,
378
+ guidance=guidance,
379
+ attention_kwargs=attention_kwargs,
380
+ return_dict=False,
381
+ )[0]
382
+
383
+ # perform guidance
384
+ if self.do_classifier_free_guidance:
385
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
386
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
387
+
388
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
389
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
390
+ noise_pred = rescale_noise_cfg(
391
+ noise_pred,
392
+ noise_pred_text,
393
+ guidance_rescale=self.guidance_rescale,
394
+ )
395
+
396
+ # compute the previous noisy sample x_t -> x_t-1
397
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
398
+
399
+ if callback_on_step_end is not None:
400
+ callback_kwargs = {}
401
+ for k in callback_on_step_end_tensor_inputs:
402
+ callback_kwargs[k] = locals()[k]
403
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
404
+
405
+ latents = callback_outputs.pop("latents", latents)
406
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
407
+
408
+ # call the callback, if provided
409
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
410
+ progress_bar.update()
411
+
412
+ if not output_type == "latent":
413
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
414
+ video = self.vae.decode(latents, return_dict=False)[0]
415
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
416
+ else:
417
+ video = latents
418
+
419
+ # Offload all models
420
+ self.maybe_free_model_hooks()
421
+
422
+ if not return_dict:
423
+ return (video,)
424
+
425
+ return HunyuanVideoPipelineOutput(frames=video)
skyreelsinfer/skyreels_video_infer.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import threading
4
+ import time
5
+ from datetime import timedelta
6
+ from typing import Any
7
+ from typing import Dict
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.multiprocessing as mp
12
+ from diffusers import HunyuanVideoTransformer3DModel
13
+ from PIL import Image
14
+ from torchao.quantization import float8_weight_only
15
+ from torchao.quantization import quantize_
16
+ from transformers import LlamaModel
17
+
18
+ from . import TaskType
19
+ from .offload import Offload
20
+ from .offload import OffloadConfig
21
+ from .pipelines import SkyreelsVideoPipeline
22
+
23
+ logger = logging.getLogger("SkyreelsVideoInfer")
24
+ logger.setLevel(logging.DEBUG)
25
+ console_handler = logging.StreamHandler()
26
+ console_handler.setLevel(logging.DEBUG)
27
+ formatter = logging.Formatter(
28
+ f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
29
+ )
30
+ console_handler.setFormatter(formatter)
31
+ logger.addHandler(console_handler)
32
+
33
+
34
+ class SkyReelsVideoSingleGpuInfer:
35
+ def _load_model(
36
+ self,
37
+ model_id: str,
38
+ base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
39
+ quant_model: bool = True,
40
+ gpu_device: str = "cuda:0",
41
+ ) -> SkyreelsVideoPipeline:
42
+ logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
43
+ text_encoder = LlamaModel.from_pretrained(
44
+ base_model_id,
45
+ subfolder="text_encoder",
46
+ torch_dtype=torch.bfloat16,
47
+ ).to("cpu")
48
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
49
+ model_id,
50
+ # subfolder="transformer",
51
+ torch_dtype=torch.bfloat16,
52
+ device="cpu",
53
+ ).to("cpu")
54
+ if quant_model:
55
+ quantize_(text_encoder, float8_weight_only(), device=gpu_device)
56
+ text_encoder.to("cpu")
57
+ torch.cuda.empty_cache()
58
+ quantize_(transformer, float8_weight_only(), device=gpu_device)
59
+ transformer.to("cpu")
60
+ torch.cuda.empty_cache()
61
+ pipe = SkyreelsVideoPipeline.from_pretrained(
62
+ base_model_id,
63
+ transformer=transformer,
64
+ text_encoder=text_encoder,
65
+ torch_dtype=torch.bfloat16,
66
+ ).to("cpu")
67
+ pipe.vae.enable_tiling()
68
+ torch.cuda.empty_cache()
69
+ return pipe
70
+
71
+ def __init__(
72
+ self,
73
+ task_type: TaskType,
74
+ model_id: str,
75
+ quant_model: bool = True,
76
+ local_rank: int = 0,
77
+ world_size: int = 1,
78
+ is_offload: bool = True,
79
+ offload_config: OffloadConfig = OffloadConfig(),
80
+ enable_cfg_parallel: bool = True,
81
+ ):
82
+ self.task_type = task_type
83
+ self.gpu_rank = local_rank
84
+ dist.init_process_group(
85
+ backend="nccl",
86
+ init_method="tcp://127.0.0.1:23456",
87
+ timeout=timedelta(seconds=600),
88
+ world_size=world_size,
89
+ rank=local_rank,
90
+ )
91
+ os.environ["LOCAL_RANK"] = str(local_rank)
92
+ logger.info(f"rank:{local_rank} Distributed backend: {dist.get_backend()}")
93
+ torch.cuda.set_device(dist.get_rank())
94
+ torch.backends.cuda.enable_cudnn_sdp(False)
95
+ gpu_device = f"cuda:{dist.get_rank()}"
96
+
97
+ self.pipe: SkyreelsVideoPipeline = self._load_model(
98
+ model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
99
+ )
100
+
101
+ from para_attn.context_parallel import init_context_parallel_mesh
102
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
103
+ from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
104
+
105
+ max_batch_dim_size = 2 if enable_cfg_parallel and world_size > 1 else 1
106
+ max_ulysses_dim_size = int(world_size / max_batch_dim_size)
107
+ logger.info(f"max_batch_dim_size: {max_batch_dim_size}, max_ulysses_dim_size:{max_ulysses_dim_size}")
108
+
109
+ mesh = init_context_parallel_mesh(
110
+ self.pipe.device.type,
111
+ max_ring_dim_size=1,
112
+ max_batch_dim_size=max_batch_dim_size,
113
+ )
114
+ parallelize_pipe(self.pipe, mesh=mesh)
115
+ parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
116
+
117
+ if is_offload:
118
+ Offload.offload(
119
+ pipeline=self.pipe,
120
+ config=offload_config,
121
+ )
122
+ else:
123
+ self.pipe.to(gpu_device)
124
+
125
+ if offload_config.compiler_transformer:
126
+ torch._dynamo.config.suppress_errors = True
127
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
128
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_{world_size}"
129
+ self.pipe.transformer = torch.compile(
130
+ self.pipe.transformer,
131
+ mode="max-autotune-no-cudagraphs",
132
+ dynamic=True,
133
+ )
134
+ self.warm_up()
135
+
136
+ def warm_up(self):
137
+ init_kwargs = {
138
+ "prompt": "A woman is dancing in a room",
139
+ "height": 544,
140
+ "width": 960,
141
+ "guidance_scale": 6,
142
+ "num_inference_steps": 1,
143
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
144
+ "num_frames": 97,
145
+ "generator": torch.Generator("cuda").manual_seed(42),
146
+ "embedded_guidance_scale": 1.0,
147
+ }
148
+ if self.task_type == TaskType.I2V:
149
+ init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
150
+ self.pipe(**init_kwargs)
151
+
152
+ def damon_inference(self, request_queue: mp.Queue, response_queue: mp.Queue):
153
+ response_queue.put(f"rank:{self.gpu_rank} ready")
154
+ logger.info(f"rank:{self.gpu_rank} finish init pipe")
155
+ while True:
156
+ logger.info(f"rank:{self.gpu_rank} waiting for request")
157
+ kwargs = request_queue.get()
158
+ logger.info(f"rank:{self.gpu_rank} kwargs: {kwargs}")
159
+ if "seed" in kwargs:
160
+ kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
161
+ del kwargs["seed"]
162
+ start_time = time.time()
163
+ assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
164
+ out = self.pipe(**kwargs).frames[0]
165
+ logger.info(f"rank:{dist.get_rank()} inference time: {time.time() - start_time}")
166
+ if dist.get_rank() == 0:
167
+ response_queue.put(out)
168
+
169
+
170
+ def single_gpu_run(
171
+ rank,
172
+ task_type: TaskType,
173
+ model_id: str,
174
+ request_queue: mp.Queue,
175
+ response_queue: mp.Queue,
176
+ quant_model: bool = True,
177
+ world_size: int = 1,
178
+ is_offload: bool = True,
179
+ offload_config: OffloadConfig = OffloadConfig(),
180
+ enable_cfg_parallel: bool = True,
181
+ ):
182
+ pipe = SkyReelsVideoSingleGpuInfer(
183
+ task_type=task_type,
184
+ model_id=model_id,
185
+ quant_model=quant_model,
186
+ local_rank=rank,
187
+ world_size=world_size,
188
+ is_offload=is_offload,
189
+ offload_config=offload_config,
190
+ enable_cfg_parallel=enable_cfg_parallel,
191
+ )
192
+ pipe.damon_inference(request_queue, response_queue)
193
+
194
+
195
+ class SkyReelsVideoInfer:
196
+ def __init__(
197
+ self,
198
+ task_type: TaskType,
199
+ model_id: str,
200
+ quant_model: bool = True,
201
+ world_size: int = 1,
202
+ is_offload: bool = True,
203
+ offload_config: OffloadConfig = OffloadConfig(),
204
+ enable_cfg_parallel: bool = True,
205
+ ):
206
+ self.world_size = world_size
207
+ smp = mp.get_context("spawn")
208
+ self.REQ_QUEUES: mp.Queue = smp.Queue()
209
+ self.RESP_QUEUE: mp.Queue = smp.Queue()
210
+ assert self.world_size > 0, "gpu_num must be greater than 0"
211
+ spawn_thread = threading.Thread(
212
+ target=self.lauch_single_gpu_infer,
213
+ args=(task_type, model_id, quant_model, world_size, is_offload, offload_config, enable_cfg_parallel),
214
+ daemon=True,
215
+ )
216
+ spawn_thread.start()
217
+ logger.info(f"Started multi-GPU thread with GPU_NUM: {world_size}")
218
+ print(f"Started multi-GPU thread with GPU_NUM: {world_size}")
219
+ # Block and wait for the prediction process to start
220
+ for _ in range(world_size):
221
+ msg = self.RESP_QUEUE.get()
222
+ logger.info(f"launch_multi_gpu get init msg: {msg}")
223
+ print(f"launch_multi_gpu get init msg: {msg}")
224
+
225
+ def lauch_single_gpu_infer(
226
+ self,
227
+ task_type: TaskType,
228
+ model_id: str,
229
+ quant_model: bool = True,
230
+ world_size: int = 1,
231
+ is_offload: bool = True,
232
+ offload_config: OffloadConfig = OffloadConfig(),
233
+ enable_cfg_parallel: bool = True,
234
+ ):
235
+ mp.spawn(
236
+ single_gpu_run,
237
+ nprocs=world_size,
238
+ join=True,
239
+ daemon=True,
240
+ args=(
241
+ task_type,
242
+ model_id,
243
+ self.REQ_QUEUES,
244
+ self.RESP_QUEUE,
245
+ quant_model,
246
+ world_size,
247
+ is_offload,
248
+ offload_config,
249
+ enable_cfg_parallel,
250
+ ),
251
+ )
252
+ logger.info(f"finish lanch multi gpu infer, world_size:{world_size}")
253
+
254
+ def inference(self, kwargs: Dict[str, Any]):
255
+ # put request to singlegpuinfer
256
+ for _ in range(self.world_size):
257
+ self.REQ_QUEUES.put(kwargs)
258
+ return self.RESP_QUEUE.get()