Spaces:
Running
Running
File size: 412 Bytes
306b4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import triton
import triton.language as tl
from packaging import version
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
if TRITON3:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
return dt
else:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
return dt |