Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
from copy import deepcopy | |
from typing import Union | |
import torch | |
from torch import nn as nn | |
from torch.nn import functional as F | |
def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float): | |
x = x.float() | |
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight | |
def fused_ada_layer_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor): | |
x = x.float() | |
x = F.layer_norm(input=x, normalized_shape=(C,), weight=None, bias=None, eps=eps) | |
return x.mul(scale.add(1)).add_(shift) | |
def fused_ada_rms_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor): | |
x = x.float() | |
x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) | |
return x.mul(scale.add(1)).add_(shift) | |