|
import torch |
|
import torch.nn as nn |
|
|
|
from ._ops import ops |
|
|
|
|
|
class SiluAndMul(nn.Module): |
|
def forward(self, x: torch.Tensor): |
|
d = x.shape[-1] // 2 |
|
output_shape = x.shape[:-1] + (d,) |
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) |
|
ops.silu_and_mul(out, x) |
|
return out |
|
|
|
|
|
class GeluAndMul(nn.Module): |
|
def forward(self, x: torch.Tensor): |
|
d = x.shape[-1] // 2 |
|
output_shape = x.shape[:-1] + (d,) |
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) |
|
ops.gelu_and_mul(out, x) |
|
return out |
|
|
|
|
|
class GeluTanhAndMul(nn.Module): |
|
def forward(self, x: torch.Tensor): |
|
d = x.shape[-1] // 2 |
|
output_shape = x.shape[:-1] + (d,) |
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) |
|
ops.gelu_tanh_and_mul(out, x) |
|
return out |
|
|
|
|
|
class FatreluAndMul(nn.Module): |
|
def __init__(self, threshold: float = 0.0): |
|
super().__init__() |
|
self.threshold = threshold |
|
|
|
def forward(self, x: torch.Tensor): |
|
d = x.shape[-1] // 2 |
|
output_shape = x.shape[:-1] + (d,) |
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) |
|
ops.fatrelu_and_mul(out, x, self.threshold) |
|
return out |
|
|
|
|
|
class FastGELU(nn.Module): |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
out = torch.empty_like(x) |
|
ops.gelu_fast(out, x) |
|
return out |
|
|
|
|
|
class NewGELU(nn.Module): |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
out = torch.empty_like(x) |
|
ops.gelu_new(out, x) |
|
return out |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
out = torch.empty_like(x) |
|
ops.gelu_quick(out, x) |
|
return out |
|
|