linoyts HF Staff commited on
Commit
55e04d3
·
verified ·
1 Parent(s): b5c3f40

update aoti compile (#3)

Browse files

- update aoti compile (9aa2a04d546ddb4bb6deb5d4ea24e016061b33cd)
- Update optimization_utils.py (8e5bed5da4bd6b198ad33fdbb55dbe760f3ea32a)

Files changed (2) hide show
  1. optimization.py +4 -9
  2. optimization_utils.py +9 -0
optimization.py CHANGED
@@ -122,13 +122,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
122
  else:
123
  return cp2(*args, **kwargs)
124
 
125
- transformer_config = pipeline.transformer.config
126
- transformer_dtype = pipeline.transformer.dtype
127
 
128
- pipeline.transformer = combined_transformer_1
129
- pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
130
- pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
131
-
132
- pipeline.transformer_2 = combined_transformer_2
133
- pipeline.transformer_2.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
134
- pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
 
122
  else:
123
  return cp2(*args, **kwargs)
124
 
125
+ pipeline.transformer.forward = combined_transformer_1
126
+ drain_module_parameters(pipeline.transformer)
127
 
128
+ pipeline.transformer_2.forward = combined_transformer_2
129
+ drain_module_parameters(pipeline.transformer_2)
 
 
 
 
 
optimization_utils.py CHANGED
@@ -96,3 +96,12 @@ def capture_component_call(
96
  except CapturedCallException as e:
97
  captured_call.args = e.args
98
  captured_call.kwargs = e.kwargs
 
 
 
 
 
 
 
 
 
 
96
  except CapturedCallException as e:
97
  captured_call.args = e.args
98
  captured_call.kwargs = e.kwargs
99
+
100
+
101
+ def drain_module_parameters(module: torch.nn.Module):
102
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
+ module.load_state_dict(state_dict, assign=True)
105
+ for name, param in state_dict.items():
106
+ meta = state_dict_meta[name]
107
+ param.data = torch.Tensor([]).to(**meta)