linoyts HF Staff commited on
Commit
9aa2a04
·
verified ·
1 Parent(s): b5c3f40

update aoti compile

Browse files
Files changed (1) hide show
  1. optimization.py +4 -9
optimization.py CHANGED
@@ -122,13 +122,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
122
  else:
123
  return cp2(*args, **kwargs)
124
 
125
- transformer_config = pipeline.transformer.config
126
- transformer_dtype = pipeline.transformer.dtype
127
 
128
- pipeline.transformer = combined_transformer_1
129
- pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
130
- pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
131
-
132
- pipeline.transformer_2 = combined_transformer_2
133
- pipeline.transformer_2.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
134
- pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
 
122
  else:
123
  return cp2(*args, **kwargs)
124
 
125
+ pipeline.transformer.forward = combined_transformer_1
126
+ drain_module_parameters(pipeline.transformer)
127
 
128
+ pipeline.transformer_2.forward = combined_transformer_2
129
+ drain_module_parameters(pipeline.transformer_2)