Spaces:
Running
on
Zero
Running
on
Zero
update aoti compile (#3)
Browse files- update aoti compile (9aa2a04d546ddb4bb6deb5d4ea24e016061b33cd)
- Update optimization_utils.py (8e5bed5da4bd6b198ad33fdbb55dbe760f3ea32a)
- optimization.py +4 -9
- optimization_utils.py +9 -0
optimization.py
CHANGED
@@ -122,13 +122,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
122 |
else:
|
123 |
return cp2(*args, **kwargs)
|
124 |
|
125 |
-
|
126 |
-
|
127 |
|
128 |
-
pipeline.
|
129 |
-
pipeline.
|
130 |
-
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
131 |
-
|
132 |
-
pipeline.transformer_2 = combined_transformer_2
|
133 |
-
pipeline.transformer_2.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
|
134 |
-
pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
|
|
122 |
else:
|
123 |
return cp2(*args, **kwargs)
|
124 |
|
125 |
+
pipeline.transformer.forward = combined_transformer_1
|
126 |
+
drain_module_parameters(pipeline.transformer)
|
127 |
|
128 |
+
pipeline.transformer_2.forward = combined_transformer_2
|
129 |
+
drain_module_parameters(pipeline.transformer_2)
|
|
|
|
|
|
|
|
|
|
optimization_utils.py
CHANGED
@@ -96,3 +96,12 @@ def capture_component_call(
|
|
96 |
except CapturedCallException as e:
|
97 |
captured_call.args = e.args
|
98 |
captured_call.kwargs = e.kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
except CapturedCallException as e:
|
97 |
captured_call.args = e.args
|
98 |
captured_call.kwargs = e.kwargs
|
99 |
+
|
100 |
+
|
101 |
+
def drain_module_parameters(module: torch.nn.Module):
|
102 |
+
state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
|
103 |
+
state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
|
104 |
+
module.load_state_dict(state_dict, assign=True)
|
105 |
+
for name, param in state_dict.items():
|
106 |
+
meta = state_dict_meta[name]
|
107 |
+
param.data = torch.Tensor([]).to(**meta)
|