Spaces:
Running
on
Zero
Running
on
Zero
Update optimization.py
Browse files- optimization.py +3 -2
optimization.py
CHANGED
|
@@ -10,6 +10,7 @@ import torch
|
|
| 10 |
from torch.utils._pytree import tree_map_only
|
| 11 |
from torchao.quantization import quantize_
|
| 12 |
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
|
|
|
| 13 |
|
| 14 |
from optimization_utils import capture_component_call
|
| 15 |
from optimization_utils import aoti_compile
|
|
@@ -42,6 +43,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 42 |
@spaces.GPU(duration=1500)
|
| 43 |
def compile_transformer():
|
| 44 |
|
|
|
|
|
|
|
| 45 |
with capture_component_call(pipeline, 'transformer') as call:
|
| 46 |
pipeline(*args, **kwargs)
|
| 47 |
|
|
@@ -86,9 +89,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 86 |
compiled_portrait_2,
|
| 87 |
)
|
| 88 |
|
| 89 |
-
pipeline.text_encoder.to('cpu')
|
| 90 |
cl1, cl2, cp1, cp2 = compile_transformer()
|
| 91 |
-
pipeline.text_encoder.to('cuda')
|
| 92 |
|
| 93 |
def combined_transformer_1(*args, **kwargs):
|
| 94 |
hidden_states: torch.Tensor = kwargs['hidden_states']
|
|
|
|
| 10 |
from torch.utils._pytree import tree_map_only
|
| 11 |
from torchao.quantization import quantize_
|
| 12 |
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
| 13 |
+
from torchao.quantization import Int8WeightOnlyConfig
|
| 14 |
|
| 15 |
from optimization_utils import capture_component_call
|
| 16 |
from optimization_utils import aoti_compile
|
|
|
|
| 43 |
@spaces.GPU(duration=1500)
|
| 44 |
def compile_transformer():
|
| 45 |
|
| 46 |
+
quantize_(pipeline.text_encoder, Int8WeightOnlyConfig()) # Just to free-up some GPU memory
|
| 47 |
+
|
| 48 |
with capture_component_call(pipeline, 'transformer') as call:
|
| 49 |
pipeline(*args, **kwargs)
|
| 50 |
|
|
|
|
| 89 |
compiled_portrait_2,
|
| 90 |
)
|
| 91 |
|
|
|
|
| 92 |
cl1, cl2, cp1, cp2 = compile_transformer()
|
|
|
|
| 93 |
|
| 94 |
def combined_transformer_1(*args, **kwargs):
|
| 95 |
hidden_states: torch.Tensor = kwargs['hidden_states']
|