ipd's picture
add pos-egnn
dded3ca
raw
history blame
20.9 kB
"""
This code was adapted from https://github.com/sarpaykent/GotenNet
Copyright (c) 2025 Sarp Aykent
MIT License
GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks
Sarp Aykent and Tian Xia
https://openreview.net/pdf?id=5wxCQDtbMo
"""
from functools import partial
from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.typing import OptTensor
from torch_geometric.utils import scatter, softmax
from .ops import (
MLP,
CosineCutoff,
Dense,
EdgeInit,
NodeInit,
TensorInit,
TensorLayerNorm,
get_weight_init_by_string,
parse_update_info,
str2act,
str2basis,
)
def lmax_tensor_size(lmax):
return ((lmax + 1) ** 2) - 1
def split_degree(tensor, lmax, dim=-1): # default to last dim
cumsum = 0
tensors = []
for i in range(1, lmax + 1):
count = lmax_tensor_size(i) - lmax_tensor_size(i - 1)
# Create slice object for the specified dimension
slc = [slice(None)] * tensor.ndim # Create list of slice(None) for all dims
slc[dim] = slice(cumsum, cumsum + count) # Replace desired dim with actual slice
tensors.append(tensor[tuple(slc)])
cumsum += count
return tensors
class GATA(MessagePassing):
def __init__(
self,
n_atom_basis: int,
activation: Callable,
weight_init=nn.init.xavier_uniform_,
bias_init=nn.init.zeros_,
aggr="add",
node_dim=0,
epsilon: float = 1e-7,
layer_norm=False,
vector_norm=False,
cutoff=5.0,
num_heads=8,
dropout=0.0,
edge_updates=True,
last_layer=False,
scale_edge=True,
edge_ln="",
evec_dim=None,
emlp_dim=None,
sep_vecj=True,
lmax=1,
):
"""
Args:
n_atom_basis (int): Number of features to describe atomic environments.
activation (Callable): Activation function to be used. If None, no activation function is used.
weight_init (Callable): Weight initialization function.
bias_init (Callable): Bias initialization function.
aggr (str): Aggregation method ('add', 'mean' or 'max').
node_dim (int): The axis along which to aggregate.
"""
super(GATA, self).__init__(aggr=aggr, node_dim=node_dim)
self.lmax = lmax
self.sep_vecj = sep_vecj
self.epsilon = epsilon
self.last_layer = last_layer
self.edge_updates = edge_updates
self.scale_edge = scale_edge
self.activation = activation
self.update_info = parse_update_info(edge_updates)
self.dropout = dropout
self.n_atom_basis = n_atom_basis
InitDense = partial(Dense, weight_init=weight_init, bias_init=bias_init)
self.gamma_s = nn.Sequential(
InitDense(n_atom_basis, n_atom_basis, activation=activation),
InitDense(n_atom_basis, 3 * n_atom_basis, activation=None),
)
self.num_heads = num_heads
self.q_w = InitDense(n_atom_basis, n_atom_basis, activation=None)
self.k_w = InitDense(n_atom_basis, n_atom_basis, activation=None)
self.gamma_v = nn.Sequential(
InitDense(n_atom_basis, n_atom_basis, activation=activation),
InitDense(n_atom_basis, 3 * n_atom_basis, activation=None),
)
self.phik_w_ra = InitDense(
n_atom_basis,
n_atom_basis,
activation=activation,
)
InitMLP = partial(MLP, weight_init=weight_init, bias_init=bias_init)
self.edge_vec_dim = n_atom_basis if evec_dim is None else evec_dim
self.edge_mlp_dim = n_atom_basis if emlp_dim is None else emlp_dim
if not self.last_layer and self.edge_updates:
if self.update_info["mlp"] or self.update_info["mlpa"]:
dims = [n_atom_basis, self.edge_mlp_dim, n_atom_basis]
else:
dims = [n_atom_basis, n_atom_basis]
self.edge_attr_up = InitMLP(
dims, activation=activation, last_activation=None if self.update_info["mlp"] else self.activation, norm=edge_ln
)
self.vecq_w = InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False)
if self.sep_vecj:
self.veck_w = nn.ModuleList(
[InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False) for i in range(self.lmax)]
)
else:
self.veck_w = InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False)
if self.update_info["lin_w"] > 0:
modules = []
if self.update_info["lin_w"] % 10 == 2:
modules.append(self.activation)
self.lin_w_linear = InitDense(
self.edge_vec_dim,
n_atom_basis,
activation=None,
norm="layer" if self.update_info["lin_w"] == 2 else "", # lin_ln in original code but error
)
modules.append(self.lin_w_linear)
self.lin_w = nn.Sequential(*modules)
self.down_proj = nn.Identity()
self.cutoff = CosineCutoff(cutoff)
self._alpha = None
self.w_re = InitDense(
n_atom_basis,
n_atom_basis * 3,
None,
)
self.layernorm_ = layer_norm
self.vector_norm_ = vector_norm
if layer_norm:
self.layernorm = nn.LayerNorm(n_atom_basis)
else:
self.layernorm = nn.Identity()
if vector_norm:
self.tln = TensorLayerNorm(n_atom_basis, trainable=False)
else:
self.tln = nn.Identity()
self.reset_parameters()
def reset_parameters(self):
if self.layernorm_:
self.layernorm.reset_parameters()
if self.vector_norm_:
self.tln.reset_parameters()
for l in self.gamma_s: # noqa: E741
l.reset_parameters()
self.q_w.reset_parameters()
self.k_w.reset_parameters()
for l in self.gamma_v: # noqa: E741
l.reset_parameters()
# self.v_w.reset_parameters()
# self.out_w.reset_parameters()
self.w_re.reset_parameters()
if not self.last_layer and self.edge_updates:
self.edge_attr_up.reset_parameters()
self.vecq_w.reset_parameters()
if self.sep_vecj:
for w in self.veck_w:
w.reset_parameters()
else:
self.veck_w.reset_parameters()
if self.update_info["lin_w"] > 0:
self.lin_w_linear.reset_parameters()
def forward(
self,
edge_index,
s: torch.Tensor,
t: torch.Tensor,
dir_ij: torch.Tensor,
r_ij: torch.Tensor,
d_ij: torch.Tensor,
num_edges_expanded: torch.Tensor,
):
"""Compute interaction output."""
s = self.layernorm(s)
t = self.tln(t)
q = self.q_w(s).reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads)
k = self.k_w(s).reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads)
x = self.gamma_s(s)
val = self.gamma_v(s)
f_ij = r_ij
r_ij_attn = self.phik_w_ra(r_ij)
r_ij = self.w_re(r_ij)
# propagate_type: (x: Tensor, ten: Tensor, q:Tensor, k:Tensor, val:Tensor, r_ij: Tensor, r_ij_attn: Tensor, d_ij:Tensor, dir_ij: Tensor, num_edges_expanded: Tensor)
su, tu = self.propagate(
edge_index=edge_index,
x=x,
q=q,
k=k,
val=val,
ten=t,
r_ij=r_ij,
r_ij_attn=r_ij_attn,
d_ij=d_ij,
dir_ij=dir_ij,
num_edges_expanded=num_edges_expanded,
) # , f_ij=f_ij
s = s + su
t = t + tu
if not self.last_layer and self.edge_updates:
vec = t
w1 = self.vecq_w(vec)
if self.sep_vecj:
vec_split = split_degree(vec, self.lmax, dim=1)
w_out = torch.concat([w(vec_split[i]) for i, w in enumerate(self.veck_w)], dim=1)
else:
w_out = self.veck_w(vec)
# edge_updater_type: (w1: Tensor, w2:Tensor, d_ij: Tensor, f_ij: Tensor)
df_ij = self.edge_updater(edge_index, w1=w1, w2=w_out, d_ij=dir_ij, f_ij=f_ij)
df_ij = f_ij + df_ij
self._alpha = None
return s, t, df_ij
else:
self._alpha = None
return s, t, f_ij
# return s, t
def message(
self,
edge_index,
x_i: torch.Tensor,
x_j: torch.Tensor,
q_i: torch.Tensor,
k_j: torch.Tensor,
val_j: torch.Tensor,
ten_j: torch.Tensor,
r_ij: torch.Tensor,
r_ij_attn: torch.Tensor,
d_ij: torch.Tensor,
dir_ij: torch.Tensor,
num_edges_expanded: torch.Tensor,
index: torch.Tensor,
ptr: OptTensor,
dim_size: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute message passing.
"""
r_ij_attn = r_ij_attn.reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads)
attn = (q_i * k_j * r_ij_attn).sum(dim=-1, keepdim=True)
attn = softmax(attn, index, ptr, dim_size)
# Normalize the attention scores
if self.scale_edge:
norm = torch.sqrt(num_edges_expanded.reshape(-1, 1, 1)) / np.sqrt(self.n_atom_basis)
else:
norm = 1.0 / np.sqrt(self.n_atom_basis)
attn = attn * norm
self._alpha = attn
attn = F.dropout(attn, p=self.dropout, training=self.training)
self_attn = attn * val_j.reshape(-1, self.num_heads, (self.n_atom_basis * 3) // self.num_heads)
SEA = self_attn.reshape(-1, 1, self.n_atom_basis * 3)
x = SEA + (r_ij.unsqueeze(1) * x_j * self.cutoff(d_ij.unsqueeze(-1).unsqueeze(-1)))
o_s, o_d, o_t = torch.split(x, self.n_atom_basis, dim=-1)
dmu = o_d * dir_ij[..., None] + o_t * ten_j
return o_s, dmu
@staticmethod
def rej(vec, d_ij):
vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True)
return vec - vec_proj * d_ij.unsqueeze(2)
def edge_update(self, w1_i, w2_j, w3_j, d_ij, f_ij):
if self.sep_vecj:
vi = w1_i
vj = w2_j
vi_split = split_degree(vi, self.lmax, dim=1)
vj_split = split_degree(vj, self.lmax, dim=1)
d_ij_split = split_degree(d_ij, self.lmax, dim=1)
pairs = []
for i in range(len(vi_split)):
if self.update_info["rej"]:
w1 = self.rej(vi_split[i], d_ij_split[i])
w2 = self.rej(vj_split[i], -d_ij_split[i])
pairs.append((w1, w2))
else:
w1 = vi_split[i]
w2 = vj_split[i]
pairs.append((w1, w2))
elif not self.update_info["rej"]:
w1 = w1_i
w2 = w2_j
pairs = [(w1, w2)]
else:
w1 = self.rej(w1_i, d_ij)
w2 = self.rej(w2_j, -d_ij)
pairs = [(w1, w2)]
w_dot_sum = None
for el in pairs:
w1, w2 = el
w_dot = (w1 * w2).sum(dim=1)
if w_dot_sum is None:
w_dot_sum = w_dot
else:
w_dot_sum = w_dot_sum + w_dot
w_dot = w_dot_sum
if self.update_info["lin_w"] > 0:
w_dot = self.lin_w(w_dot)
if self.update_info["gated"] == "gatedt":
w_dot = torch.tanh(w_dot)
elif self.update_info["gated"] == "gated":
w_dot = torch.sigmoid(w_dot)
elif self.update_info["gated"] == "act":
w_dot = self.activation(w_dot)
df_ij = self.edge_attr_up(f_ij) * w_dot
return df_ij
# noinspection PyMethodOverriding
def aggregate(
self,
features: Tuple[torch.Tensor, torch.Tensor],
index: torch.Tensor,
ptr: Optional[torch.Tensor],
dim_size: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
x, vec = features
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
return x, vec
def update(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return inputs
class EQFF(nn.Module):
def __init__(
self,
n_atom_basis: int,
activation: Callable,
epsilon: float = 1e-8,
weight_init=nn.init.xavier_uniform_,
bias_init=nn.init.zeros_,
vec_dim=None,
):
"""Equiavariant Feed Forward layer."""
super(EQFF, self).__init__()
self.n_atom_basis = n_atom_basis
InitDense = partial(Dense, weight_init=weight_init, bias_init=bias_init)
vec_dim = n_atom_basis if vec_dim is None else vec_dim
context_dim = 2 * n_atom_basis
self.gamma_m = nn.Sequential(
InitDense(context_dim, n_atom_basis, activation=activation),
InitDense(n_atom_basis, 2 * n_atom_basis, activation=None),
)
self.w_vu = InitDense(n_atom_basis, vec_dim, activation=None, bias=False)
self.epsilon = epsilon
def reset_parameters(self):
self.w_vu.reset_parameters()
for l in self.gamma_m: # noqa: E741
l.reset_parameters()
def forward(self, s, v):
"""Compute Equivariant Feed Forward output."""
t_prime = self.w_vu(v)
t_prime_mag = torch.sqrt(torch.sum(t_prime**2, dim=-2, keepdim=True) + self.epsilon)
combined = [s, t_prime_mag]
combined_tensor = torch.cat(combined, dim=-1)
m12 = self.gamma_m(combined_tensor)
m_1, m_2 = torch.split(m12, self.n_atom_basis, dim=-1)
delta_v = m_2 * t_prime
s = s + m_1
v = v + delta_v
return s, v
class GotenNet(nn.Module):
def __init__(
self,
hidden_channels: int = 128,
num_layers: int = 8,
radial_basis: Union[Callable, str] = "BesselBasis",
n_rbf: int = 20,
cutoff: float = 5.0,
activation: Optional[Union[Callable, str]] = F.silu,
max_z: int = 100,
epsilon: float = 1e-8,
weight_init=nn.init.xavier_uniform_,
bias_init=nn.init.zeros_,
int_layer_norm=False,
int_vector_norm=False,
before_mixing_layer_norm=False,
after_mixing_layer_norm=False,
num_heads=8,
attn_dropout=0.0,
edge_updates=True,
scale_edge=True,
lmax=2,
aggr="add",
edge_ln="",
evec_dim=None,
emlp_dim=None,
sep_int_vec=True,
):
"""
Representation for GotenNet
"""
super(GotenNet, self).__init__()
self.scale_edge = scale_edge
if type(weight_init) == str: # noqa: E721
# print(f"Using {weight_init} weight initialization")
weight_init = get_weight_init_by_string(weight_init)
if type(bias_init) == str: # noqa: E721
bias_init = get_weight_init_by_string(bias_init)
if type(activation) is str:
activation = str2act(activation)
self.n_atom_basis = self.hidden_dim = hidden_channels
self.n_interactions = num_layers
self.cutoff = cutoff
self.neighbor_embedding = NodeInit(
[self.hidden_dim // 2, self.hidden_dim],
n_rbf,
self.cutoff,
max_z=max_z,
weight_init=weight_init,
bias_init=bias_init,
concat=False,
proj_ln="layer",
activation=activation,
)
self.edge_embedding = EdgeInit(
n_rbf, [self.hidden_dim // 2, self.hidden_dim], weight_init=weight_init, bias_init=bias_init, proj_ln=""
)
radial_basis = str2basis(radial_basis)
self.radial_basis = radial_basis(cutoff=self.cutoff, n_rbf=n_rbf)
self.embedding = nn.Embedding(max_z, self.n_atom_basis, padding_idx=0)
self.tensor_init = TensorInit(l=lmax)
self.gata = nn.ModuleList(
[
GATA(
n_atom_basis=self.n_atom_basis,
activation=activation,
aggr=aggr,
weight_init=weight_init,
bias_init=bias_init,
layer_norm=int_layer_norm,
vector_norm=int_vector_norm,
cutoff=self.cutoff,
epsilon=epsilon,
num_heads=num_heads,
dropout=attn_dropout,
edge_updates=edge_updates,
last_layer=(i == self.n_interactions - 1),
scale_edge=scale_edge,
edge_ln=edge_ln,
evec_dim=evec_dim,
emlp_dim=emlp_dim,
sep_vecj=sep_int_vec,
lmax=lmax,
)
for i in range(self.n_interactions)
]
)
self.eqff = nn.ModuleList(
[
EQFF(n_atom_basis=self.n_atom_basis, activation=activation, epsilon=epsilon, weight_init=weight_init, bias_init=bias_init)
for i in range(self.n_interactions)
]
)
# Extra layer norms for the scalar quantities
if before_mixing_layer_norm:
self.before_mixing_ln = nn.LayerNorm(self.n_atom_basis)
else:
self.before_mixing_ln = nn.Identity()
if after_mixing_layer_norm:
self.after_mixing_ln = nn.LayerNorm(self.n_atom_basis)
else:
self.after_mixing_ln = nn.Identity()
self.reset_parameters()
def reset_parameters(self):
self.edge_embedding.reset_parameters()
self.neighbor_embedding.reset_parameters()
for l in self.gata: # noqa: E741
l.reset_parameters()
for l in self.eqff: # noqa: E741
l.reset_parameters()
if not isinstance(self.before_mixing_ln, nn.Identity):
self.before_mixing_ln.reset_parameters()
if not isinstance(self.after_mixing_ln, nn.Identity):
self.after_mixing_ln.reset_parameters()
def forward(self, z, pos, cutoff_edge_index, cutoff_edge_distance, cutoff_edge_vec):
q = self.embedding(z)[:]
edge_attr = self.radial_basis(cutoff_edge_distance)
q = self.neighbor_embedding(z, q, cutoff_edge_index, cutoff_edge_distance, edge_attr)
edge_attr = self.edge_embedding(cutoff_edge_index, edge_attr, q)
mask = cutoff_edge_index[0] != cutoff_edge_index[1]
# direction vector
dist = torch.norm(cutoff_edge_vec[mask], dim=1).unsqueeze(1)
cutoff_edge_vec[mask] = cutoff_edge_vec[mask] / dist
cutoff_edge_vec = self.tensor_init(cutoff_edge_vec)
equi_dim = ((self.tensor_init.l + 1) ** 2) - 1
# count number of edges for each node
num_edges = scatter(torch.ones_like(cutoff_edge_distance), cutoff_edge_index[0], dim=0, reduce="sum")
# the shape of num edges is [num_nodes, 1], we want to expand this to [num_edges, 1]
# Map num_edges back to the shape of attn using cutoff_edge_index
num_edges_expanded = num_edges[cutoff_edge_index[0]]
qs = q.shape
mu = torch.zeros((qs[0], equi_dim, qs[1]), device=q.device)
q.unsqueeze_(1)
layer_outputs = []
for i, (interaction, mixing) in enumerate(zip(self.gata, self.eqff)):
q, mu, edge_attr = interaction(
cutoff_edge_index,
q,
mu,
dir_ij=cutoff_edge_vec,
r_ij=edge_attr,
d_ij=cutoff_edge_distance,
num_edges_expanded=num_edges_expanded,
)
q = self.before_mixing_ln(q)
q, mu = mixing(q, mu)
q = self.after_mixing_ln(q)
# Collect all scalars for inter-layer read-outs
layer_outputs.append(q.squeeze(1))
# q = q.squeeze(1)
layer_outputs = torch.stack(layer_outputs, dim=-1)
output_dict = {}
output_dict["embedding_0"] = layer_outputs.unsqueeze(2) # [n_nodes, n_features, dimension of irrep, n_layers]
# This is a scalar so a single irrep
return output_dict