Spaces:
Sleeping
Sleeping
File size: 3,904 Bytes
28c6826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
""" Activation Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from .activations import *
from .activations_jit import *
from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
# will use native version if present. Eventually, the custom Swish layers will be removed
# and only native 'silu' will be used.
_has_silu = 'silu' in dir(torch.nn.functional)
_ACT_FN_DEFAULT = dict(
silu=F.silu if _has_silu else swish,
swish=F.silu if _has_silu else swish,
mish=mish,
relu=F.relu,
relu6=F.relu6,
leaky_relu=F.leaky_relu,
elu=F.elu,
celu=F.celu,
selu=F.selu,
gelu=gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
hard_swish=hard_swish,
hard_mish=hard_mish,
)
_ACT_FN_JIT = dict(
silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit,
mish=mish_jit,
hard_sigmoid=hard_sigmoid_jit,
hard_swish=hard_swish_jit,
hard_mish=hard_mish_jit
)
_ACT_FN_ME = dict(
silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me,
mish=mish_me,
hard_sigmoid=hard_sigmoid_me,
hard_swish=hard_swish_me,
hard_mish=hard_mish_me,
)
_ACT_LAYER_DEFAULT = dict(
silu=nn.SiLU if _has_silu else Swish,
swish=nn.SiLU if _has_silu else Swish,
mish=Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU,
elu=nn.ELU,
prelu=PReLU,
celu=nn.CELU,
selu=nn.SELU,
gelu=GELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
hard_mish=HardMish,
)
_ACT_LAYER_JIT = dict(
silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit,
mish=MishJit,
hard_sigmoid=HardSigmoidJit,
hard_swish=HardSwishJit,
hard_mish=HardMishJit
)
_ACT_LAYER_ME = dict(
silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe,
mish=MishMe,
hard_sigmoid=HardSigmoidMe,
hard_swish=HardSwishMe,
hard_mish=HardMishMe,
)
def get_act_fn(name='relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if not (is_no_jit() or is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback
if name in _ACT_FN_ME:
return _ACT_FN_ME[name]
if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return Swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name, inplace=False, **kwargs):
act_layer = get_act_layer(name)
if act_layer is not None:
return act_layer(inplace=inplace, **kwargs)
else:
return None
|