activation / tests /kernels /test_activation.py
danieldk's picture
danieldk HF staff
Fix a couple of bugs and add tests from vLLM
5600c5f
raw
history blame
4.38 kB
import math
import random
from typing import Type
import activation
import pytest
import torch
import torch.nn.functional as F
from .utils import opcheck
from .allclose_default import get_default_atol, get_default_rtol
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def gelu_new(x: torch.Tensor) -> torch.Tensor:
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
def gelu_quick(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, threshold, 0.0)
return x1 * x2
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
@pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_act_and_mul(
activation_name: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation_name == "silu":
torch_fn = silu_and_mul
fn = activation.silu_and_mul
op = activation.ops.silu_and_mul
elif activation_name == "gelu":
torch_fn = lambda x: gelu_and_mul(x, "none")
fn = activation.gelu_and_mul
op = activation.ops.gelu_and_mul
elif activation_name == "gelu_tanh":
torch_fn = lambda x: gelu_and_mul(x, "tanh")
fn = activation.gelu_tanh_and_mul
op = activation.ops.gelu_tanh_and_mul
elif activation_name == "fatrelu":
threshold = random.uniform(0, 1)
torch_fn = lambda x: fatrelu_and_mul(x, threshold)
fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold)
op = activation.ops.fatrelu_and_mul
out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
out = fn(out, x)
ref_out = torch_fn(x)
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation_name == "fatrelu":
opcheck(op, (out, x, threshold))
else:
opcheck(op, (out, x))
@pytest.mark.parametrize(
"activation_fns",
[
(gelu_fast, activation.gelu_fast, activation.ops.gelu_fast),
(gelu_new, activation.gelu_new, activation.ops.gelu_new),
(gelu_quick, activation.gelu_quick, activation.ops.gelu_quick),
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_activation(
activation_fns,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
torch_fn, fn, op = activation_fns
out = fn(torch.empty_like(x), x)
ref_out = torch_fn(x)
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
)
out = torch.empty_like(x)
opcheck(op, (out, x))