Spaces:
Running
on
Zero
Running
on
Zero
""" Model / Layer Config singleton state | |
""" | |
import os | |
import warnings | |
from typing import Any, Optional | |
import torch | |
__all__ = [ | |
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn', | |
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn' | |
] | |
# Set to True if prefer to have layers with no jit optimization (includes activations) | |
_NO_JIT = False | |
# Set to True if prefer to have activation layers with no jit optimization | |
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying | |
# the jit flags so far are activations. This will change as more layers are updated and/or added. | |
_NO_ACTIVATION_JIT = False | |
# Set to True if exporting a model with Same padding via ONNX | |
_EXPORTABLE = False | |
# Set to True if wanting to use torch.jit.script on a model | |
_SCRIPTABLE = False | |
# use torch.scaled_dot_product_attention where possible | |
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention') | |
if 'TIMM_FUSED_ATTN' in os.environ: | |
_USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN']) | |
else: | |
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) | |
def is_no_jit(): | |
return _NO_JIT | |
class set_no_jit: | |
def __init__(self, mode: bool) -> None: | |
global _NO_JIT | |
self.prev = _NO_JIT | |
_NO_JIT = mode | |
def __enter__(self) -> None: | |
pass | |
def __exit__(self, *args: Any) -> bool: | |
global _NO_JIT | |
_NO_JIT = self.prev | |
return False | |
def is_exportable(): | |
return _EXPORTABLE | |
class set_exportable: | |
def __init__(self, mode: bool) -> None: | |
global _EXPORTABLE | |
self.prev = _EXPORTABLE | |
_EXPORTABLE = mode | |
def __enter__(self) -> None: | |
pass | |
def __exit__(self, *args: Any) -> bool: | |
global _EXPORTABLE | |
_EXPORTABLE = self.prev | |
return False | |
def is_scriptable(): | |
return _SCRIPTABLE | |
class set_scriptable: | |
def __init__(self, mode: bool) -> None: | |
global _SCRIPTABLE | |
self.prev = _SCRIPTABLE | |
_SCRIPTABLE = mode | |
def __enter__(self) -> None: | |
pass | |
def __exit__(self, *args: Any) -> bool: | |
global _SCRIPTABLE | |
_SCRIPTABLE = self.prev | |
return False | |
class set_layer_config: | |
""" Layer config context manager that allows setting all layer config flags at once. | |
If a flag arg is None, it will not change the current value. | |
""" | |
def __init__( | |
self, | |
scriptable: Optional[bool] = None, | |
exportable: Optional[bool] = None, | |
no_jit: Optional[bool] = None, | |
no_activation_jit: Optional[bool] = None): | |
global _SCRIPTABLE | |
global _EXPORTABLE | |
global _NO_JIT | |
global _NO_ACTIVATION_JIT | |
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT | |
if scriptable is not None: | |
_SCRIPTABLE = scriptable | |
if exportable is not None: | |
_EXPORTABLE = exportable | |
if no_jit is not None: | |
_NO_JIT = no_jit | |
if no_activation_jit is not None: | |
_NO_ACTIVATION_JIT = no_activation_jit | |
def __enter__(self) -> None: | |
pass | |
def __exit__(self, *args: Any) -> bool: | |
global _SCRIPTABLE | |
global _EXPORTABLE | |
global _NO_JIT | |
global _NO_ACTIVATION_JIT | |
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev | |
return False | |
def use_fused_attn(experimental: bool = False) -> bool: | |
# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0 | |
if not _HAS_FUSED_ATTN or _EXPORTABLE: | |
return False | |
if experimental: | |
return _USE_FUSED_ATTN > 1 | |
return _USE_FUSED_ATTN > 0 | |
def set_fused_attn(enable: bool = True, experimental: bool = False): | |
global _USE_FUSED_ATTN | |
if not _HAS_FUSED_ATTN: | |
warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.') | |
return | |
if experimental and enable: | |
_USE_FUSED_ATTN = 2 | |
elif enable: | |
_USE_FUSED_ATTN = 1 | |
else: | |
_USE_FUSED_ATTN = 0 | |