Spaces:
Running
Running
""" | |
This code was adapted from https://github.com/sarpaykent/GotenNet | |
Copyright (c) 2025 Sarp Aykent | |
MIT License | |
GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks | |
Sarp Aykent and Tian Xia | |
https://openreview.net/pdf?id=5wxCQDtbMo | |
""" | |
from functools import partial | |
from typing import Callable, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.typing import OptTensor | |
from torch_geometric.utils import scatter, softmax | |
from .ops import ( | |
MLP, | |
CosineCutoff, | |
Dense, | |
EdgeInit, | |
NodeInit, | |
TensorInit, | |
TensorLayerNorm, | |
get_weight_init_by_string, | |
parse_update_info, | |
str2act, | |
str2basis, | |
) | |
def lmax_tensor_size(lmax): | |
return ((lmax + 1) ** 2) - 1 | |
def split_degree(tensor, lmax, dim=-1): # default to last dim | |
cumsum = 0 | |
tensors = [] | |
for i in range(1, lmax + 1): | |
count = lmax_tensor_size(i) - lmax_tensor_size(i - 1) | |
# Create slice object for the specified dimension | |
slc = [slice(None)] * tensor.ndim # Create list of slice(None) for all dims | |
slc[dim] = slice(cumsum, cumsum + count) # Replace desired dim with actual slice | |
tensors.append(tensor[tuple(slc)]) | |
cumsum += count | |
return tensors | |
class GATA(MessagePassing): | |
def __init__( | |
self, | |
n_atom_basis: int, | |
activation: Callable, | |
weight_init=nn.init.xavier_uniform_, | |
bias_init=nn.init.zeros_, | |
aggr="add", | |
node_dim=0, | |
epsilon: float = 1e-7, | |
layer_norm=False, | |
vector_norm=False, | |
cutoff=5.0, | |
num_heads=8, | |
dropout=0.0, | |
edge_updates=True, | |
last_layer=False, | |
scale_edge=True, | |
edge_ln="", | |
evec_dim=None, | |
emlp_dim=None, | |
sep_vecj=True, | |
lmax=1, | |
): | |
""" | |
Args: | |
n_atom_basis (int): Number of features to describe atomic environments. | |
activation (Callable): Activation function to be used. If None, no activation function is used. | |
weight_init (Callable): Weight initialization function. | |
bias_init (Callable): Bias initialization function. | |
aggr (str): Aggregation method ('add', 'mean' or 'max'). | |
node_dim (int): The axis along which to aggregate. | |
""" | |
super(GATA, self).__init__(aggr=aggr, node_dim=node_dim) | |
self.lmax = lmax | |
self.sep_vecj = sep_vecj | |
self.epsilon = epsilon | |
self.last_layer = last_layer | |
self.edge_updates = edge_updates | |
self.scale_edge = scale_edge | |
self.activation = activation | |
self.update_info = parse_update_info(edge_updates) | |
self.dropout = dropout | |
self.n_atom_basis = n_atom_basis | |
InitDense = partial(Dense, weight_init=weight_init, bias_init=bias_init) | |
self.gamma_s = nn.Sequential( | |
InitDense(n_atom_basis, n_atom_basis, activation=activation), | |
InitDense(n_atom_basis, 3 * n_atom_basis, activation=None), | |
) | |
self.num_heads = num_heads | |
self.q_w = InitDense(n_atom_basis, n_atom_basis, activation=None) | |
self.k_w = InitDense(n_atom_basis, n_atom_basis, activation=None) | |
self.gamma_v = nn.Sequential( | |
InitDense(n_atom_basis, n_atom_basis, activation=activation), | |
InitDense(n_atom_basis, 3 * n_atom_basis, activation=None), | |
) | |
self.phik_w_ra = InitDense( | |
n_atom_basis, | |
n_atom_basis, | |
activation=activation, | |
) | |
InitMLP = partial(MLP, weight_init=weight_init, bias_init=bias_init) | |
self.edge_vec_dim = n_atom_basis if evec_dim is None else evec_dim | |
self.edge_mlp_dim = n_atom_basis if emlp_dim is None else emlp_dim | |
if not self.last_layer and self.edge_updates: | |
if self.update_info["mlp"] or self.update_info["mlpa"]: | |
dims = [n_atom_basis, self.edge_mlp_dim, n_atom_basis] | |
else: | |
dims = [n_atom_basis, n_atom_basis] | |
self.edge_attr_up = InitMLP( | |
dims, activation=activation, last_activation=None if self.update_info["mlp"] else self.activation, norm=edge_ln | |
) | |
self.vecq_w = InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False) | |
if self.sep_vecj: | |
self.veck_w = nn.ModuleList( | |
[InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False) for i in range(self.lmax)] | |
) | |
else: | |
self.veck_w = InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False) | |
if self.update_info["lin_w"] > 0: | |
modules = [] | |
if self.update_info["lin_w"] % 10 == 2: | |
modules.append(self.activation) | |
self.lin_w_linear = InitDense( | |
self.edge_vec_dim, | |
n_atom_basis, | |
activation=None, | |
norm="layer" if self.update_info["lin_w"] == 2 else "", # lin_ln in original code but error | |
) | |
modules.append(self.lin_w_linear) | |
self.lin_w = nn.Sequential(*modules) | |
self.down_proj = nn.Identity() | |
self.cutoff = CosineCutoff(cutoff) | |
self._alpha = None | |
self.w_re = InitDense( | |
n_atom_basis, | |
n_atom_basis * 3, | |
None, | |
) | |
self.layernorm_ = layer_norm | |
self.vector_norm_ = vector_norm | |
if layer_norm: | |
self.layernorm = nn.LayerNorm(n_atom_basis) | |
else: | |
self.layernorm = nn.Identity() | |
if vector_norm: | |
self.tln = TensorLayerNorm(n_atom_basis, trainable=False) | |
else: | |
self.tln = nn.Identity() | |
self.reset_parameters() | |
def reset_parameters(self): | |
if self.layernorm_: | |
self.layernorm.reset_parameters() | |
if self.vector_norm_: | |
self.tln.reset_parameters() | |
for l in self.gamma_s: # noqa: E741 | |
l.reset_parameters() | |
self.q_w.reset_parameters() | |
self.k_w.reset_parameters() | |
for l in self.gamma_v: # noqa: E741 | |
l.reset_parameters() | |
# self.v_w.reset_parameters() | |
# self.out_w.reset_parameters() | |
self.w_re.reset_parameters() | |
if not self.last_layer and self.edge_updates: | |
self.edge_attr_up.reset_parameters() | |
self.vecq_w.reset_parameters() | |
if self.sep_vecj: | |
for w in self.veck_w: | |
w.reset_parameters() | |
else: | |
self.veck_w.reset_parameters() | |
if self.update_info["lin_w"] > 0: | |
self.lin_w_linear.reset_parameters() | |
def forward( | |
self, | |
edge_index, | |
s: torch.Tensor, | |
t: torch.Tensor, | |
dir_ij: torch.Tensor, | |
r_ij: torch.Tensor, | |
d_ij: torch.Tensor, | |
num_edges_expanded: torch.Tensor, | |
): | |
"""Compute interaction output.""" | |
s = self.layernorm(s) | |
t = self.tln(t) | |
q = self.q_w(s).reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads) | |
k = self.k_w(s).reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads) | |
x = self.gamma_s(s) | |
val = self.gamma_v(s) | |
f_ij = r_ij | |
r_ij_attn = self.phik_w_ra(r_ij) | |
r_ij = self.w_re(r_ij) | |
# propagate_type: (x: Tensor, ten: Tensor, q:Tensor, k:Tensor, val:Tensor, r_ij: Tensor, r_ij_attn: Tensor, d_ij:Tensor, dir_ij: Tensor, num_edges_expanded: Tensor) | |
su, tu = self.propagate( | |
edge_index=edge_index, | |
x=x, | |
q=q, | |
k=k, | |
val=val, | |
ten=t, | |
r_ij=r_ij, | |
r_ij_attn=r_ij_attn, | |
d_ij=d_ij, | |
dir_ij=dir_ij, | |
num_edges_expanded=num_edges_expanded, | |
) # , f_ij=f_ij | |
s = s + su | |
t = t + tu | |
if not self.last_layer and self.edge_updates: | |
vec = t | |
w1 = self.vecq_w(vec) | |
if self.sep_vecj: | |
vec_split = split_degree(vec, self.lmax, dim=1) | |
w_out = torch.concat([w(vec_split[i]) for i, w in enumerate(self.veck_w)], dim=1) | |
else: | |
w_out = self.veck_w(vec) | |
# edge_updater_type: (w1: Tensor, w2:Tensor, d_ij: Tensor, f_ij: Tensor) | |
df_ij = self.edge_updater(edge_index, w1=w1, w2=w_out, d_ij=dir_ij, f_ij=f_ij) | |
df_ij = f_ij + df_ij | |
self._alpha = None | |
return s, t, df_ij | |
else: | |
self._alpha = None | |
return s, t, f_ij | |
# return s, t | |
def message( | |
self, | |
edge_index, | |
x_i: torch.Tensor, | |
x_j: torch.Tensor, | |
q_i: torch.Tensor, | |
k_j: torch.Tensor, | |
val_j: torch.Tensor, | |
ten_j: torch.Tensor, | |
r_ij: torch.Tensor, | |
r_ij_attn: torch.Tensor, | |
d_ij: torch.Tensor, | |
dir_ij: torch.Tensor, | |
num_edges_expanded: torch.Tensor, | |
index: torch.Tensor, | |
ptr: OptTensor, | |
dim_size: Optional[int], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Compute message passing. | |
""" | |
r_ij_attn = r_ij_attn.reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads) | |
attn = (q_i * k_j * r_ij_attn).sum(dim=-1, keepdim=True) | |
attn = softmax(attn, index, ptr, dim_size) | |
# Normalize the attention scores | |
if self.scale_edge: | |
norm = torch.sqrt(num_edges_expanded.reshape(-1, 1, 1)) / np.sqrt(self.n_atom_basis) | |
else: | |
norm = 1.0 / np.sqrt(self.n_atom_basis) | |
attn = attn * norm | |
self._alpha = attn | |
attn = F.dropout(attn, p=self.dropout, training=self.training) | |
self_attn = attn * val_j.reshape(-1, self.num_heads, (self.n_atom_basis * 3) // self.num_heads) | |
SEA = self_attn.reshape(-1, 1, self.n_atom_basis * 3) | |
x = SEA + (r_ij.unsqueeze(1) * x_j * self.cutoff(d_ij.unsqueeze(-1).unsqueeze(-1))) | |
o_s, o_d, o_t = torch.split(x, self.n_atom_basis, dim=-1) | |
dmu = o_d * dir_ij[..., None] + o_t * ten_j | |
return o_s, dmu | |
def rej(vec, d_ij): | |
vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True) | |
return vec - vec_proj * d_ij.unsqueeze(2) | |
def edge_update(self, w1_i, w2_j, w3_j, d_ij, f_ij): | |
if self.sep_vecj: | |
vi = w1_i | |
vj = w2_j | |
vi_split = split_degree(vi, self.lmax, dim=1) | |
vj_split = split_degree(vj, self.lmax, dim=1) | |
d_ij_split = split_degree(d_ij, self.lmax, dim=1) | |
pairs = [] | |
for i in range(len(vi_split)): | |
if self.update_info["rej"]: | |
w1 = self.rej(vi_split[i], d_ij_split[i]) | |
w2 = self.rej(vj_split[i], -d_ij_split[i]) | |
pairs.append((w1, w2)) | |
else: | |
w1 = vi_split[i] | |
w2 = vj_split[i] | |
pairs.append((w1, w2)) | |
elif not self.update_info["rej"]: | |
w1 = w1_i | |
w2 = w2_j | |
pairs = [(w1, w2)] | |
else: | |
w1 = self.rej(w1_i, d_ij) | |
w2 = self.rej(w2_j, -d_ij) | |
pairs = [(w1, w2)] | |
w_dot_sum = None | |
for el in pairs: | |
w1, w2 = el | |
w_dot = (w1 * w2).sum(dim=1) | |
if w_dot_sum is None: | |
w_dot_sum = w_dot | |
else: | |
w_dot_sum = w_dot_sum + w_dot | |
w_dot = w_dot_sum | |
if self.update_info["lin_w"] > 0: | |
w_dot = self.lin_w(w_dot) | |
if self.update_info["gated"] == "gatedt": | |
w_dot = torch.tanh(w_dot) | |
elif self.update_info["gated"] == "gated": | |
w_dot = torch.sigmoid(w_dot) | |
elif self.update_info["gated"] == "act": | |
w_dot = self.activation(w_dot) | |
df_ij = self.edge_attr_up(f_ij) * w_dot | |
return df_ij | |
# noinspection PyMethodOverriding | |
def aggregate( | |
self, | |
features: Tuple[torch.Tensor, torch.Tensor], | |
index: torch.Tensor, | |
ptr: Optional[torch.Tensor], | |
dim_size: Optional[int], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
x, vec = features | |
x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) | |
vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) | |
return x, vec | |
def update(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: | |
return inputs | |
class EQFF(nn.Module): | |
def __init__( | |
self, | |
n_atom_basis: int, | |
activation: Callable, | |
epsilon: float = 1e-8, | |
weight_init=nn.init.xavier_uniform_, | |
bias_init=nn.init.zeros_, | |
vec_dim=None, | |
): | |
"""Equiavariant Feed Forward layer.""" | |
super(EQFF, self).__init__() | |
self.n_atom_basis = n_atom_basis | |
InitDense = partial(Dense, weight_init=weight_init, bias_init=bias_init) | |
vec_dim = n_atom_basis if vec_dim is None else vec_dim | |
context_dim = 2 * n_atom_basis | |
self.gamma_m = nn.Sequential( | |
InitDense(context_dim, n_atom_basis, activation=activation), | |
InitDense(n_atom_basis, 2 * n_atom_basis, activation=None), | |
) | |
self.w_vu = InitDense(n_atom_basis, vec_dim, activation=None, bias=False) | |
self.epsilon = epsilon | |
def reset_parameters(self): | |
self.w_vu.reset_parameters() | |
for l in self.gamma_m: # noqa: E741 | |
l.reset_parameters() | |
def forward(self, s, v): | |
"""Compute Equivariant Feed Forward output.""" | |
t_prime = self.w_vu(v) | |
t_prime_mag = torch.sqrt(torch.sum(t_prime**2, dim=-2, keepdim=True) + self.epsilon) | |
combined = [s, t_prime_mag] | |
combined_tensor = torch.cat(combined, dim=-1) | |
m12 = self.gamma_m(combined_tensor) | |
m_1, m_2 = torch.split(m12, self.n_atom_basis, dim=-1) | |
delta_v = m_2 * t_prime | |
s = s + m_1 | |
v = v + delta_v | |
return s, v | |
class GotenNet(nn.Module): | |
def __init__( | |
self, | |
hidden_channels: int = 128, | |
num_layers: int = 8, | |
radial_basis: Union[Callable, str] = "BesselBasis", | |
n_rbf: int = 20, | |
cutoff: float = 5.0, | |
activation: Optional[Union[Callable, str]] = F.silu, | |
max_z: int = 100, | |
epsilon: float = 1e-8, | |
weight_init=nn.init.xavier_uniform_, | |
bias_init=nn.init.zeros_, | |
int_layer_norm=False, | |
int_vector_norm=False, | |
before_mixing_layer_norm=False, | |
after_mixing_layer_norm=False, | |
num_heads=8, | |
attn_dropout=0.0, | |
edge_updates=True, | |
scale_edge=True, | |
lmax=2, | |
aggr="add", | |
edge_ln="", | |
evec_dim=None, | |
emlp_dim=None, | |
sep_int_vec=True, | |
): | |
""" | |
Representation for GotenNet | |
""" | |
super(GotenNet, self).__init__() | |
self.scale_edge = scale_edge | |
if type(weight_init) == str: # noqa: E721 | |
# print(f"Using {weight_init} weight initialization") | |
weight_init = get_weight_init_by_string(weight_init) | |
if type(bias_init) == str: # noqa: E721 | |
bias_init = get_weight_init_by_string(bias_init) | |
if type(activation) is str: | |
activation = str2act(activation) | |
self.n_atom_basis = self.hidden_dim = hidden_channels | |
self.n_interactions = num_layers | |
self.cutoff = cutoff | |
self.neighbor_embedding = NodeInit( | |
[self.hidden_dim // 2, self.hidden_dim], | |
n_rbf, | |
self.cutoff, | |
max_z=max_z, | |
weight_init=weight_init, | |
bias_init=bias_init, | |
concat=False, | |
proj_ln="layer", | |
activation=activation, | |
) | |
self.edge_embedding = EdgeInit( | |
n_rbf, [self.hidden_dim // 2, self.hidden_dim], weight_init=weight_init, bias_init=bias_init, proj_ln="" | |
) | |
radial_basis = str2basis(radial_basis) | |
self.radial_basis = radial_basis(cutoff=self.cutoff, n_rbf=n_rbf) | |
self.embedding = nn.Embedding(max_z, self.n_atom_basis, padding_idx=0) | |
self.tensor_init = TensorInit(l=lmax) | |
self.gata = nn.ModuleList( | |
[ | |
GATA( | |
n_atom_basis=self.n_atom_basis, | |
activation=activation, | |
aggr=aggr, | |
weight_init=weight_init, | |
bias_init=bias_init, | |
layer_norm=int_layer_norm, | |
vector_norm=int_vector_norm, | |
cutoff=self.cutoff, | |
epsilon=epsilon, | |
num_heads=num_heads, | |
dropout=attn_dropout, | |
edge_updates=edge_updates, | |
last_layer=(i == self.n_interactions - 1), | |
scale_edge=scale_edge, | |
edge_ln=edge_ln, | |
evec_dim=evec_dim, | |
emlp_dim=emlp_dim, | |
sep_vecj=sep_int_vec, | |
lmax=lmax, | |
) | |
for i in range(self.n_interactions) | |
] | |
) | |
self.eqff = nn.ModuleList( | |
[ | |
EQFF(n_atom_basis=self.n_atom_basis, activation=activation, epsilon=epsilon, weight_init=weight_init, bias_init=bias_init) | |
for i in range(self.n_interactions) | |
] | |
) | |
# Extra layer norms for the scalar quantities | |
if before_mixing_layer_norm: | |
self.before_mixing_ln = nn.LayerNorm(self.n_atom_basis) | |
else: | |
self.before_mixing_ln = nn.Identity() | |
if after_mixing_layer_norm: | |
self.after_mixing_ln = nn.LayerNorm(self.n_atom_basis) | |
else: | |
self.after_mixing_ln = nn.Identity() | |
self.reset_parameters() | |
def reset_parameters(self): | |
self.edge_embedding.reset_parameters() | |
self.neighbor_embedding.reset_parameters() | |
for l in self.gata: # noqa: E741 | |
l.reset_parameters() | |
for l in self.eqff: # noqa: E741 | |
l.reset_parameters() | |
if not isinstance(self.before_mixing_ln, nn.Identity): | |
self.before_mixing_ln.reset_parameters() | |
if not isinstance(self.after_mixing_ln, nn.Identity): | |
self.after_mixing_ln.reset_parameters() | |
def forward(self, z, pos, cutoff_edge_index, cutoff_edge_distance, cutoff_edge_vec): | |
q = self.embedding(z)[:] | |
edge_attr = self.radial_basis(cutoff_edge_distance) | |
q = self.neighbor_embedding(z, q, cutoff_edge_index, cutoff_edge_distance, edge_attr) | |
edge_attr = self.edge_embedding(cutoff_edge_index, edge_attr, q) | |
mask = cutoff_edge_index[0] != cutoff_edge_index[1] | |
# direction vector | |
dist = torch.norm(cutoff_edge_vec[mask], dim=1).unsqueeze(1) | |
cutoff_edge_vec[mask] = cutoff_edge_vec[mask] / dist | |
cutoff_edge_vec = self.tensor_init(cutoff_edge_vec) | |
equi_dim = ((self.tensor_init.l + 1) ** 2) - 1 | |
# count number of edges for each node | |
num_edges = scatter(torch.ones_like(cutoff_edge_distance), cutoff_edge_index[0], dim=0, reduce="sum") | |
# the shape of num edges is [num_nodes, 1], we want to expand this to [num_edges, 1] | |
# Map num_edges back to the shape of attn using cutoff_edge_index | |
num_edges_expanded = num_edges[cutoff_edge_index[0]] | |
qs = q.shape | |
mu = torch.zeros((qs[0], equi_dim, qs[1]), device=q.device) | |
q.unsqueeze_(1) | |
layer_outputs = [] | |
for i, (interaction, mixing) in enumerate(zip(self.gata, self.eqff)): | |
q, mu, edge_attr = interaction( | |
cutoff_edge_index, | |
q, | |
mu, | |
dir_ij=cutoff_edge_vec, | |
r_ij=edge_attr, | |
d_ij=cutoff_edge_distance, | |
num_edges_expanded=num_edges_expanded, | |
) | |
q = self.before_mixing_ln(q) | |
q, mu = mixing(q, mu) | |
q = self.after_mixing_ln(q) | |
# Collect all scalars for inter-layer read-outs | |
layer_outputs.append(q.squeeze(1)) | |
# q = q.squeeze(1) | |
layer_outputs = torch.stack(layer_outputs, dim=-1) | |
output_dict = {} | |
output_dict["embedding_0"] = layer_outputs.unsqueeze(2) # [n_nodes, n_features, dimension of irrep, n_layers] | |
# This is a scalar so a single irrep | |
return output_dict | |