wxDai's picture
[Init]
eb339cb
raw
history blame
860 Bytes
import copy
from typing import Optional
import torch.nn as nn
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU()
}
def get_clone(module: nn.Module) -> nn.Module:
return copy.deepcopy(module)
def get_clones(module: nn.Module, N: int) -> nn.ModuleList:
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def get_activation_fn(act_fn: Optional[str] = None) -> nn.Module:
if act_fn is None:
return nn.Identity()
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
def zero_module(module: nn.Module) -> nn.Module:
for p in module.parameters():
nn.init.zeros_(p)
return module