File size: 860 Bytes
eb339cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import copy
from typing import Optional
import torch.nn as nn
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU()
}
def get_clone(module: nn.Module) -> nn.Module:
return copy.deepcopy(module)
def get_clones(module: nn.Module, N: int) -> nn.ModuleList:
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def get_activation_fn(act_fn: Optional[str] = None) -> nn.Module:
if act_fn is None:
return nn.Identity()
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
def zero_module(module: nn.Module) -> nn.Module:
for p in module.parameters():
nn.init.zeros_(p)
return module
|