danieldk's picture
danieldk HF Staff
Add layers
28b4d27
raw
history blame
1.77 kB
import torch
import torch.nn as nn
from ._ops import ops
class SiluAndMul(nn.Module):
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out
class GeluAndMul(nn.Module):
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.gelu_and_mul(out, x)
return out
class GeluTanhAndMul(nn.Module):
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.gelu_tanh_and_mul(out, x)
return out
class FatreluAndMul(nn.Module):
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.fatrelu_and_mul(out, x, self.threshold)
return out
class FastGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
class NewGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out