linoyts HF Staff commited on
Commit
988720a
·
verified ·
1 Parent(s): 89a8ba5

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +12 -11
optimization.py CHANGED
@@ -20,7 +20,13 @@ from optimization_utils import ZeroGPUCompiledModel
20
  P = ParamSpec('P')
21
 
22
 
23
- TRANSFORMER_DYNAMIC_SHAPES = {}
 
 
 
 
 
 
24
 
25
  INDUCTOR_CONFIGS = {
26
  'conv_1x1_as_mm': True,
@@ -36,7 +42,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
36
 
37
  @spaces.GPU(duration=1500)
38
  def compile_transformer():
39
-
40
  pipeline.load_lora_weights(
41
  "vrgamedevgirl84/Wan14BT2VFusioniX",
42
  weight_name="FusionX_LoRa/Phantom_Wan_14B_FusionX_LoRA.safetensors",
@@ -116,13 +122,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
116
  else:
117
  return cp2(*args, **kwargs)
118
 
119
- transformer_config = pipeline.transformer.config
120
- transformer_dtype = pipeline.transformer.dtype
121
-
122
- pipeline.transformer = combined_transformer_1
123
- pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
124
- pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
125
 
126
- pipeline.transformer_2 = combined_transformer_2
127
- pipeline.transformer_2.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
128
- pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
 
20
  P = ParamSpec('P')
21
 
22
 
23
+ TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
24
+
25
+ TRANSFORMER_DYNAMIC_SHAPES = {
26
+ 'hidden_states': {
27
+ 2: TRANSFORMER_NUM_FRAMES_DIM,
28
+ },
29
+ }
30
 
31
  INDUCTOR_CONFIGS = {
32
  'conv_1x1_as_mm': True,
 
42
 
43
  @spaces.GPU(duration=1500)
44
  def compile_transformer():
45
+
46
  pipeline.load_lora_weights(
47
  "vrgamedevgirl84/Wan14BT2VFusioniX",
48
  weight_name="FusionX_LoRa/Phantom_Wan_14B_FusionX_LoRA.safetensors",
 
122
  else:
123
  return cp2(*args, **kwargs)
124
 
125
+ pipeline.transformer.forward = combined_transformer_1
126
+ drain_module_parameters(pipeline.transformer)
 
 
 
 
127
 
128
+ pipeline.transformer_2.forward = combined_transformer_2
129
+ drain_module_parameters(pipeline.transformer_2)