17 lines
412 B
Python
17 lines
412 B
Python
import triton
|
|
import triton.language as tl
|
|
from packaging import version
|
|
|
|
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
|
|
|
|
|
|
if TRITON3:
|
|
@triton.jit
|
|
def softplus(dt):
|
|
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
|
return dt
|
|
else:
|
|
@triton.jit
|
|
def softplus(dt):
|
|
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
|
return dt |