M3Site / esm /layers /regression_head.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
raw
history blame contribute delete
642 Bytes
import torch.nn as nn
def RegressionHead(
d_model: int,
output_dim: int,
hidden_dim: int | None = None,
) -> nn.Module:
"""Single-hidden layer MLP for supervised output.
Args:
d_model: input dimension
output_dim: dimensionality of the output.
hidden_dim: optional dimension of hidden layer, defaults to d_model.
Returns:
output MLP module.
"""
hidden_dim = hidden_dim if hidden_dim is not None else d_model
return nn.Sequential(
nn.Linear(d_model, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, output_dim),
)