ipd commited on
Commit
dded3ca
·
1 Parent(s): 3915777

add pos-egnn

Browse files
models/fm4m.py CHANGED
@@ -32,6 +32,7 @@ from models.selfies_ted.load import SELFIES as bart
32
  from models.mhg_model import load as mhg
33
  from models.smi_ted.smi_ted_light.load import load_smi_ted
34
  from models.smi_ssed.load import load_smi_ssed
 
35
 
36
  import mordred
37
  from mordred import Calculator, descriptors
@@ -60,6 +61,7 @@ def avail_models_data():
60
  {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
61
  {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"},
62
  {"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"},
 
63
  {"Name": "smi-ssed", "Model Name": "SMI-SSED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}]
64
 
65
 
@@ -71,6 +73,7 @@ def avail_models(raw=False):
71
  {"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"},
72
  {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality"},
73
  {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"},
 
74
  {"Name": "Mordred", "Model Name": "Mordred","Description": "Baseline: A descriptor-calculation software application that can calculate more than 1800 two- and three-dimensional descriptors"},
75
  {"Name": "MorganFingerprint", "Model Name": "MorganFingerprint","Description": "Baseline: Circular atom environments based descriptor"}
76
  ]
@@ -150,6 +153,7 @@ def reset():
150
  {"Name": "mol-xl", "Description": "MolFormer model for string based SMILES modality",
151
  "Timestamp": "2024-06-21 12:35:56"},
152
  {"Name": "mhg", "Description": "MHG", "Timestamp": "2024-07-10 00:09:42"},
 
153
  {"Name": "spec-gru", "Description": "Spectrum modality with GRU", "Timestamp": "2024-07-10 00:09:42"},
154
  {"Name": "spec-lstm", "Description": "Spectrum modality with LSTM", "Timestamp": "2024-07-10 00:09:54"},
155
  {"Name": "3d-vae", "Description": "VAE model for 3D atom positions", "Timestamp": "2024-07-10 00:10:08"}]
@@ -204,7 +208,9 @@ avail_models_data()
204
 
205
 
206
  def get_representation(train_data,test_data,model_type, return_tensor=True):
207
- alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted", "SMI-SSED": "smi-ssed"}
 
 
208
  if model_type in alias.keys():
209
  model_type = alias[model_type]
210
 
@@ -227,6 +233,12 @@ def get_representation(train_data,test_data,model_type, return_tensor=True):
227
  x_batch = model.encode(train_data, return_tensor=return_tensor)
228
  x_batch_test = model.encode(test_data, return_tensor=return_tensor)
229
 
 
 
 
 
 
 
230
  elif model_type == "smi-ted":
231
  model = load_smi_ted(folder='../models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
232
  with torch.no_grad():
@@ -314,7 +326,9 @@ def get_representation(train_data,test_data,model_type, return_tensor=True):
314
 
315
  def single_modal(model,dataset=None, downstream_model=None, params=None, x_train=None, x_test=None, y_train=None, y_test=None):
316
  print(model)
317
- alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted", "SMI-SSED": "smi-ssed"}
 
 
318
  data = avail_models(raw=True)
319
  df = pd.DataFrame(data)
320
  #print(list(df["Name"].values))
@@ -383,13 +397,14 @@ def single_modal(model,dataset=None, downstream_model=None, params=None, x_train
383
  print(f'x_batch_test Nan index: {nan_indices}')
384
  print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}')
385
 
386
- print(f" Calculating ROC AUC Score ...")
387
 
388
  if downstream_model == "XGBClassifier":
389
  if params == None:
390
  xgb_predict_concat = XGBClassifier()
391
  else:
392
  xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
 
393
  xgb_predict_concat.fit(x_batch, y_batch)
394
 
395
  y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
@@ -628,7 +643,10 @@ def multi_modal(model_list,dataset=None, downstream_model=None,params=None, x_tr
628
  df = pd.DataFrame(data)
629
  list(df["Name"].values)
630
 
631
- alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl","SMI-TED":"smi-ted","SMI-SSED":"smi-ssed", "Mordred": "Mordred", "MorganFingerprint": "MorganFingerprint"}
 
 
 
632
  #if set(model_list).issubset(list(df["Name"].values)):
633
  if set(model_list).issubset(list(alias.keys())):
634
  for i, model in enumerate(model_list):
@@ -717,7 +735,7 @@ def multi_modal(model_list,dataset=None, downstream_model=None,params=None, x_tr
717
 
718
  print("Generating latent plots : Done")
719
 
720
- print(f" Calculating ROC AUC Score ...")
721
 
722
 
723
  if downstream_model == "XGBClassifier":
 
32
  from models.mhg_model import load as mhg
33
  from models.smi_ted.smi_ted_light.load import load_smi_ted
34
  from models.smi_ssed.load import load_smi_ssed
35
+ from models.pos_egnn.load import POSEGNN as pos
36
 
37
  import mordred
38
  from mordred import Calculator, descriptors
 
61
  {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
62
  {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"},
63
  {"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"},
64
+ {"Name": "pos-egnn", "Model Name": "POS-EGNN","Description": "3D atom position model", "Timestamp": "2025-04-04 00:11:42"},
65
  {"Name": "smi-ssed", "Model Name": "SMI-SSED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}]
66
 
67
 
 
73
  {"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"},
74
  {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality"},
75
  {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"},
76
+ {"Name": "pos", "Model Name": "POS-EGNN","Description": "3D atom position model"},
77
  {"Name": "Mordred", "Model Name": "Mordred","Description": "Baseline: A descriptor-calculation software application that can calculate more than 1800 two- and three-dimensional descriptors"},
78
  {"Name": "MorganFingerprint", "Model Name": "MorganFingerprint","Description": "Baseline: Circular atom environments based descriptor"}
79
  ]
 
153
  {"Name": "mol-xl", "Description": "MolFormer model for string based SMILES modality",
154
  "Timestamp": "2024-06-21 12:35:56"},
155
  {"Name": "mhg", "Description": "MHG", "Timestamp": "2024-07-10 00:09:42"},
156
+ {"Name": "pos", "Description": "POS-EGNN", "Timestamp": "2024-07-10 00:09:42"},
157
  {"Name": "spec-gru", "Description": "Spectrum modality with GRU", "Timestamp": "2024-07-10 00:09:42"},
158
  {"Name": "spec-lstm", "Description": "Spectrum modality with LSTM", "Timestamp": "2024-07-10 00:09:54"},
159
  {"Name": "3d-vae", "Description": "VAE model for 3D atom positions", "Timestamp": "2024-07-10 00:10:08"}]
 
208
 
209
 
210
  def get_representation(train_data,test_data,model_type, return_tensor=True):
211
+ #alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted", "SMI-SSED": "smi-ssed"}
212
+ alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "SMI-TED": "smi-ted", "POS-EGNN": "pos", "MolFormer": "mol-xl", "Molformer": "mol-xl",
213
+ }
214
  if model_type in alias.keys():
215
  model_type = alias[model_type]
216
 
 
233
  x_batch = model.encode(train_data, return_tensor=return_tensor)
234
  x_batch_test = model.encode(test_data, return_tensor=return_tensor)
235
 
236
+ elif model_type == "pos":
237
+ model = pos()
238
+ model.load()
239
+ x_batch = model.encode(train_data, return_tensor=return_tensor)
240
+ x_batch_test = model.encode(test_data, return_tensor=return_tensor)
241
+
242
  elif model_type == "smi-ted":
243
  model = load_smi_ted(folder='../models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
244
  with torch.no_grad():
 
326
 
327
  def single_modal(model,dataset=None, downstream_model=None, params=None, x_train=None, x_test=None, y_train=None, y_test=None):
328
  print(model)
329
+ #alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted", "SMI-SSED": "smi-ssed"}
330
+ alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "SMI-TED": "smi-ted", "POS-EGNN": "pos", "MolFormer": "mol-xl", "Molformer": "mol-xl",
331
+ }
332
  data = avail_models(raw=True)
333
  df = pd.DataFrame(data)
334
  #print(list(df["Name"].values))
 
397
  print(f'x_batch_test Nan index: {nan_indices}')
398
  print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}')
399
 
400
+ print(f"Model selected: {downstream_model} - Calculating ROC AUC Score ...")
401
 
402
  if downstream_model == "XGBClassifier":
403
  if params == None:
404
  xgb_predict_concat = XGBClassifier()
405
  else:
406
  xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
407
+
408
  xgb_predict_concat.fit(x_batch, y_batch)
409
 
410
  y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
 
643
  df = pd.DataFrame(data)
644
  list(df["Name"].values)
645
 
646
+ #alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl","SMI-TED":"smi-ted","SMI-SSED":"smi-ssed", "Mordred": "Mordred", "MorganFingerprint": "MorganFingerprint"}
647
+ alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "SMI-TED": "smi-ted", "POS-EGNN": "pos", "MolFormer": "mol-xl",
648
+ "Molformer": "mol-xl", "Mordred": "Mordred", "MorganFingerprint": "MorganFingerprint"
649
+ }
650
  #if set(model_list).issubset(list(df["Name"].values)):
651
  if set(model_list).issubset(list(alias.keys())):
652
  for i, model in enumerate(model_list):
 
735
 
736
  print("Generating latent plots : Done")
737
 
738
+ print(f"Model selected: {downstream_model} - Calculating ROC AUC Score ...")
739
 
740
 
741
  if downstream_model == "XGBClassifier":
models/pos_egnn/.gitignore ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ !/morningstar/data/
3
+ !/tests/data
4
+ .ruff*
5
+ lightning_logs/
6
+ .DS_Store
7
+ .vscode
8
+
9
+ # Useful scratch subfolder
10
+ _scratch
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ pip-wheel-metadata/
35
+ share/python-wheels/
36
+ *.egg-info/
37
+ .installed.cfg
38
+ *.egg
39
+ MANIFEST
40
+
41
+ # PyInstaller
42
+ # Usually these files are written by a python script from a template
43
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
44
+ *.manifest
45
+ *.spec
46
+
47
+ # Installer logs
48
+ pip-log.txt
49
+ pip-delete-this-directory.txt
50
+
51
+ # Unit test / coverage reports
52
+ htmlcov/
53
+ .tox/
54
+ .nox/
55
+ .coverage
56
+ .coverage.*
57
+ .cache
58
+ nosetests.xml
59
+ coverage.xml
60
+ *.cover
61
+ *.py,cover
62
+ .hypothesis/
63
+ .pytest_cache/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ *.Pipfile.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
141
+
142
+ temp_data/
143
+ .aim/
144
+ csv_logs/
145
+
146
+ # Dataset files
147
+ *data.mdb
148
+ *.zip
149
+ *.tar.gz
150
+
151
+ /plots
152
+ /checkpoints
153
+ /data*
154
+
155
+ *.err
156
+ *.out
157
+ *.ckpt
158
+ output.xyz
models/pos_egnn/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Position-based Equivariant Graph Neural Network (`pos-egnn`)
2
+ This repository contains PyTorch source code for loading and performing inference using the `pos-egnn`, a foundation model for Chemistry and Materials.
3
+
4
+ **GitHub**: https://github.com/ibm/materials
5
+
6
+ **HuggingFace**: https://huggingface.co/ibm-research/materials.pos-egnn
7
+
8
+ <p align="center">
9
+ <img src="../../img/posegnn.svg">
10
+ </p>
11
+
12
+ ## Introduction
13
+ We present `pos-egnn`, a Position-based Equivariant Graph Neural Network foundation model for Chemistry and Materials. The model was pre-trained on 1.4M samples (i.e., 90%) from the Materials Project Trajectory (MPtrj) dataset to predict energies, forces and stress. `pos-egnn` can be used as a machine-learning potential, as a feature extractor, or can be fine-tuned for specific downstream tasks.
14
+
15
+ Besides the model weigths `pos-egnn.v1-6M.pt` (download from [HuggingFace](https://huggingface.co/ibm-research/materials.pos-egnn)), we also provide an `example.ipynb` notebook (download from [GitHub](https://github.com/ibm/materials)), which demonstrates how to perform inference, feature extraction and molecular dynamics simulation with the model.
16
+
17
+ For more information, please reach out to [email protected] and/or [email protected]
18
+
19
+ ## Table of Contents
20
+ 1. [**Getting Started**](#getting-started)
21
+ 2. [**Example**](#example)
22
+
23
+ ## Getting Started
24
+ Follow these steps to replicate our environment and install the necessary libraries:
25
+
26
+ First, make sure to have Python 3.11 installed. Then, to create the virtual environment, run the following commands:
27
+
28
+ ```bash
29
+ python3.11 -m venv env
30
+ source env/bin/activate
31
+ ```
32
+
33
+ Run the following command to install the library dependencies.
34
+
35
+ ```bash
36
+ pip install -r requirements.txt
37
+ ```
38
+
39
+ ## Example
40
+ Please refer to the `example.ipynb` for a step-by-step demonstration on how to perform inference, feature extraction and molecular dynamics simulation with the model.
models/pos_egnn/example.ipynb ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# POS-EGNN "
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## Setup"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# Uncomment to install notebook-only dependencies\n",
24
+ "# !pip install nglview ipywidgets"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "data": {
34
+ "application/vnd.jupyter.widget-view+json": {
35
+ "model_id": "4bac12c6048044898065f0778d95caeb",
36
+ "version_major": 2,
37
+ "version_minor": 0
38
+ },
39
+ "text/plain": []
40
+ },
41
+ "metadata": {},
42
+ "output_type": "display_data"
43
+ }
44
+ ],
45
+ "source": [
46
+ "import nglview as nv\n",
47
+ "import torch\n",
48
+ "from ase import units\n",
49
+ "from ase.io import read\n",
50
+ "from ase.md.langevin import Langevin\n",
51
+ "\n",
52
+ "from posegnn.calculator import PosEGNNCalculator"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 3,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "device = \"cpu\"\n",
62
+ "torch.set_float32_matmul_precision(\"high\")"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {},
68
+ "source": [
69
+ "## Feature Extraction"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 4,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# Please download checkpoint from https://huggingface.co/ibm-research/materials.pos-egnn\n",
79
+ "calculator = PosEGNNCalculator(\"pos-egnn.v1-6M.ckpt\", device=device, compute_stress=False)\n",
80
+ "atoms = read(\"inputs/3BPA.xyz\", index=0)\n",
81
+ "atoms.calc = calculator"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 5,
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "data": {
91
+ "text/plain": [
92
+ "torch.Size([27, 256])"
93
+ ]
94
+ },
95
+ "execution_count": 5,
96
+ "metadata": {},
97
+ "output_type": "execute_result"
98
+ }
99
+ ],
100
+ "source": [
101
+ "embeddings = atoms.get_invariant_embeddings()\n",
102
+ "embeddings.shape"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {},
108
+ "source": [
109
+ "## Inference"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 6,
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "energy = atoms.get_potential_energy()\n",
119
+ "forces = atoms.get_forces()"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 7,
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "data": {
129
+ "text/plain": [
130
+ "(array([-175.05188], dtype=float32),\n",
131
+ " array([[ 0.34280202, -0.41967863, 0.7246248 ],\n",
132
+ " [-0.86854756, -0.12186409, -2.305024 ],\n",
133
+ " [ 0.26306945, 0.06607065, 0.85476065],\n",
134
+ " [-0.230737 , 0.02304646, -0.5161394 ],\n",
135
+ " [-0.43901953, 2.7678285 , -0.70297724],\n",
136
+ " [ 0.03933215, -0.50390136, 1.0451801 ],\n",
137
+ " [ 0.37628424, -2.2708364 , -0.7662437 ],\n",
138
+ " [ 0.25884533, -1.6086004 , -0.08700082],\n",
139
+ " [-0.09319548, -0.24666801, -0.48069426],\n",
140
+ " [ 0.01849201, 1.001767 , 2.151208 ],\n",
141
+ " [-0.46055827, 1.3630681 , -0.38470453],\n",
142
+ " [ 0.38605827, -0.32170498, 0.6269282 ],\n",
143
+ " [-0.29103595, 0.22509174, -0.26729944],\n",
144
+ " [ 1.3340423 , -1.727819 , -0.08812339],\n",
145
+ " [-0.96442086, 1.1447092 , 1.0665402 ],\n",
146
+ " [-0.74679977, 0.56782806, 0.03098067],\n",
147
+ " [ 0.42040402, 0.7405614 , -0.6953748 ],\n",
148
+ " [-0.25654212, 0.25282693, 0.25414664],\n",
149
+ " [ 2.0051584 , -0.38257334, -0.26911467],\n",
150
+ " [-0.00743119, 0.43786597, -0.27683535],\n",
151
+ " [ 0.64563835, -0.5602143 , -0.11240276],\n",
152
+ " [-0.00601408, -1.03808 , 0.23635206],\n",
153
+ " [-0.04149596, 0.02955294, -0.06748012],\n",
154
+ " [-0.86066115, 0.00299245, 0.06783121],\n",
155
+ " [-0.05461264, 0.05352221, -0.06844339],\n",
156
+ " [-0.26291835, 0.58347785, 0.19614606],\n",
157
+ " [-0.50613666, -0.05826864, -0.16684091]], dtype=float32))"
158
+ ]
159
+ },
160
+ "execution_count": 7,
161
+ "metadata": {},
162
+ "output_type": "execute_result"
163
+ }
164
+ ],
165
+ "source": [
166
+ "energy, forces"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "metadata": {},
172
+ "source": [
173
+ "## Molecular Dynamics Simulation"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": 11,
179
+ "metadata": {},
180
+ "outputs": [
181
+ {
182
+ "data": {
183
+ "text/plain": [
184
+ "True"
185
+ ]
186
+ },
187
+ "execution_count": 11,
188
+ "metadata": {},
189
+ "output_type": "execute_result"
190
+ }
191
+ ],
192
+ "source": [
193
+ "dyn = Langevin(atoms=atoms, friction=0.005, temperature_K=310, timestep=0.5 * units.fs)\n",
194
+ "\n",
195
+ "def write_frame():\n",
196
+ " dyn.atoms.write(\"output.xyz\", append=True)\n",
197
+ "\n",
198
+ "dyn.attach(write_frame, interval=5)\n",
199
+ "dyn.run(500)"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 12,
205
+ "metadata": {},
206
+ "outputs": [
207
+ {
208
+ "data": {
209
+ "application/vnd.jupyter.widget-view+json": {
210
+ "model_id": "139d7605baca43d79ea515d3454d9941",
211
+ "version_major": 2,
212
+ "version_minor": 0
213
+ },
214
+ "text/plain": [
215
+ "NGLWidget(max_frame=234)"
216
+ ]
217
+ },
218
+ "metadata": {},
219
+ "output_type": "display_data"
220
+ }
221
+ ],
222
+ "source": [
223
+ "traj = read('output.xyz', index=slice(None))\n",
224
+ "view = nv.show_asetraj(traj)\n",
225
+ "display(view)"
226
+ ]
227
+ }
228
+ ],
229
+ "metadata": {
230
+ "kernelspec": {
231
+ "display_name": "py311",
232
+ "language": "python",
233
+ "name": "python3"
234
+ },
235
+ "language_info": {
236
+ "codemirror_mode": {
237
+ "name": "ipython",
238
+ "version": 3
239
+ },
240
+ "file_extension": ".py",
241
+ "mimetype": "text/x-python",
242
+ "name": "python",
243
+ "nbconvert_exporter": "python",
244
+ "pygments_lexer": "ipython3",
245
+ "version": "3.11.11"
246
+ }
247
+ },
248
+ "nbformat": 4,
249
+ "nbformat_minor": 2
250
+ }
models/pos_egnn/inputs/3BPA.xyz ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 27
2
+ Lattice="50.0 0.0 0.0 0.0 50.0 0.0 0.0 0.0 50.0" Properties=species:S:1:pos:R:3
3
+ C 0.32656990 -1.01286015 0.72107275
4
+ C -0.19461567 0.25830309 0.83413890
5
+ C 0.79249430 -1.60244652 1.90438013
6
+ H 0.43514812 -1.59784402 -0.17897327
7
+ C -0.06984110 0.97877174 2.04804512
8
+ O -0.82551374 0.94764149 -0.23208227
9
+ N -0.29006580 2.37382353 1.86234713
10
+ N 0.38837699 0.47664821 3.17063941
11
+ H -0.98416331 2.63164516 1.17783211
12
+ H -0.35613357 2.91935548 2.65951281
13
+ C 0.72925588 -0.88036074 3.09652633
14
+ H 0.89111308 -1.33212550 4.04892841
15
+ H 1.21249480 -2.61699312 1.95761271
16
+ C -1.55059042 0.13992010 -1.19894310
17
+ C -2.93143480 -0.31868852 -0.84882155
18
+ H -0.88027102 -0.72270141 -1.50495635
19
+ H -1.71075243 0.76854244 -2.05655313
20
+ C -3.36709795 -1.63868610 -0.72393961
21
+ C -3.88193920 0.73797993 -0.59066374
22
+ C -4.66143774 -1.92714685 -0.27081530
23
+ H -2.71536973 -2.46157596 -0.92494411
24
+ C -5.57704224 -0.87827710 -0.11327951
25
+ H -4.95294558 -2.96156288 -0.08227495
26
+ C -5.17381471 0.43421272 -0.29908870
27
+ H -6.59525001 -1.10219470 0.20624618
28
+ H -5.92759350 1.19353463 -0.20195647
29
+ H -3.44855719 1.74786696 -0.56747187
models/pos_egnn/load.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .posegnn.calculator import PosEGNNCalculator
3
+ import ase
4
+ from ase import Atoms
5
+ from rdkit import Chem
6
+ from rdkit.Chem import AllChem
7
+ import pandas as pd
8
+ import numpy as np
9
+ from huggingface_hub import hf_hub_download
10
+ from tqdm import tqdm
11
+
12
+ torch.set_float32_matmul_precision("high")
13
+
14
+ def smiles_to_atoms(smiles):
15
+ mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
16
+ AllChem.EmbedMolecule(mol)
17
+ ase_atoms = ase.Atoms(
18
+ numbers=[
19
+ atom.GetAtomicNum() for atom in mol.GetAtoms()
20
+ ],
21
+ positions=mol.GetConformer().GetPositions()
22
+ )
23
+ return ase_atoms
24
+
25
+ class POSEGNN():
26
+ def __init__(self, use_gpu=True):
27
+ device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
28
+ self.device = device
29
+ self.calculator = None
30
+
31
+ def load(self, checkpoint=None):
32
+ repo_id = "ibm-research/materials.pos-egnn"
33
+ filename = "pytorch_model.bin"
34
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
35
+ self.calculator = PosEGNNCalculator(model_path, device=self.device, compute_stress=False)
36
+
37
+ def encode(self, smiles_list, return_tensor=False, batch_size=32):
38
+ results = []
39
+
40
+ # make batch-wise processing with progress bar
41
+ for i in tqdm(range(0, len(smiles_list), batch_size), desc="Batch Encoding"):
42
+ batch = smiles_list[i:i+batch_size]
43
+ atoms_batch = []
44
+
45
+ for smiles in batch:
46
+ try:
47
+ atoms = smiles_to_atoms(smiles)
48
+ atoms.calc = self.calculator
49
+ atoms_batch.append(atoms)
50
+ except Exception as e:
51
+ print(f"Skipping {smiles}: {e}")
52
+
53
+ if atoms_batch:
54
+ embeddings = [a.get_invariant_embeddings().mean(dim=0).cpu() for a in atoms_batch]
55
+ batch_tensor = torch.stack(embeddings)
56
+ results.append(batch_tensor)
57
+
58
+ if not results:
59
+ raise RuntimeError("No valid SMILES could be processed.")
60
+
61
+ all_embeddings = torch.cat(results, dim=0)
62
+ return all_embeddings if return_tensor else pd.DataFrame(all_embeddings.numpy())
63
+
models/pos_egnn/posegnn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import calculator, encoder, model, ops, utils
2
+
3
+ __all__ = ["calculator", "encoder", "model", "ops", "utils"]
models/pos_egnn/posegnn/calculator.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ase import Atoms
4
+ from ase.calculators.calculator import Calculator, all_changes
5
+ from ase.data import atomic_numbers
6
+ from ase.stress import full_3x3_to_voigt_6_stress
7
+ from torch_geometric.data.data import Data
8
+
9
+ from .model import PosEGNN
10
+
11
+
12
+ class PosEGNNCalculator(Calculator):
13
+ def __init__(self, checkpoint: str, device: str, compute_stress: bool = True, **kwargs):
14
+ Calculator.__init__(self, **kwargs)
15
+
16
+ checkpoint_dict = torch.load(checkpoint, weights_only=True, map_location=device)
17
+
18
+ self.model = PosEGNN(checkpoint_dict["config"])
19
+ self.model.load_state_dict(checkpoint_dict["state_dict"], strict=True)
20
+ self.model.eval()
21
+
22
+ self.model.to(device)
23
+ self.model.eval()
24
+
25
+ self.implemented_properties = ["energy", "forces"]
26
+ self.implemented_properties += ["stress"] if compute_stress else []
27
+ self.device = device
28
+ self.compute_stress = compute_stress
29
+
30
+ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
31
+ Calculator.calculate(self, atoms)
32
+ self.results = {}
33
+ data = self._build_data(atoms)
34
+ out = self.model.compute_properties(data, compute_stress=self.compute_stress)
35
+
36
+ # Decoder Forward
37
+ self.results = {
38
+ "energy": out["total_energy"].cpu().detach().numpy(),
39
+ "forces": out["force"].cpu().detach().numpy()
40
+ }
41
+ if self.compute_stress:
42
+ self.results.update({
43
+ "stress": full_3x3_to_voigt_6_stress(out["stress"].cpu().detach().numpy())
44
+ })
45
+
46
+ def _build_data(self, atoms):
47
+ z = torch.tensor(np.array([atomic_numbers[symbol] for symbol in atoms.symbols]), device=self.device)
48
+ box = torch.tensor(atoms.get_cell().tolist(), device=self.device).unsqueeze(0).float()
49
+ pos = torch.tensor(atoms.get_positions().tolist(), device=self.device).float()
50
+ batch = torch.zeros(len(z), device=self.device).long()
51
+ ptr = torch.zeros(1, device=self.device).long()
52
+ return Data(z=z, pos=pos, box=box, batch=batch, num_graphs=1, ptr=ptr)
53
+
54
+
55
+ def get_invariant_embeddings(self):
56
+ if self.calc is None:
57
+ raise RuntimeError("No calculator is set.")
58
+ else:
59
+ data = self.calc._build_data(self)
60
+ with torch.no_grad():
61
+ embeddings = self.calc.model(data)["embedding_0"][..., 1].squeeze(2)
62
+ return embeddings
63
+
64
+
65
+ Atoms.get_invariant_embeddings = get_invariant_embeddings
models/pos_egnn/posegnn/encoder.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code was adapted from https://github.com/sarpaykent/GotenNet
3
+ Copyright (c) 2025 Sarp Aykent
4
+ MIT License
5
+
6
+ GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks
7
+ Sarp Aykent and Tian Xia
8
+ https://openreview.net/pdf?id=5wxCQDtbMo
9
+ """
10
+
11
+ from functools import partial
12
+ from typing import Callable, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch_geometric.nn import MessagePassing
19
+ from torch_geometric.typing import OptTensor
20
+ from torch_geometric.utils import scatter, softmax
21
+
22
+ from .ops import (
23
+ MLP,
24
+ CosineCutoff,
25
+ Dense,
26
+ EdgeInit,
27
+ NodeInit,
28
+ TensorInit,
29
+ TensorLayerNorm,
30
+ get_weight_init_by_string,
31
+ parse_update_info,
32
+ str2act,
33
+ str2basis,
34
+ )
35
+
36
+
37
+ def lmax_tensor_size(lmax):
38
+ return ((lmax + 1) ** 2) - 1
39
+
40
+
41
+ def split_degree(tensor, lmax, dim=-1): # default to last dim
42
+ cumsum = 0
43
+ tensors = []
44
+ for i in range(1, lmax + 1):
45
+ count = lmax_tensor_size(i) - lmax_tensor_size(i - 1)
46
+ # Create slice object for the specified dimension
47
+ slc = [slice(None)] * tensor.ndim # Create list of slice(None) for all dims
48
+ slc[dim] = slice(cumsum, cumsum + count) # Replace desired dim with actual slice
49
+ tensors.append(tensor[tuple(slc)])
50
+ cumsum += count
51
+ return tensors
52
+
53
+
54
+ class GATA(MessagePassing):
55
+ def __init__(
56
+ self,
57
+ n_atom_basis: int,
58
+ activation: Callable,
59
+ weight_init=nn.init.xavier_uniform_,
60
+ bias_init=nn.init.zeros_,
61
+ aggr="add",
62
+ node_dim=0,
63
+ epsilon: float = 1e-7,
64
+ layer_norm=False,
65
+ vector_norm=False,
66
+ cutoff=5.0,
67
+ num_heads=8,
68
+ dropout=0.0,
69
+ edge_updates=True,
70
+ last_layer=False,
71
+ scale_edge=True,
72
+ edge_ln="",
73
+ evec_dim=None,
74
+ emlp_dim=None,
75
+ sep_vecj=True,
76
+ lmax=1,
77
+ ):
78
+ """
79
+ Args:
80
+ n_atom_basis (int): Number of features to describe atomic environments.
81
+ activation (Callable): Activation function to be used. If None, no activation function is used.
82
+ weight_init (Callable): Weight initialization function.
83
+ bias_init (Callable): Bias initialization function.
84
+ aggr (str): Aggregation method ('add', 'mean' or 'max').
85
+ node_dim (int): The axis along which to aggregate.
86
+ """
87
+ super(GATA, self).__init__(aggr=aggr, node_dim=node_dim)
88
+ self.lmax = lmax
89
+ self.sep_vecj = sep_vecj
90
+ self.epsilon = epsilon
91
+ self.last_layer = last_layer
92
+ self.edge_updates = edge_updates
93
+ self.scale_edge = scale_edge
94
+ self.activation = activation
95
+
96
+ self.update_info = parse_update_info(edge_updates)
97
+
98
+ self.dropout = dropout
99
+ self.n_atom_basis = n_atom_basis
100
+
101
+ InitDense = partial(Dense, weight_init=weight_init, bias_init=bias_init)
102
+ self.gamma_s = nn.Sequential(
103
+ InitDense(n_atom_basis, n_atom_basis, activation=activation),
104
+ InitDense(n_atom_basis, 3 * n_atom_basis, activation=None),
105
+ )
106
+
107
+ self.num_heads = num_heads
108
+ self.q_w = InitDense(n_atom_basis, n_atom_basis, activation=None)
109
+ self.k_w = InitDense(n_atom_basis, n_atom_basis, activation=None)
110
+
111
+ self.gamma_v = nn.Sequential(
112
+ InitDense(n_atom_basis, n_atom_basis, activation=activation),
113
+ InitDense(n_atom_basis, 3 * n_atom_basis, activation=None),
114
+ )
115
+
116
+ self.phik_w_ra = InitDense(
117
+ n_atom_basis,
118
+ n_atom_basis,
119
+ activation=activation,
120
+ )
121
+
122
+ InitMLP = partial(MLP, weight_init=weight_init, bias_init=bias_init)
123
+
124
+ self.edge_vec_dim = n_atom_basis if evec_dim is None else evec_dim
125
+ self.edge_mlp_dim = n_atom_basis if emlp_dim is None else emlp_dim
126
+ if not self.last_layer and self.edge_updates:
127
+ if self.update_info["mlp"] or self.update_info["mlpa"]:
128
+ dims = [n_atom_basis, self.edge_mlp_dim, n_atom_basis]
129
+ else:
130
+ dims = [n_atom_basis, n_atom_basis]
131
+ self.edge_attr_up = InitMLP(
132
+ dims, activation=activation, last_activation=None if self.update_info["mlp"] else self.activation, norm=edge_ln
133
+ )
134
+ self.vecq_w = InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False)
135
+
136
+ if self.sep_vecj:
137
+ self.veck_w = nn.ModuleList(
138
+ [InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False) for i in range(self.lmax)]
139
+ )
140
+ else:
141
+ self.veck_w = InitDense(n_atom_basis, self.edge_vec_dim, activation=None, bias=False)
142
+
143
+ if self.update_info["lin_w"] > 0:
144
+ modules = []
145
+ if self.update_info["lin_w"] % 10 == 2:
146
+ modules.append(self.activation)
147
+ self.lin_w_linear = InitDense(
148
+ self.edge_vec_dim,
149
+ n_atom_basis,
150
+ activation=None,
151
+ norm="layer" if self.update_info["lin_w"] == 2 else "", # lin_ln in original code but error
152
+ )
153
+ modules.append(self.lin_w_linear)
154
+ self.lin_w = nn.Sequential(*modules)
155
+
156
+ self.down_proj = nn.Identity()
157
+
158
+ self.cutoff = CosineCutoff(cutoff)
159
+ self._alpha = None
160
+
161
+ self.w_re = InitDense(
162
+ n_atom_basis,
163
+ n_atom_basis * 3,
164
+ None,
165
+ )
166
+
167
+ self.layernorm_ = layer_norm
168
+ self.vector_norm_ = vector_norm
169
+
170
+ if layer_norm:
171
+ self.layernorm = nn.LayerNorm(n_atom_basis)
172
+ else:
173
+ self.layernorm = nn.Identity()
174
+ if vector_norm:
175
+ self.tln = TensorLayerNorm(n_atom_basis, trainable=False)
176
+ else:
177
+ self.tln = nn.Identity()
178
+
179
+ self.reset_parameters()
180
+
181
+ def reset_parameters(self):
182
+ if self.layernorm_:
183
+ self.layernorm.reset_parameters()
184
+ if self.vector_norm_:
185
+ self.tln.reset_parameters()
186
+ for l in self.gamma_s: # noqa: E741
187
+ l.reset_parameters()
188
+
189
+ self.q_w.reset_parameters()
190
+ self.k_w.reset_parameters()
191
+ for l in self.gamma_v: # noqa: E741
192
+ l.reset_parameters()
193
+ # self.v_w.reset_parameters()
194
+ # self.out_w.reset_parameters()
195
+ self.w_re.reset_parameters()
196
+
197
+ if not self.last_layer and self.edge_updates:
198
+ self.edge_attr_up.reset_parameters()
199
+ self.vecq_w.reset_parameters()
200
+
201
+ if self.sep_vecj:
202
+ for w in self.veck_w:
203
+ w.reset_parameters()
204
+ else:
205
+ self.veck_w.reset_parameters()
206
+
207
+ if self.update_info["lin_w"] > 0:
208
+ self.lin_w_linear.reset_parameters()
209
+
210
+ def forward(
211
+ self,
212
+ edge_index,
213
+ s: torch.Tensor,
214
+ t: torch.Tensor,
215
+ dir_ij: torch.Tensor,
216
+ r_ij: torch.Tensor,
217
+ d_ij: torch.Tensor,
218
+ num_edges_expanded: torch.Tensor,
219
+ ):
220
+ """Compute interaction output."""
221
+ s = self.layernorm(s)
222
+ t = self.tln(t)
223
+
224
+ q = self.q_w(s).reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads)
225
+ k = self.k_w(s).reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads)
226
+
227
+ x = self.gamma_s(s)
228
+ val = self.gamma_v(s)
229
+ f_ij = r_ij
230
+ r_ij_attn = self.phik_w_ra(r_ij)
231
+ r_ij = self.w_re(r_ij)
232
+
233
+ # propagate_type: (x: Tensor, ten: Tensor, q:Tensor, k:Tensor, val:Tensor, r_ij: Tensor, r_ij_attn: Tensor, d_ij:Tensor, dir_ij: Tensor, num_edges_expanded: Tensor)
234
+ su, tu = self.propagate(
235
+ edge_index=edge_index,
236
+ x=x,
237
+ q=q,
238
+ k=k,
239
+ val=val,
240
+ ten=t,
241
+ r_ij=r_ij,
242
+ r_ij_attn=r_ij_attn,
243
+ d_ij=d_ij,
244
+ dir_ij=dir_ij,
245
+ num_edges_expanded=num_edges_expanded,
246
+ ) # , f_ij=f_ij
247
+
248
+ s = s + su
249
+ t = t + tu
250
+
251
+ if not self.last_layer and self.edge_updates:
252
+ vec = t
253
+
254
+ w1 = self.vecq_w(vec)
255
+ if self.sep_vecj:
256
+ vec_split = split_degree(vec, self.lmax, dim=1)
257
+ w_out = torch.concat([w(vec_split[i]) for i, w in enumerate(self.veck_w)], dim=1)
258
+
259
+ else:
260
+ w_out = self.veck_w(vec)
261
+
262
+ # edge_updater_type: (w1: Tensor, w2:Tensor, d_ij: Tensor, f_ij: Tensor)
263
+ df_ij = self.edge_updater(edge_index, w1=w1, w2=w_out, d_ij=dir_ij, f_ij=f_ij)
264
+ df_ij = f_ij + df_ij
265
+ self._alpha = None
266
+ return s, t, df_ij
267
+ else:
268
+ self._alpha = None
269
+ return s, t, f_ij
270
+
271
+ # return s, t
272
+
273
+ def message(
274
+ self,
275
+ edge_index,
276
+ x_i: torch.Tensor,
277
+ x_j: torch.Tensor,
278
+ q_i: torch.Tensor,
279
+ k_j: torch.Tensor,
280
+ val_j: torch.Tensor,
281
+ ten_j: torch.Tensor,
282
+ r_ij: torch.Tensor,
283
+ r_ij_attn: torch.Tensor,
284
+ d_ij: torch.Tensor,
285
+ dir_ij: torch.Tensor,
286
+ num_edges_expanded: torch.Tensor,
287
+ index: torch.Tensor,
288
+ ptr: OptTensor,
289
+ dim_size: Optional[int],
290
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
291
+ """
292
+ Compute message passing.
293
+ """
294
+
295
+ r_ij_attn = r_ij_attn.reshape(-1, self.num_heads, self.n_atom_basis // self.num_heads)
296
+ attn = (q_i * k_j * r_ij_attn).sum(dim=-1, keepdim=True)
297
+
298
+ attn = softmax(attn, index, ptr, dim_size)
299
+
300
+ # Normalize the attention scores
301
+ if self.scale_edge:
302
+ norm = torch.sqrt(num_edges_expanded.reshape(-1, 1, 1)) / np.sqrt(self.n_atom_basis)
303
+ else:
304
+ norm = 1.0 / np.sqrt(self.n_atom_basis)
305
+ attn = attn * norm
306
+ self._alpha = attn
307
+ attn = F.dropout(attn, p=self.dropout, training=self.training)
308
+
309
+ self_attn = attn * val_j.reshape(-1, self.num_heads, (self.n_atom_basis * 3) // self.num_heads)
310
+ SEA = self_attn.reshape(-1, 1, self.n_atom_basis * 3)
311
+
312
+ x = SEA + (r_ij.unsqueeze(1) * x_j * self.cutoff(d_ij.unsqueeze(-1).unsqueeze(-1)))
313
+
314
+ o_s, o_d, o_t = torch.split(x, self.n_atom_basis, dim=-1)
315
+ dmu = o_d * dir_ij[..., None] + o_t * ten_j
316
+ return o_s, dmu
317
+
318
+ @staticmethod
319
+ def rej(vec, d_ij):
320
+ vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True)
321
+ return vec - vec_proj * d_ij.unsqueeze(2)
322
+
323
+ def edge_update(self, w1_i, w2_j, w3_j, d_ij, f_ij):
324
+ if self.sep_vecj:
325
+ vi = w1_i
326
+ vj = w2_j
327
+ vi_split = split_degree(vi, self.lmax, dim=1)
328
+ vj_split = split_degree(vj, self.lmax, dim=1)
329
+ d_ij_split = split_degree(d_ij, self.lmax, dim=1)
330
+
331
+ pairs = []
332
+ for i in range(len(vi_split)):
333
+ if self.update_info["rej"]:
334
+ w1 = self.rej(vi_split[i], d_ij_split[i])
335
+ w2 = self.rej(vj_split[i], -d_ij_split[i])
336
+ pairs.append((w1, w2))
337
+ else:
338
+ w1 = vi_split[i]
339
+ w2 = vj_split[i]
340
+ pairs.append((w1, w2))
341
+ elif not self.update_info["rej"]:
342
+ w1 = w1_i
343
+ w2 = w2_j
344
+ pairs = [(w1, w2)]
345
+ else:
346
+ w1 = self.rej(w1_i, d_ij)
347
+ w2 = self.rej(w2_j, -d_ij)
348
+ pairs = [(w1, w2)]
349
+
350
+ w_dot_sum = None
351
+ for el in pairs:
352
+ w1, w2 = el
353
+ w_dot = (w1 * w2).sum(dim=1)
354
+ if w_dot_sum is None:
355
+ w_dot_sum = w_dot
356
+ else:
357
+ w_dot_sum = w_dot_sum + w_dot
358
+ w_dot = w_dot_sum
359
+ if self.update_info["lin_w"] > 0:
360
+ w_dot = self.lin_w(w_dot)
361
+
362
+ if self.update_info["gated"] == "gatedt":
363
+ w_dot = torch.tanh(w_dot)
364
+ elif self.update_info["gated"] == "gated":
365
+ w_dot = torch.sigmoid(w_dot)
366
+ elif self.update_info["gated"] == "act":
367
+ w_dot = self.activation(w_dot)
368
+
369
+ df_ij = self.edge_attr_up(f_ij) * w_dot
370
+ return df_ij
371
+
372
+ # noinspection PyMethodOverriding
373
+ def aggregate(
374
+ self,
375
+ features: Tuple[torch.Tensor, torch.Tensor],
376
+ index: torch.Tensor,
377
+ ptr: Optional[torch.Tensor],
378
+ dim_size: Optional[int],
379
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
380
+ x, vec = features
381
+ x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
382
+ vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
383
+ return x, vec
384
+
385
+ def update(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
386
+ return inputs
387
+
388
+
389
+ class EQFF(nn.Module):
390
+ def __init__(
391
+ self,
392
+ n_atom_basis: int,
393
+ activation: Callable,
394
+ epsilon: float = 1e-8,
395
+ weight_init=nn.init.xavier_uniform_,
396
+ bias_init=nn.init.zeros_,
397
+ vec_dim=None,
398
+ ):
399
+ """Equiavariant Feed Forward layer."""
400
+ super(EQFF, self).__init__()
401
+ self.n_atom_basis = n_atom_basis
402
+
403
+ InitDense = partial(Dense, weight_init=weight_init, bias_init=bias_init)
404
+
405
+ vec_dim = n_atom_basis if vec_dim is None else vec_dim
406
+ context_dim = 2 * n_atom_basis
407
+
408
+ self.gamma_m = nn.Sequential(
409
+ InitDense(context_dim, n_atom_basis, activation=activation),
410
+ InitDense(n_atom_basis, 2 * n_atom_basis, activation=None),
411
+ )
412
+ self.w_vu = InitDense(n_atom_basis, vec_dim, activation=None, bias=False)
413
+
414
+ self.epsilon = epsilon
415
+
416
+ def reset_parameters(self):
417
+ self.w_vu.reset_parameters()
418
+ for l in self.gamma_m: # noqa: E741
419
+ l.reset_parameters()
420
+
421
+ def forward(self, s, v):
422
+ """Compute Equivariant Feed Forward output."""
423
+
424
+ t_prime = self.w_vu(v)
425
+ t_prime_mag = torch.sqrt(torch.sum(t_prime**2, dim=-2, keepdim=True) + self.epsilon)
426
+ combined = [s, t_prime_mag]
427
+ combined_tensor = torch.cat(combined, dim=-1)
428
+ m12 = self.gamma_m(combined_tensor)
429
+
430
+ m_1, m_2 = torch.split(m12, self.n_atom_basis, dim=-1)
431
+ delta_v = m_2 * t_prime
432
+
433
+ s = s + m_1
434
+ v = v + delta_v
435
+
436
+ return s, v
437
+
438
+
439
+ class GotenNet(nn.Module):
440
+ def __init__(
441
+ self,
442
+ hidden_channels: int = 128,
443
+ num_layers: int = 8,
444
+ radial_basis: Union[Callable, str] = "BesselBasis",
445
+ n_rbf: int = 20,
446
+ cutoff: float = 5.0,
447
+ activation: Optional[Union[Callable, str]] = F.silu,
448
+ max_z: int = 100,
449
+ epsilon: float = 1e-8,
450
+ weight_init=nn.init.xavier_uniform_,
451
+ bias_init=nn.init.zeros_,
452
+ int_layer_norm=False,
453
+ int_vector_norm=False,
454
+ before_mixing_layer_norm=False,
455
+ after_mixing_layer_norm=False,
456
+ num_heads=8,
457
+ attn_dropout=0.0,
458
+ edge_updates=True,
459
+ scale_edge=True,
460
+ lmax=2,
461
+ aggr="add",
462
+ edge_ln="",
463
+ evec_dim=None,
464
+ emlp_dim=None,
465
+ sep_int_vec=True,
466
+ ):
467
+ """
468
+ Representation for GotenNet
469
+ """
470
+ super(GotenNet, self).__init__()
471
+
472
+ self.scale_edge = scale_edge
473
+ if type(weight_init) == str: # noqa: E721
474
+ # print(f"Using {weight_init} weight initialization")
475
+ weight_init = get_weight_init_by_string(weight_init)
476
+
477
+ if type(bias_init) == str: # noqa: E721
478
+ bias_init = get_weight_init_by_string(bias_init)
479
+
480
+ if type(activation) is str:
481
+ activation = str2act(activation)
482
+
483
+ self.n_atom_basis = self.hidden_dim = hidden_channels
484
+ self.n_interactions = num_layers
485
+ self.cutoff = cutoff
486
+
487
+ self.neighbor_embedding = NodeInit(
488
+ [self.hidden_dim // 2, self.hidden_dim],
489
+ n_rbf,
490
+ self.cutoff,
491
+ max_z=max_z,
492
+ weight_init=weight_init,
493
+ bias_init=bias_init,
494
+ concat=False,
495
+ proj_ln="layer",
496
+ activation=activation,
497
+ )
498
+ self.edge_embedding = EdgeInit(
499
+ n_rbf, [self.hidden_dim // 2, self.hidden_dim], weight_init=weight_init, bias_init=bias_init, proj_ln=""
500
+ )
501
+
502
+ radial_basis = str2basis(radial_basis)
503
+ self.radial_basis = radial_basis(cutoff=self.cutoff, n_rbf=n_rbf)
504
+
505
+ self.embedding = nn.Embedding(max_z, self.n_atom_basis, padding_idx=0)
506
+
507
+ self.tensor_init = TensorInit(l=lmax)
508
+
509
+ self.gata = nn.ModuleList(
510
+ [
511
+ GATA(
512
+ n_atom_basis=self.n_atom_basis,
513
+ activation=activation,
514
+ aggr=aggr,
515
+ weight_init=weight_init,
516
+ bias_init=bias_init,
517
+ layer_norm=int_layer_norm,
518
+ vector_norm=int_vector_norm,
519
+ cutoff=self.cutoff,
520
+ epsilon=epsilon,
521
+ num_heads=num_heads,
522
+ dropout=attn_dropout,
523
+ edge_updates=edge_updates,
524
+ last_layer=(i == self.n_interactions - 1),
525
+ scale_edge=scale_edge,
526
+ edge_ln=edge_ln,
527
+ evec_dim=evec_dim,
528
+ emlp_dim=emlp_dim,
529
+ sep_vecj=sep_int_vec,
530
+ lmax=lmax,
531
+ )
532
+ for i in range(self.n_interactions)
533
+ ]
534
+ )
535
+
536
+ self.eqff = nn.ModuleList(
537
+ [
538
+ EQFF(n_atom_basis=self.n_atom_basis, activation=activation, epsilon=epsilon, weight_init=weight_init, bias_init=bias_init)
539
+ for i in range(self.n_interactions)
540
+ ]
541
+ )
542
+
543
+ # Extra layer norms for the scalar quantities
544
+ if before_mixing_layer_norm:
545
+ self.before_mixing_ln = nn.LayerNorm(self.n_atom_basis)
546
+ else:
547
+ self.before_mixing_ln = nn.Identity()
548
+
549
+ if after_mixing_layer_norm:
550
+ self.after_mixing_ln = nn.LayerNorm(self.n_atom_basis)
551
+ else:
552
+ self.after_mixing_ln = nn.Identity()
553
+
554
+ self.reset_parameters()
555
+
556
+ def reset_parameters(self):
557
+ self.edge_embedding.reset_parameters()
558
+ self.neighbor_embedding.reset_parameters()
559
+ for l in self.gata: # noqa: E741
560
+ l.reset_parameters()
561
+ for l in self.eqff: # noqa: E741
562
+ l.reset_parameters()
563
+
564
+ if not isinstance(self.before_mixing_ln, nn.Identity):
565
+ self.before_mixing_ln.reset_parameters()
566
+ if not isinstance(self.after_mixing_ln, nn.Identity):
567
+ self.after_mixing_ln.reset_parameters()
568
+
569
+ def forward(self, z, pos, cutoff_edge_index, cutoff_edge_distance, cutoff_edge_vec):
570
+ q = self.embedding(z)[:]
571
+
572
+ edge_attr = self.radial_basis(cutoff_edge_distance)
573
+
574
+ q = self.neighbor_embedding(z, q, cutoff_edge_index, cutoff_edge_distance, edge_attr)
575
+ edge_attr = self.edge_embedding(cutoff_edge_index, edge_attr, q)
576
+ mask = cutoff_edge_index[0] != cutoff_edge_index[1]
577
+ # direction vector
578
+ dist = torch.norm(cutoff_edge_vec[mask], dim=1).unsqueeze(1)
579
+ cutoff_edge_vec[mask] = cutoff_edge_vec[mask] / dist
580
+
581
+ cutoff_edge_vec = self.tensor_init(cutoff_edge_vec)
582
+ equi_dim = ((self.tensor_init.l + 1) ** 2) - 1
583
+ # count number of edges for each node
584
+ num_edges = scatter(torch.ones_like(cutoff_edge_distance), cutoff_edge_index[0], dim=0, reduce="sum")
585
+ # the shape of num edges is [num_nodes, 1], we want to expand this to [num_edges, 1]
586
+ # Map num_edges back to the shape of attn using cutoff_edge_index
587
+ num_edges_expanded = num_edges[cutoff_edge_index[0]]
588
+
589
+ qs = q.shape
590
+ mu = torch.zeros((qs[0], equi_dim, qs[1]), device=q.device)
591
+ q.unsqueeze_(1)
592
+
593
+ layer_outputs = []
594
+
595
+ for i, (interaction, mixing) in enumerate(zip(self.gata, self.eqff)):
596
+ q, mu, edge_attr = interaction(
597
+ cutoff_edge_index,
598
+ q,
599
+ mu,
600
+ dir_ij=cutoff_edge_vec,
601
+ r_ij=edge_attr,
602
+ d_ij=cutoff_edge_distance,
603
+ num_edges_expanded=num_edges_expanded,
604
+ )
605
+
606
+ q = self.before_mixing_ln(q)
607
+ q, mu = mixing(q, mu)
608
+ q = self.after_mixing_ln(q)
609
+
610
+ # Collect all scalars for inter-layer read-outs
611
+ layer_outputs.append(q.squeeze(1))
612
+
613
+ # q = q.squeeze(1)
614
+
615
+ layer_outputs = torch.stack(layer_outputs, dim=-1)
616
+
617
+ output_dict = {}
618
+ output_dict["embedding_0"] = layer_outputs.unsqueeze(2) # [n_nodes, n_features, dimension of irrep, n_layers]
619
+ # This is a scalar so a single irrep
620
+
621
+ return output_dict
models/pos_egnn/posegnn/model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from .encoder import GotenNet
4
+ from .utils import get_symmetric_displacement, BatchedPeriodicDistance, ACT_CLASS_MAPPING
5
+ #from torch_scatter import scatter
6
+
7
+ class NodeInvariantReadout(nn.Module):
8
+ def __init__(self, in_channels, num_residues, hidden_channels, out_channels, activation):
9
+ super().__init__()
10
+
11
+ self.linears = nn.ModuleList([nn.Linear(in_channels, out_channels) for _ in range(num_residues - 1)])
12
+
13
+ # Define the nonlinear layer for the last layer's output
14
+ self.non_linear = nn.Sequential(
15
+ nn.Linear(in_channels, hidden_channels),
16
+ ACT_CLASS_MAPPING[activation](),
17
+ nn.Linear(hidden_channels, out_channels),
18
+ )
19
+
20
+ def forward(self, embedding_0):
21
+ layer_outputs = embedding_0.squeeze(2) # [n_nodes, in_channels, num_residues]
22
+
23
+ processed_outputs = []
24
+ for i, linear in enumerate(self.linears):
25
+ processed_outputs.append(linear(layer_outputs[:, :, i]))
26
+
27
+ processed_outputs.append(self.non_linear(layer_outputs[:, :, -1]))
28
+ output = torch.stack(processed_outputs, dim=0).sum(dim=0).squeeze(-1)
29
+ return output
30
+
31
+ class PosEGNN(nn.Module):
32
+ def __init__(self, config):
33
+ super().__init__()
34
+
35
+ self.distance = BatchedPeriodicDistance(config["encoder"]["cutoff"])
36
+ self.encoder = GotenNet(**config["encoder"])
37
+ self.readout = NodeInvariantReadout(**config["decoder"])
38
+ self.register_buffer("e0_mean", torch.tensor(config["e0_mean"]))
39
+ self.register_buffer("atomic_res_total_mean", torch.tensor(config["atomic_res_total_mean"]))
40
+ self.register_buffer("atomic_res_total_std", torch.tensor(config["atomic_res_total_std"]))
41
+
42
+ def forward(self, data):
43
+ data.pos.requires_grad_(True)
44
+
45
+ data.pos, data.box, data.displacements = get_symmetric_displacement(data.pos, data.box, data.num_graphs, data.batch)
46
+
47
+ data.cutoff_edge_index, data.cutoff_edge_distance, data.cutoff_edge_vec, data.cutoff_shifts_idx = self.distance(
48
+ data.pos, data.box, data.batch
49
+ )
50
+
51
+ embedding_dict = self.encoder(data.z, data.pos, data.cutoff_edge_index, data.cutoff_edge_distance, data.cutoff_edge_vec)
52
+
53
+ return embedding_dict
54
+
55
+ def compute_properties(self, data, compute_stress = True):
56
+ output = {}
57
+
58
+ embedding_dict = self.forward(data)
59
+ embedding_0 = embedding_dict["embedding_0"]
60
+
61
+ # Compute energy
62
+ node_e_res = self.readout(embedding_0)
63
+
64
+ node_e_res = node_e_res * self.atomic_res_total_std + self.atomic_res_total_mean
65
+ total_e_res = scatter(src=node_e_res, index=data["batch"], dim=0, reduce="sum")
66
+
67
+ node_e0 = self.e0_mean[data.z]
68
+ total_e0 = scatter(src=node_e0, index=data["batch"], dim=0, reduce="sum")
69
+
70
+ total_energy = total_e0 + total_e_res
71
+ output["total_energy"] = total_energy
72
+
73
+ # Compute gradients
74
+ if compute_stress:
75
+ inputs = [data.pos, data.displacements]
76
+ compute_stress = True
77
+ else:
78
+ inputs = [data.pos]
79
+
80
+ grad_outputs = torch.autograd.grad(
81
+ outputs=[total_energy],
82
+ inputs=inputs,
83
+ grad_outputs=[torch.ones_like(total_energy)],
84
+ retain_graph=self.training,
85
+ create_graph=self.training,
86
+ )
87
+
88
+ # Get forces and stresses
89
+ if compute_stress:
90
+ force, virial = grad_outputs
91
+ stress = virial / torch.det(data.box).abs().view(-1, 1, 1)
92
+ stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress))
93
+ output["force"] = -force
94
+ output["stress"] = -stress
95
+ else:
96
+ force = grad_outputs[0]
97
+ output["force"] = -force
98
+
99
+ return output
models/pos_egnn/posegnn/ops.py ADDED
@@ -0,0 +1,1584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code was adapted from https://github.com/sarpaykent/GotenNet
3
+ Copyright (c) 2025 Sarp Aykent
4
+ MIT License
5
+
6
+ GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks
7
+ Sarp Aykent and Tian Xia
8
+ https://openreview.net/pdf?id=5wxCQDtbMo
9
+ """
10
+
11
+ from __future__ import absolute_import, division, print_function
12
+
13
+ import inspect
14
+ import math
15
+ from functools import partial
16
+ from typing import List
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import Tensor
21
+ from torch import nn as nn
22
+ from torch.nn.init import constant_, xavier_uniform_
23
+ from torch_geometric.nn import MessagePassing
24
+ from torch_geometric.nn.inits import glorot_orthogonal
25
+ from torch_geometric.nn.models.schnet import ShiftedSoftplus
26
+ #from torch_scatter import scatter
27
+
28
+ zeros_initializer = partial(constant_, val=0.0)
29
+
30
+
31
+ def centralize(
32
+ batch,
33
+ key: str,
34
+ batch_index: torch.Tensor,
35
+ ): # note: cannot make assumptions on output shape
36
+ # derive centroid of each batch element, and center entities using corresponding centroids
37
+ entities_centroid = scatter(batch[key], batch_index, dim=0, reduce="mean") # e.g., [batch_size, 3]
38
+ entities_centered = batch[key] - entities_centroid[batch_index]
39
+
40
+ return entities_centroid, entities_centered
41
+
42
+
43
+ def decentralize(
44
+ positions: torch.Tensor,
45
+ batch_index: torch.Tensor,
46
+ entities_centroid: torch.Tensor,
47
+ ) -> torch.Tensor: # note: cannot make assumptions on output shape
48
+ entities_centered = positions + entities_centroid[batch_index]
49
+ return entities_centered
50
+
51
+
52
+ def parse_update_info(edge_updates):
53
+ update_info = {
54
+ "gated": False,
55
+ "rej": True,
56
+ "vec_norm": False,
57
+ "mlp": False,
58
+ "mlpa": False,
59
+ "lin_w": 0,
60
+ "drej": False,
61
+ }
62
+ if isinstance(edge_updates, str):
63
+ update_parts = edge_updates.split("_")
64
+ else:
65
+ update_parts = []
66
+
67
+ allowed_parts = ["gated", "gatedt", "norej", "mlp", "mlpa", "act", "linw", "linwa", "drej"]
68
+ if not all([part in allowed_parts for part in update_parts]):
69
+ raise ValueError(f"Invalid edge update parts. Allowed parts are {allowed_parts}")
70
+
71
+ if "gated" in update_parts:
72
+ update_info["gated"] = "gated"
73
+ if "gatedt" in update_parts:
74
+ update_info["gated"] = "gatedt"
75
+ if "act" in update_parts:
76
+ update_info["gated"] = "act"
77
+ if "norej" in update_parts:
78
+ update_info["rej"] = False
79
+ if "mlp" in update_parts:
80
+ update_info["mlp"] = True
81
+ if "mlpa" in update_parts:
82
+ update_info["mlpa"] = True
83
+ if "linw" in update_parts:
84
+ update_info["lin_w"] = 1
85
+ if "linwa" in update_parts:
86
+ update_info["lin_w"] = 2
87
+ if "drej" in update_parts:
88
+ update_info["drej"] = True
89
+ return update_info
90
+
91
+
92
+ class SmoothLeakyReLU(torch.nn.Module):
93
+ def __init__(self, negative_slope=0.2):
94
+ super().__init__()
95
+ self.alpha = negative_slope
96
+
97
+ def forward(self, x):
98
+ x1 = ((1 + self.alpha) / 2) * x
99
+ x2 = ((1 - self.alpha) / 2) * x * (2 * torch.sigmoid(x) - 1)
100
+ return x1 + x2
101
+
102
+ def extra_repr(self):
103
+ return "negative_slope={}".format(self.alpha)
104
+
105
+
106
+ def shifted_softplus(x: torch.Tensor):
107
+ return F.softplus(x) - math.log(2.0)
108
+
109
+
110
+ class PolynomialCutoff(nn.Module):
111
+ def __init__(self, cutoff, p: int = 6):
112
+ super(PolynomialCutoff, self).__init__()
113
+ self.cutoff = cutoff
114
+ self.p = p
115
+
116
+ @staticmethod
117
+ def polynomial_cutoff(r: Tensor, rcut: float, p: float = 6.0) -> Tensor:
118
+ """
119
+ Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
120
+ """
121
+ if not p >= 2.0:
122
+ # replace below with logger error
123
+ print(f"Exponent p={p} has to be >= 2.")
124
+ print("Exiting code.")
125
+
126
+ print(f"Exponent p={p} has to be >= 2.")
127
+ print("Exiting code.")
128
+ exit()
129
+
130
+ rscaled = r / rcut
131
+
132
+ out = 1.0
133
+ out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(rscaled, p))
134
+ out = out + (p * (p + 2.0) * torch.pow(rscaled, p + 1.0))
135
+ out = out - ((p * (p + 1.0) / 2) * torch.pow(rscaled, p + 2.0))
136
+
137
+ return out * (rscaled < 1.0).float()
138
+
139
+ def forward(self, r):
140
+ return self.polynomial_cutoff(r=r, rcut=self.cutoff, p=self.p)
141
+
142
+ def __repr__(self):
143
+ return f"{self.__class__.__name__}(cutoff={self.cutoff}, p={self.p})"
144
+
145
+
146
+ class CosineCutoff(nn.Module):
147
+ def __init__(self, cutoff):
148
+ super(CosineCutoff, self).__init__()
149
+
150
+ if isinstance(cutoff, torch.Tensor):
151
+ cutoff = cutoff.item()
152
+ self.cutoff = cutoff
153
+
154
+ def forward(self, distances):
155
+ cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff) + 1.0)
156
+ cutoffs = cutoffs * (distances < self.cutoff).float()
157
+ return cutoffs
158
+
159
+
160
+ class ScaleShift(nn.Module):
161
+ r"""Scale and shift layer for standardization.
162
+
163
+ .. math::
164
+ y = x \times \sigma + \mu
165
+
166
+ Args:
167
+ mean (torch.Tensor): mean value :math:`\mu`.
168
+ stddev (torch.Tensor): standard deviation value :math:`\sigma`.
169
+
170
+ """
171
+
172
+ def __init__(self, mean, stddev):
173
+ super(ScaleShift, self).__init__()
174
+ if isinstance(mean, float):
175
+ mean = torch.FloatTensor([mean])
176
+ if isinstance(stddev, float):
177
+ stddev = torch.FloatTensor([stddev])
178
+ self.register_buffer("mean", mean)
179
+ self.register_buffer("stddev", stddev)
180
+
181
+ def forward(self, input):
182
+ """Compute layer output.
183
+
184
+ Args:
185
+ input (torch.Tensor): input data.
186
+
187
+ Returns:
188
+ torch.Tensor: layer output.
189
+
190
+ """
191
+ y = input * self.stddev + self.mean
192
+ return y
193
+
194
+
195
+ class GetItem(nn.Module):
196
+ """Extraction layer to get an item from SchNetPack dictionary of input tensors.
197
+ Args:
198
+ key (str): Property to be extracted from SchNetPack input tensors.
199
+ """
200
+
201
+ def __init__(self, key):
202
+ super(GetItem, self).__init__()
203
+ self.key = key
204
+
205
+ def forward(self, inputs):
206
+ """Compute layer output.
207
+ Args:
208
+ inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors.
209
+ Returns:
210
+ torch.Tensor: layer output.
211
+ """
212
+ return inputs[self.key]
213
+
214
+
215
+ class SchnetMLP(nn.Module):
216
+ """Multiple layer fully connected perceptron neural network.
217
+ Args:
218
+ n_in (int): number of input nodes.
219
+ n_out (int): number of output nodes.
220
+ n_hidden (list of int or int, optional): number hidden layer nodes.
221
+ If an integer, same number of node is used for all hidden layers resulting
222
+ in a rectangular network.
223
+ If None, the number of neurons is divided by two after each layer starting
224
+ n_in resulting in a pyramidal network.
225
+ n_layers (int, optional): number of layers.
226
+ activation (callable, optional): activation function. All hidden layers would
227
+ the same activation function except the output layer that does not apply
228
+ any activation function.
229
+ """
230
+
231
+ def __init__(self, n_in, n_out, n_hidden=None, n_layers=2, activation=shifted_softplus):
232
+ super(SchnetMLP, self).__init__()
233
+ # get list of number of nodes in input, hidden & output layers
234
+ if n_hidden is None:
235
+ c_neurons = n_in
236
+ self.n_neurons = []
237
+ for i in range(n_layers):
238
+ self.n_neurons.append(c_neurons)
239
+ c_neurons = c_neurons // 2
240
+ self.n_neurons.append(n_out)
241
+ else:
242
+ # get list of number of nodes hidden layers
243
+ if type(n_hidden) is int:
244
+ n_hidden = [n_hidden] * (n_layers - 1)
245
+ self.n_neurons = [n_in] + n_hidden + [n_out]
246
+
247
+ # assign a Dense layer (with activation function) to each hidden layer
248
+ layers = [Dense(self.n_neurons[i], self.n_neurons[i + 1], activation=activation) for i in range(n_layers - 1)]
249
+ # assign a Dense layer (without activation function) to the output layer
250
+ layers.append(Dense(self.n_neurons[-2], self.n_neurons[-1], activation=None))
251
+ # put all layers together to make the network
252
+ self.out_net = nn.Sequential(*layers)
253
+
254
+ def forward(self, inputs):
255
+ """Compute neural network output.
256
+ Args:
257
+ inputs (torch.Tensor): network input.
258
+ Returns:
259
+ torch.Tensor: network output.
260
+ """
261
+ return self.out_net(inputs)
262
+
263
+
264
+ def scaled_silu(x, scale=0.6):
265
+ return F.silu(x) * scale
266
+
267
+
268
+ def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor):
269
+ coeff = -0.5 / torch.pow(widths, 2)
270
+ diff = inputs[..., None] - offsets
271
+ y = torch.exp(coeff * torch.pow(diff, 2))
272
+ return y
273
+
274
+
275
+ class GaussianRBF(nn.Module):
276
+ r"""Gaussian radial basis functions."""
277
+
278
+ def __init__(self, n_rbf: int, cutoff: float, start: float = 0.0, trainable: bool = False):
279
+ """
280
+ Args:
281
+ n_rbf: total number of Gaussian functions, :math:`N_g`.
282
+ cutoff: center of last Gaussian function, :math:`\mu_{N_g}`
283
+ start: center of first Gaussian function, :math:`\mu_0`.
284
+ trainable: If True, widths and offset of Gaussian functions
285
+ are adjusted during training process.
286
+ """
287
+ super(GaussianRBF, self).__init__()
288
+ self.n_rbf = n_rbf
289
+
290
+ # compute offset and width of Gaussian functions
291
+ offset = torch.linspace(start, cutoff, n_rbf)
292
+ widths = torch.FloatTensor(torch.abs(offset[1] - offset[0]) * torch.ones_like(offset))
293
+ if trainable:
294
+ self.widths = nn.Parameter(widths)
295
+ self.offsets = nn.Parameter(offset)
296
+ else:
297
+ self.register_buffer("widths", widths)
298
+ self.register_buffer("offsets", offset)
299
+
300
+ def forward(self, inputs: torch.Tensor):
301
+ return gaussian_rbf(inputs, self.offsets, self.widths)
302
+
303
+
304
+ class BesselBasis(nn.Module):
305
+ """
306
+ Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
307
+ """
308
+
309
+ def __init__(self, cutoff=5.0, n_rbf=None, trainable=False):
310
+ """
311
+ Args:
312
+ cutoff: radial cutoff
313
+ n_rbf: number of basis functions.
314
+ """
315
+ super(BesselBasis, self).__init__()
316
+ self.n_rbf = n_rbf
317
+ # compute offset and width of Gaussian functions
318
+ freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff
319
+ self.register_buffer("freqs", freqs)
320
+ self.register_buffer("norm1", torch.tensor(1.0))
321
+
322
+ def forward(self, inputs):
323
+ input_size = len(inputs.shape) # noqa: F841
324
+ a = self.freqs[None, :]
325
+ inputs = inputs[..., None]
326
+ ax = inputs * a
327
+ sinax = torch.sin(ax)
328
+
329
+ norm = torch.where(inputs == 0, self.norm1, inputs)
330
+ y = sinax / norm
331
+
332
+ return y
333
+
334
+
335
+ def glorot_orthogonal_wrapper_(tensor, scale=2.0):
336
+ return glorot_orthogonal(tensor, scale=scale)
337
+
338
+
339
+ def _standardize(kernel):
340
+ """
341
+ Makes sure that Var(W) = 1 and E[W] = 0
342
+ """
343
+ eps = 1e-6
344
+
345
+ if len(kernel.shape) == 3:
346
+ axis = [0, 1] # last dimension is output dimension
347
+ else:
348
+ axis = 1
349
+
350
+ var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True)
351
+ kernel = (kernel - mean) / (var + eps) ** 0.5
352
+ return kernel
353
+
354
+
355
+ def he_orthogonal_init(tensor):
356
+ """
357
+ Generate a weight matrix with variance according to He initialization.
358
+ Based on a random (semi-)orthogonal matrix neural networks
359
+ are expected to learn better when features are decorrelated
360
+ (stated by eg. "Reducing overfitting in deep networks by decorrelating representations",
361
+ "Dropout: a simple way to prevent neural networks from overfitting",
362
+ "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks")
363
+ """
364
+ tensor = torch.nn.init.orthogonal_(tensor)
365
+
366
+ if len(tensor.shape) == 3:
367
+ fan_in = tensor.shape[:-1].numel()
368
+ else:
369
+ fan_in = tensor.shape[1]
370
+
371
+ with torch.no_grad():
372
+ tensor.data = _standardize(tensor.data)
373
+ tensor.data *= (1 / fan_in) ** 0.5
374
+
375
+ return tensor
376
+
377
+
378
+ def get_weight_init_by_string(init_str):
379
+ if init_str == "":
380
+ # Noop
381
+ return lambda x: x
382
+ elif init_str == "zeros":
383
+ return torch.nn.init.zeros_
384
+ elif init_str == "xavier_uniform":
385
+ return torch.nn.init.xavier_uniform_
386
+ elif init_str == "glo_orthogonal":
387
+ return glorot_orthogonal_wrapper_
388
+ elif init_str == "he_orthogonal":
389
+ return he_orthogonal_init
390
+ else:
391
+ raise ValueError(f"Unknown initialization {init_str}")
392
+
393
+
394
+ class Dense(nn.Linear):
395
+ r"""Fully connected linear layer with activation function.
396
+ Barrowed from https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/nn/base.py
397
+
398
+ .. math::
399
+ y = activation(xW^T + b)
400
+
401
+ Args:
402
+ in_features (int): number of input feature :math:`x`.
403
+ out_features (int): number of output features :math:`y`.
404
+ bias (bool, optional): if False, the layer will not adapt bias :math:`b`.
405
+ activation (callable, optional): if None, no activation function is used.
406
+ weight_init (callable, optional): weight initializer from current weight.
407
+ bias_init (callable, optional): bias initializer from current bias.
408
+
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ in_features,
414
+ out_features,
415
+ bias=True,
416
+ activation=None,
417
+ weight_init=xavier_uniform_,
418
+ bias_init=zeros_initializer,
419
+ norm=None,
420
+ gain=None,
421
+ ):
422
+ # initialize linear layer y = xW^T + b
423
+ self.weight_init = weight_init
424
+ self.bias_init = bias_init
425
+ self.gain = gain
426
+ super(Dense, self).__init__(in_features, out_features, bias)
427
+ # Initialize activation function
428
+ if inspect.isclass(activation):
429
+ self.activation = activation()
430
+ self.activation = activation
431
+
432
+ if norm == "layer":
433
+ self.norm = nn.LayerNorm(out_features)
434
+ elif norm == "batch":
435
+ self.norm = nn.BatchNorm1d(out_features)
436
+ elif norm == "instance":
437
+ self.norm = nn.InstanceNorm1d(out_features)
438
+ else:
439
+ self.norm = None
440
+
441
+ def reset_parameters(self):
442
+ """Reinitialize model weight and bias values."""
443
+ if self.gain:
444
+ self.weight_init(self.weight, gain=self.gain)
445
+ else:
446
+ self.weight_init(self.weight)
447
+ if self.bias is not None:
448
+ self.bias_init(self.bias)
449
+
450
+ def forward(self, inputs):
451
+ """Compute layer output.
452
+
453
+ Args:
454
+ inputs (dict of torch.Tensor): batch of input values.
455
+
456
+ Returns:
457
+ torch.Tensor: layer output.
458
+
459
+ """
460
+ # compute linear layer y = xW^T + b
461
+ y = super(Dense, self).forward(inputs)
462
+ if self.norm is not None:
463
+ y = self.norm(y)
464
+ # add activation function
465
+ if self.activation:
466
+ y = self.activation(y)
467
+ return y
468
+
469
+
470
+ class _VDropout(nn.Module):
471
+ """
472
+ Vector channel dropout where the elements of each
473
+ vector channel are dropped together.
474
+ """
475
+
476
+ def __init__(self, drop_rate, scale=True):
477
+ super(_VDropout, self).__init__()
478
+ self.drop_rate = drop_rate
479
+ self.scale = scale
480
+
481
+ def forward(self, x, dim=-1):
482
+ """
483
+ :param x: `torch.Tensor` corresponding to vector channels
484
+ """
485
+ if self.drop_rate == 0:
486
+ return x
487
+ device = x.device
488
+ if not self.training:
489
+ return x
490
+
491
+ shape = list(x.shape)
492
+ assert shape[dim] == 3, "The dimension must be vector"
493
+ shape[dim] = 1
494
+
495
+ mask = torch.bernoulli((1 - self.drop_rate) * torch.ones(shape, device=device))
496
+ x = mask * x
497
+ if self.scale:
498
+ # scale the output to keep the expected output distribution
499
+ # same as input distribution. However, this might be harmfuk
500
+ # for vector space.
501
+ x = x / (1 - self.drop_rate)
502
+
503
+ return x
504
+
505
+
506
+ class Dropout(nn.Module):
507
+ """
508
+ Combined dropout for tuples (s, V).
509
+ Takes tuples (s, V) as input and as output.
510
+ """
511
+
512
+ def __init__(self, drop_rate, vector_dropout=True):
513
+ super(Dropout, self).__init__()
514
+ self.sdropout = nn.Dropout(drop_rate)
515
+ if vector_dropout:
516
+ self.vdropout = _VDropout(drop_rate)
517
+ else:
518
+ self.vdropout = lambda x, dim: x
519
+
520
+ def forward(self, x):
521
+ """
522
+ :param x: tuple (s, V) of `torch.Tensor`,
523
+ or single `torch.Tensor`
524
+ (will be assumed to be scalar channels)
525
+ """
526
+ if type(x) is torch.Tensor:
527
+ return self.sdropout(x)
528
+ s, v = x
529
+ return self.sdropout(s), self.vdropout(v, dim=1)
530
+
531
+
532
+ class TensorInit(nn.Module):
533
+ def __init__(self, l=2): # noqa: E741
534
+ super(TensorInit, self).__init__()
535
+ self.l = l
536
+
537
+ def forward(self, edge_vec):
538
+ edge_sh = self._calculate_components(self.l, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2])
539
+ return edge_sh
540
+
541
+ @property
542
+ def tensor_size(self):
543
+ return ((self.l + 1) ** 2) - 1
544
+
545
+ @staticmethod
546
+ def _calculate_components(lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
547
+ sh_1_0, sh_1_1, sh_1_2 = x, y, z
548
+
549
+ if lmax == 1:
550
+ return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1)
551
+
552
+ # (x^2, y^2, z^2) ^2
553
+
554
+ sh_2_0 = math.sqrt(3.0) * x * z
555
+ sh_2_1 = math.sqrt(3.0) * x * y
556
+ y2 = y.pow(2)
557
+ x2z2 = x.pow(2) + z.pow(2)
558
+ sh_2_2 = y2 - 0.5 * x2z2
559
+ sh_2_3 = math.sqrt(3.0) * y * z
560
+ sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2))
561
+
562
+ if lmax == 2:
563
+ return torch.stack([sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], dim=-1)
564
+
565
+ # Borrowed from e3nn: https://github.com/e3nn/e3nn/blob/main/e3nn/o3/_spherical_harmonics.py#L188
566
+ sh_3_0 = (1 / 6) * math.sqrt(42) * (sh_2_0 * z + sh_2_4 * x)
567
+ sh_3_1 = math.sqrt(7) * sh_2_0 * y
568
+ sh_3_2 = (1 / 8) * math.sqrt(168) * (4.0 * y2 - x2z2) * x
569
+ sh_3_3 = (1 / 2) * math.sqrt(7) * y * (2.0 * y2 - 3.0 * x2z2)
570
+ sh_3_4 = (1 / 8) * math.sqrt(168) * z * (4.0 * y2 - x2z2)
571
+ sh_3_5 = math.sqrt(7) * sh_2_4 * y
572
+ sh_3_6 = (1 / 6) * math.sqrt(42) * (sh_2_4 * z - sh_2_0 * x)
573
+
574
+ if lmax == 3:
575
+ return torch.stack(
576
+ [
577
+ sh_1_0,
578
+ sh_1_1,
579
+ sh_1_2,
580
+ sh_2_0,
581
+ sh_2_1,
582
+ sh_2_2,
583
+ sh_2_3,
584
+ sh_2_4,
585
+ sh_3_0,
586
+ sh_3_1,
587
+ sh_3_2,
588
+ sh_3_3,
589
+ sh_3_4,
590
+ sh_3_5,
591
+ sh_3_6,
592
+ ],
593
+ dim=-1,
594
+ )
595
+
596
+ sh_4_0 = (3 / 4) * math.sqrt(2) * (sh_3_0 * z + sh_3_6 * x)
597
+ sh_4_1 = (3 / 4) * sh_3_0 * y + (3 / 8) * math.sqrt(6) * sh_3_1 * z + (3 / 8) * math.sqrt(6) * sh_3_5 * x
598
+ sh_4_2 = (
599
+ -3 / 56 * math.sqrt(14) * sh_3_0 * z
600
+ + (3 / 14) * math.sqrt(21) * sh_3_1 * y
601
+ + (3 / 56) * math.sqrt(210) * sh_3_2 * z
602
+ + (3 / 56) * math.sqrt(210) * sh_3_4 * x
603
+ + (3 / 56) * math.sqrt(14) * sh_3_6 * x
604
+ )
605
+ sh_4_3 = (
606
+ -3 / 56 * math.sqrt(42) * sh_3_1 * z
607
+ + (3 / 28) * math.sqrt(105) * sh_3_2 * y
608
+ + (3 / 28) * math.sqrt(70) * sh_3_3 * x
609
+ + (3 / 56) * math.sqrt(42) * sh_3_5 * x
610
+ )
611
+ sh_4_4 = -3 / 28 * math.sqrt(42) * sh_3_2 * x + (3 / 7) * math.sqrt(7) * sh_3_3 * y - 3 / 28 * math.sqrt(42) * sh_3_4 * z
612
+ sh_4_5 = (
613
+ -3 / 56 * math.sqrt(42) * sh_3_1 * x
614
+ + (3 / 28) * math.sqrt(70) * sh_3_3 * z
615
+ + (3 / 28) * math.sqrt(105) * sh_3_4 * y
616
+ - 3 / 56 * math.sqrt(42) * sh_3_5 * z
617
+ )
618
+ sh_4_6 = (
619
+ -3 / 56 * math.sqrt(14) * sh_3_0 * x
620
+ - 3 / 56 * math.sqrt(210) * sh_3_2 * x
621
+ + (3 / 56) * math.sqrt(210) * sh_3_4 * z
622
+ + (3 / 14) * math.sqrt(21) * sh_3_5 * y
623
+ - 3 / 56 * math.sqrt(14) * sh_3_6 * z
624
+ )
625
+ sh_4_7 = -3 / 8 * math.sqrt(6) * sh_3_1 * x + (3 / 8) * math.sqrt(6) * sh_3_5 * z + (3 / 4) * sh_3_6 * y
626
+ sh_4_8 = (3 / 4) * math.sqrt(2) * (-sh_3_0 * x + sh_3_6 * z)
627
+ if lmax == 4:
628
+ return torch.stack(
629
+ [
630
+ sh_1_0,
631
+ sh_1_1,
632
+ sh_1_2,
633
+ sh_2_0,
634
+ sh_2_1,
635
+ sh_2_2,
636
+ sh_2_3,
637
+ sh_2_4,
638
+ sh_3_0,
639
+ sh_3_1,
640
+ sh_3_2,
641
+ sh_3_3,
642
+ sh_3_4,
643
+ sh_3_5,
644
+ sh_3_6,
645
+ sh_4_0,
646
+ sh_4_1,
647
+ sh_4_2,
648
+ sh_4_3,
649
+ sh_4_4,
650
+ sh_4_5,
651
+ sh_4_6,
652
+ sh_4_7,
653
+ sh_4_8,
654
+ ],
655
+ dim=-1,
656
+ )
657
+
658
+ sh_5_0 = (1 / 10) * math.sqrt(110) * (sh_4_0 * z + sh_4_8 * x)
659
+ sh_5_1 = (1 / 5) * math.sqrt(11) * sh_4_0 * y + (1 / 5) * math.sqrt(22) * sh_4_1 * z + (1 / 5) * math.sqrt(22) * sh_4_7 * x
660
+ sh_5_2 = (
661
+ -1 / 30 * math.sqrt(22) * sh_4_0 * z
662
+ + (4 / 15) * math.sqrt(11) * sh_4_1 * y
663
+ + (1 / 15) * math.sqrt(154) * sh_4_2 * z
664
+ + (1 / 15) * math.sqrt(154) * sh_4_6 * x
665
+ + (1 / 30) * math.sqrt(22) * sh_4_8 * x
666
+ )
667
+ sh_5_3 = (
668
+ -1 / 30 * math.sqrt(66) * sh_4_1 * z
669
+ + (1 / 15) * math.sqrt(231) * sh_4_2 * y
670
+ + (1 / 30) * math.sqrt(462) * sh_4_3 * z
671
+ + (1 / 30) * math.sqrt(462) * sh_4_5 * x
672
+ + (1 / 30) * math.sqrt(66) * sh_4_7 * x
673
+ )
674
+ sh_5_4 = (
675
+ -1 / 15 * math.sqrt(33) * sh_4_2 * z
676
+ + (2 / 15) * math.sqrt(66) * sh_4_3 * y
677
+ + (1 / 15) * math.sqrt(165) * sh_4_4 * x
678
+ + (1 / 15) * math.sqrt(33) * sh_4_6 * x
679
+ )
680
+ sh_5_5 = -1 / 15 * math.sqrt(110) * sh_4_3 * x + (1 / 3) * math.sqrt(11) * sh_4_4 * y - 1 / 15 * math.sqrt(110) * sh_4_5 * z
681
+ sh_5_6 = (
682
+ -1 / 15 * math.sqrt(33) * sh_4_2 * x
683
+ + (1 / 15) * math.sqrt(165) * sh_4_4 * z
684
+ + (2 / 15) * math.sqrt(66) * sh_4_5 * y
685
+ - 1 / 15 * math.sqrt(33) * sh_4_6 * z
686
+ )
687
+ sh_5_7 = (
688
+ -1 / 30 * math.sqrt(66) * sh_4_1 * x
689
+ - 1 / 30 * math.sqrt(462) * sh_4_3 * x
690
+ + (1 / 30) * math.sqrt(462) * sh_4_5 * z
691
+ + (1 / 15) * math.sqrt(231) * sh_4_6 * y
692
+ - 1 / 30 * math.sqrt(66) * sh_4_7 * z
693
+ )
694
+ sh_5_8 = (
695
+ -1 / 30 * math.sqrt(22) * sh_4_0 * x
696
+ - 1 / 15 * math.sqrt(154) * sh_4_2 * x
697
+ + (1 / 15) * math.sqrt(154) * sh_4_6 * z
698
+ + (4 / 15) * math.sqrt(11) * sh_4_7 * y
699
+ - 1 / 30 * math.sqrt(22) * sh_4_8 * z
700
+ )
701
+ sh_5_9 = -1 / 5 * math.sqrt(22) * sh_4_1 * x + (1 / 5) * math.sqrt(22) * sh_4_7 * z + (1 / 5) * math.sqrt(11) * sh_4_8 * y
702
+ sh_5_10 = (1 / 10) * math.sqrt(110) * (-sh_4_0 * x + sh_4_8 * z)
703
+ if lmax == 5:
704
+ return torch.stack(
705
+ [
706
+ sh_1_0,
707
+ sh_1_1,
708
+ sh_1_2,
709
+ sh_2_0,
710
+ sh_2_1,
711
+ sh_2_2,
712
+ sh_2_3,
713
+ sh_2_4,
714
+ sh_3_0,
715
+ sh_3_1,
716
+ sh_3_2,
717
+ sh_3_3,
718
+ sh_3_4,
719
+ sh_3_5,
720
+ sh_3_6,
721
+ sh_4_0,
722
+ sh_4_1,
723
+ sh_4_2,
724
+ sh_4_3,
725
+ sh_4_4,
726
+ sh_4_5,
727
+ sh_4_6,
728
+ sh_4_7,
729
+ sh_4_8,
730
+ sh_5_0,
731
+ sh_5_1,
732
+ sh_5_2,
733
+ sh_5_3,
734
+ sh_5_4,
735
+ sh_5_5,
736
+ sh_5_6,
737
+ sh_5_7,
738
+ sh_5_8,
739
+ sh_5_9,
740
+ sh_5_10,
741
+ ],
742
+ dim=-1,
743
+ )
744
+
745
+ sh_6_0 = (1 / 6) * math.sqrt(39) * (sh_5_0 * z + sh_5_10 * x)
746
+ sh_6_1 = (1 / 6) * math.sqrt(13) * sh_5_0 * y + (1 / 12) * math.sqrt(130) * sh_5_1 * z + (1 / 12) * math.sqrt(130) * sh_5_9 * x
747
+ sh_6_2 = (
748
+ -1 / 132 * math.sqrt(286) * sh_5_0 * z
749
+ + (1 / 33) * math.sqrt(715) * sh_5_1 * y
750
+ + (1 / 132) * math.sqrt(286) * sh_5_10 * x
751
+ + (1 / 44) * math.sqrt(1430) * sh_5_2 * z
752
+ + (1 / 44) * math.sqrt(1430) * sh_5_8 * x
753
+ )
754
+ sh_6_3 = (
755
+ -1 / 132 * math.sqrt(858) * sh_5_1 * z
756
+ + (1 / 22) * math.sqrt(429) * sh_5_2 * y
757
+ + (1 / 22) * math.sqrt(286) * sh_5_3 * z
758
+ + (1 / 22) * math.sqrt(286) * sh_5_7 * x
759
+ + (1 / 132) * math.sqrt(858) * sh_5_9 * x
760
+ )
761
+ sh_6_4 = (
762
+ -1 / 66 * math.sqrt(429) * sh_5_2 * z
763
+ + (2 / 33) * math.sqrt(286) * sh_5_3 * y
764
+ + (1 / 66) * math.sqrt(2002) * sh_5_4 * z
765
+ + (1 / 66) * math.sqrt(2002) * sh_5_6 * x
766
+ + (1 / 66) * math.sqrt(429) * sh_5_8 * x
767
+ )
768
+ sh_6_5 = (
769
+ -1 / 66 * math.sqrt(715) * sh_5_3 * z
770
+ + (1 / 66) * math.sqrt(5005) * sh_5_4 * y
771
+ + (1 / 66) * math.sqrt(3003) * sh_5_5 * x
772
+ + (1 / 66) * math.sqrt(715) * sh_5_7 * x
773
+ )
774
+ sh_6_6 = -1 / 66 * math.sqrt(2145) * sh_5_4 * x + (1 / 11) * math.sqrt(143) * sh_5_5 * y - 1 / 66 * math.sqrt(2145) * sh_5_6 * z
775
+ sh_6_7 = (
776
+ -1 / 66 * math.sqrt(715) * sh_5_3 * x
777
+ + (1 / 66) * math.sqrt(3003) * sh_5_5 * z
778
+ + (1 / 66) * math.sqrt(5005) * sh_5_6 * y
779
+ - 1 / 66 * math.sqrt(715) * sh_5_7 * z
780
+ )
781
+ sh_6_8 = (
782
+ -1 / 66 * math.sqrt(429) * sh_5_2 * x
783
+ - 1 / 66 * math.sqrt(2002) * sh_5_4 * x
784
+ + (1 / 66) * math.sqrt(2002) * sh_5_6 * z
785
+ + (2 / 33) * math.sqrt(286) * sh_5_7 * y
786
+ - 1 / 66 * math.sqrt(429) * sh_5_8 * z
787
+ )
788
+ sh_6_9 = (
789
+ -1 / 132 * math.sqrt(858) * sh_5_1 * x
790
+ - 1 / 22 * math.sqrt(286) * sh_5_3 * x
791
+ + (1 / 22) * math.sqrt(286) * sh_5_7 * z
792
+ + (1 / 22) * math.sqrt(429) * sh_5_8 * y
793
+ - 1 / 132 * math.sqrt(858) * sh_5_9 * z
794
+ )
795
+ sh_6_10 = (
796
+ -1 / 132 * math.sqrt(286) * sh_5_0 * x
797
+ - 1 / 132 * math.sqrt(286) * sh_5_10 * z
798
+ - 1 / 44 * math.sqrt(1430) * sh_5_2 * x
799
+ + (1 / 44) * math.sqrt(1430) * sh_5_8 * z
800
+ + (1 / 33) * math.sqrt(715) * sh_5_9 * y
801
+ )
802
+ sh_6_11 = -1 / 12 * math.sqrt(130) * sh_5_1 * x + (1 / 6) * math.sqrt(13) * sh_5_10 * y + (1 / 12) * math.sqrt(130) * sh_5_9 * z
803
+ sh_6_12 = (1 / 6) * math.sqrt(39) * (-sh_5_0 * x + sh_5_10 * z)
804
+ if lmax == 6:
805
+ return torch.stack(
806
+ [
807
+ sh_1_0,
808
+ sh_1_1,
809
+ sh_1_2,
810
+ sh_2_0,
811
+ sh_2_1,
812
+ sh_2_2,
813
+ sh_2_3,
814
+ sh_2_4,
815
+ sh_3_0,
816
+ sh_3_1,
817
+ sh_3_2,
818
+ sh_3_3,
819
+ sh_3_4,
820
+ sh_3_5,
821
+ sh_3_6,
822
+ sh_4_0,
823
+ sh_4_1,
824
+ sh_4_2,
825
+ sh_4_3,
826
+ sh_4_4,
827
+ sh_4_5,
828
+ sh_4_6,
829
+ sh_4_7,
830
+ sh_4_8,
831
+ sh_5_0,
832
+ sh_5_1,
833
+ sh_5_2,
834
+ sh_5_3,
835
+ sh_5_4,
836
+ sh_5_5,
837
+ sh_5_6,
838
+ sh_5_7,
839
+ sh_5_8,
840
+ sh_5_9,
841
+ sh_5_10,
842
+ sh_6_0,
843
+ sh_6_1,
844
+ sh_6_2,
845
+ sh_6_3,
846
+ sh_6_4,
847
+ sh_6_5,
848
+ sh_6_6,
849
+ sh_6_7,
850
+ sh_6_8,
851
+ sh_6_9,
852
+ sh_6_10,
853
+ sh_6_11,
854
+ sh_6_12,
855
+ ],
856
+ dim=-1,
857
+ )
858
+
859
+ sh_7_0 = (1 / 14) * math.sqrt(210) * (sh_6_0 * z + sh_6_12 * x)
860
+ sh_7_1 = (1 / 7) * math.sqrt(15) * sh_6_0 * y + (3 / 7) * math.sqrt(5) * sh_6_1 * z + (3 / 7) * math.sqrt(5) * sh_6_11 * x
861
+ sh_7_2 = (
862
+ -1 / 182 * math.sqrt(390) * sh_6_0 * z
863
+ + (6 / 91) * math.sqrt(130) * sh_6_1 * y
864
+ + (3 / 91) * math.sqrt(715) * sh_6_10 * x
865
+ + (1 / 182) * math.sqrt(390) * sh_6_12 * x
866
+ + (3 / 91) * math.sqrt(715) * sh_6_2 * z
867
+ )
868
+ sh_7_3 = (
869
+ -3 / 182 * math.sqrt(130) * sh_6_1 * z
870
+ + (3 / 182) * math.sqrt(130) * sh_6_11 * x
871
+ + (3 / 91) * math.sqrt(715) * sh_6_2 * y
872
+ + (5 / 182) * math.sqrt(858) * sh_6_3 * z
873
+ + (5 / 182) * math.sqrt(858) * sh_6_9 * x
874
+ )
875
+ sh_7_4 = (
876
+ (3 / 91) * math.sqrt(65) * sh_6_10 * x
877
+ - 3 / 91 * math.sqrt(65) * sh_6_2 * z
878
+ + (10 / 91) * math.sqrt(78) * sh_6_3 * y
879
+ + (15 / 182) * math.sqrt(78) * sh_6_4 * z
880
+ + (15 / 182) * math.sqrt(78) * sh_6_8 * x
881
+ )
882
+ sh_7_5 = (
883
+ -5 / 91 * math.sqrt(39) * sh_6_3 * z
884
+ + (15 / 91) * math.sqrt(39) * sh_6_4 * y
885
+ + (3 / 91) * math.sqrt(390) * sh_6_5 * z
886
+ + (3 / 91) * math.sqrt(390) * sh_6_7 * x
887
+ + (5 / 91) * math.sqrt(39) * sh_6_9 * x
888
+ )
889
+ sh_7_6 = (
890
+ -15 / 182 * math.sqrt(26) * sh_6_4 * z
891
+ + (12 / 91) * math.sqrt(65) * sh_6_5 * y
892
+ + (2 / 91) * math.sqrt(1365) * sh_6_6 * x
893
+ + (15 / 182) * math.sqrt(26) * sh_6_8 * x
894
+ )
895
+ sh_7_7 = -3 / 91 * math.sqrt(455) * sh_6_5 * x + (1 / 13) * math.sqrt(195) * sh_6_6 * y - 3 / 91 * math.sqrt(455) * sh_6_7 * z
896
+ sh_7_8 = (
897
+ -15 / 182 * math.sqrt(26) * sh_6_4 * x
898
+ + (2 / 91) * math.sqrt(1365) * sh_6_6 * z
899
+ + (12 / 91) * math.sqrt(65) * sh_6_7 * y
900
+ - 15 / 182 * math.sqrt(26) * sh_6_8 * z
901
+ )
902
+ sh_7_9 = (
903
+ -5 / 91 * math.sqrt(39) * sh_6_3 * x
904
+ - 3 / 91 * math.sqrt(390) * sh_6_5 * x
905
+ + (3 / 91) * math.sqrt(390) * sh_6_7 * z
906
+ + (15 / 91) * math.sqrt(39) * sh_6_8 * y
907
+ - 5 / 91 * math.sqrt(39) * sh_6_9 * z
908
+ )
909
+ sh_7_10 = (
910
+ -3 / 91 * math.sqrt(65) * sh_6_10 * z
911
+ - 3 / 91 * math.sqrt(65) * sh_6_2 * x
912
+ - 15 / 182 * math.sqrt(78) * sh_6_4 * x
913
+ + (15 / 182) * math.sqrt(78) * sh_6_8 * z
914
+ + (10 / 91) * math.sqrt(78) * sh_6_9 * y
915
+ )
916
+ sh_7_11 = (
917
+ -3 / 182 * math.sqrt(130) * sh_6_1 * x
918
+ + (3 / 91) * math.sqrt(715) * sh_6_10 * y
919
+ - 3 / 182 * math.sqrt(130) * sh_6_11 * z
920
+ - 5 / 182 * math.sqrt(858) * sh_6_3 * x
921
+ + (5 / 182) * math.sqrt(858) * sh_6_9 * z
922
+ )
923
+ sh_7_12 = (
924
+ -1 / 182 * math.sqrt(390) * sh_6_0 * x
925
+ + (3 / 91) * math.sqrt(715) * sh_6_10 * z
926
+ + (6 / 91) * math.sqrt(130) * sh_6_11 * y
927
+ - 1 / 182 * math.sqrt(390) * sh_6_12 * z
928
+ - 3 / 91 * math.sqrt(715) * sh_6_2 * x
929
+ )
930
+ sh_7_13 = -3 / 7 * math.sqrt(5) * sh_6_1 * x + (3 / 7) * math.sqrt(5) * sh_6_11 * z + (1 / 7) * math.sqrt(15) * sh_6_12 * y
931
+ sh_7_14 = (1 / 14) * math.sqrt(210) * (-sh_6_0 * x + sh_6_12 * z)
932
+ if lmax == 7:
933
+ return torch.stack(
934
+ [
935
+ sh_1_0,
936
+ sh_1_1,
937
+ sh_1_2,
938
+ sh_2_0,
939
+ sh_2_1,
940
+ sh_2_2,
941
+ sh_2_3,
942
+ sh_2_4,
943
+ sh_3_0,
944
+ sh_3_1,
945
+ sh_3_2,
946
+ sh_3_3,
947
+ sh_3_4,
948
+ sh_3_5,
949
+ sh_3_6,
950
+ sh_4_0,
951
+ sh_4_1,
952
+ sh_4_2,
953
+ sh_4_3,
954
+ sh_4_4,
955
+ sh_4_5,
956
+ sh_4_6,
957
+ sh_4_7,
958
+ sh_4_8,
959
+ sh_5_0,
960
+ sh_5_1,
961
+ sh_5_2,
962
+ sh_5_3,
963
+ sh_5_4,
964
+ sh_5_5,
965
+ sh_5_6,
966
+ sh_5_7,
967
+ sh_5_8,
968
+ sh_5_9,
969
+ sh_5_10,
970
+ sh_6_0,
971
+ sh_6_1,
972
+ sh_6_2,
973
+ sh_6_3,
974
+ sh_6_4,
975
+ sh_6_5,
976
+ sh_6_6,
977
+ sh_6_7,
978
+ sh_6_8,
979
+ sh_6_9,
980
+ sh_6_10,
981
+ sh_6_11,
982
+ sh_6_12,
983
+ sh_7_0,
984
+ sh_7_1,
985
+ sh_7_2,
986
+ sh_7_3,
987
+ sh_7_4,
988
+ sh_7_5,
989
+ sh_7_6,
990
+ sh_7_7,
991
+ sh_7_8,
992
+ sh_7_9,
993
+ sh_7_10,
994
+ sh_7_11,
995
+ sh_7_12,
996
+ sh_7_13,
997
+ sh_7_14,
998
+ ],
999
+ dim=-1,
1000
+ )
1001
+
1002
+ sh_8_0 = (1 / 4) * math.sqrt(17) * (sh_7_0 * z + sh_7_14 * x)
1003
+ sh_8_1 = (1 / 8) * math.sqrt(17) * sh_7_0 * y + (1 / 16) * math.sqrt(238) * sh_7_1 * z + (1 / 16) * math.sqrt(238) * sh_7_13 * x
1004
+ sh_8_2 = (
1005
+ -1 / 240 * math.sqrt(510) * sh_7_0 * z
1006
+ + (1 / 60) * math.sqrt(1785) * sh_7_1 * y
1007
+ + (1 / 240) * math.sqrt(46410) * sh_7_12 * x
1008
+ + (1 / 240) * math.sqrt(510) * sh_7_14 * x
1009
+ + (1 / 240) * math.sqrt(46410) * sh_7_2 * z
1010
+ )
1011
+ sh_8_3 = (
1012
+ (1 / 80)
1013
+ * math.sqrt(2)
1014
+ * (
1015
+ -math.sqrt(85) * sh_7_1 * z
1016
+ + math.sqrt(2210) * sh_7_11 * x
1017
+ + math.sqrt(85) * sh_7_13 * x
1018
+ + math.sqrt(2210) * sh_7_2 * y
1019
+ + math.sqrt(2210) * sh_7_3 * z
1020
+ )
1021
+ )
1022
+ sh_8_4 = (
1023
+ (1 / 40) * math.sqrt(935) * sh_7_10 * x
1024
+ + (1 / 40) * math.sqrt(85) * sh_7_12 * x
1025
+ - 1 / 40 * math.sqrt(85) * sh_7_2 * z
1026
+ + (1 / 10) * math.sqrt(85) * sh_7_3 * y
1027
+ + (1 / 40) * math.sqrt(935) * sh_7_4 * z
1028
+ )
1029
+ sh_8_5 = (
1030
+ (1 / 48)
1031
+ * math.sqrt(2)
1032
+ * (
1033
+ math.sqrt(102) * sh_7_11 * x
1034
+ - math.sqrt(102) * sh_7_3 * z
1035
+ + math.sqrt(1122) * sh_7_4 * y
1036
+ + math.sqrt(561) * sh_7_5 * z
1037
+ + math.sqrt(561) * sh_7_9 * x
1038
+ )
1039
+ )
1040
+ sh_8_6 = (
1041
+ (1 / 16) * math.sqrt(34) * sh_7_10 * x
1042
+ - 1 / 16 * math.sqrt(34) * sh_7_4 * z
1043
+ + (1 / 4) * math.sqrt(17) * sh_7_5 * y
1044
+ + (1 / 16) * math.sqrt(102) * sh_7_6 * z
1045
+ + (1 / 16) * math.sqrt(102) * sh_7_8 * x
1046
+ )
1047
+ sh_8_7 = (
1048
+ -1 / 80 * math.sqrt(1190) * sh_7_5 * z
1049
+ + (1 / 40) * math.sqrt(1785) * sh_7_6 * y
1050
+ + (1 / 20) * math.sqrt(255) * sh_7_7 * x
1051
+ + (1 / 80) * math.sqrt(1190) * sh_7_9 * x
1052
+ )
1053
+ sh_8_8 = -1 / 60 * math.sqrt(1785) * sh_7_6 * x + (1 / 15) * math.sqrt(255) * sh_7_7 * y - 1 / 60 * math.sqrt(1785) * sh_7_8 * z
1054
+ sh_8_9 = (
1055
+ -1 / 80 * math.sqrt(1190) * sh_7_5 * x
1056
+ + (1 / 20) * math.sqrt(255) * sh_7_7 * z
1057
+ + (1 / 40) * math.sqrt(1785) * sh_7_8 * y
1058
+ - 1 / 80 * math.sqrt(1190) * sh_7_9 * z
1059
+ )
1060
+ sh_8_10 = (
1061
+ -1 / 16 * math.sqrt(34) * sh_7_10 * z
1062
+ - 1 / 16 * math.sqrt(34) * sh_7_4 * x
1063
+ - 1 / 16 * math.sqrt(102) * sh_7_6 * x
1064
+ + (1 / 16) * math.sqrt(102) * sh_7_8 * z
1065
+ + (1 / 4) * math.sqrt(17) * sh_7_9 * y
1066
+ )
1067
+ sh_8_11 = (
1068
+ (1 / 48)
1069
+ * math.sqrt(2)
1070
+ * (
1071
+ math.sqrt(1122) * sh_7_10 * y
1072
+ - math.sqrt(102) * sh_7_11 * z
1073
+ - math.sqrt(102) * sh_7_3 * x
1074
+ - math.sqrt(561) * sh_7_5 * x
1075
+ + math.sqrt(561) * sh_7_9 * z
1076
+ )
1077
+ )
1078
+ sh_8_12 = (
1079
+ (1 / 40) * math.sqrt(935) * sh_7_10 * z
1080
+ + (1 / 10) * math.sqrt(85) * sh_7_11 * y
1081
+ - 1 / 40 * math.sqrt(85) * sh_7_12 * z
1082
+ - 1 / 40 * math.sqrt(85) * sh_7_2 * x
1083
+ - 1 / 40 * math.sqrt(935) * sh_7_4 * x
1084
+ )
1085
+ sh_8_13 = (
1086
+ (1 / 80)
1087
+ * math.sqrt(2)
1088
+ * (
1089
+ -math.sqrt(85) * sh_7_1 * x
1090
+ + math.sqrt(2210) * sh_7_11 * z
1091
+ + math.sqrt(2210) * sh_7_12 * y
1092
+ - math.sqrt(85) * sh_7_13 * z
1093
+ - math.sqrt(2210) * sh_7_3 * x
1094
+ )
1095
+ )
1096
+ sh_8_14 = (
1097
+ -1 / 240 * math.sqrt(510) * sh_7_0 * x
1098
+ + (1 / 240) * math.sqrt(46410) * sh_7_12 * z
1099
+ + (1 / 60) * math.sqrt(1785) * sh_7_13 * y
1100
+ - 1 / 240 * math.sqrt(510) * sh_7_14 * z
1101
+ - 1 / 240 * math.sqrt(46410) * sh_7_2 * x
1102
+ )
1103
+ sh_8_15 = -1 / 16 * math.sqrt(238) * sh_7_1 * x + (1 / 16) * math.sqrt(238) * sh_7_13 * z + (1 / 8) * math.sqrt(17) * sh_7_14 * y
1104
+ sh_8_16 = (1 / 4) * math.sqrt(17) * (-sh_7_0 * x + sh_7_14 * z)
1105
+ if lmax == 8:
1106
+ return torch.stack(
1107
+ [
1108
+ sh_1_0,
1109
+ sh_1_1,
1110
+ sh_1_2,
1111
+ sh_2_0,
1112
+ sh_2_1,
1113
+ sh_2_2,
1114
+ sh_2_3,
1115
+ sh_2_4,
1116
+ sh_3_0,
1117
+ sh_3_1,
1118
+ sh_3_2,
1119
+ sh_3_3,
1120
+ sh_3_4,
1121
+ sh_3_5,
1122
+ sh_3_6,
1123
+ sh_4_0,
1124
+ sh_4_1,
1125
+ sh_4_2,
1126
+ sh_4_3,
1127
+ sh_4_4,
1128
+ sh_4_5,
1129
+ sh_4_6,
1130
+ sh_4_7,
1131
+ sh_4_8,
1132
+ sh_5_0,
1133
+ sh_5_1,
1134
+ sh_5_2,
1135
+ sh_5_3,
1136
+ sh_5_4,
1137
+ sh_5_5,
1138
+ sh_5_6,
1139
+ sh_5_7,
1140
+ sh_5_8,
1141
+ sh_5_9,
1142
+ sh_5_10,
1143
+ sh_6_0,
1144
+ sh_6_1,
1145
+ sh_6_2,
1146
+ sh_6_3,
1147
+ sh_6_4,
1148
+ sh_6_5,
1149
+ sh_6_6,
1150
+ sh_6_7,
1151
+ sh_6_8,
1152
+ sh_6_9,
1153
+ sh_6_10,
1154
+ sh_6_11,
1155
+ sh_6_12,
1156
+ sh_7_0,
1157
+ sh_7_1,
1158
+ sh_7_2,
1159
+ sh_7_3,
1160
+ sh_7_4,
1161
+ sh_7_5,
1162
+ sh_7_6,
1163
+ sh_7_7,
1164
+ sh_7_8,
1165
+ sh_7_9,
1166
+ sh_7_10,
1167
+ sh_7_11,
1168
+ sh_7_12,
1169
+ sh_7_13,
1170
+ sh_7_14,
1171
+ sh_8_0,
1172
+ sh_8_1,
1173
+ sh_8_2,
1174
+ sh_8_3,
1175
+ sh_8_4,
1176
+ sh_8_5,
1177
+ sh_8_6,
1178
+ sh_8_7,
1179
+ sh_8_8,
1180
+ sh_8_9,
1181
+ sh_8_10,
1182
+ sh_8_11,
1183
+ sh_8_12,
1184
+ sh_8_13,
1185
+ sh_8_14,
1186
+ sh_8_15,
1187
+ sh_8_16,
1188
+ ],
1189
+ dim=-1,
1190
+ )
1191
+
1192
+
1193
+ def lmax_tensor_size(lmax):
1194
+ return ((lmax + 1) ** 2) - 1
1195
+
1196
+
1197
+ def get_split_sizes_from_dim(feature_dim):
1198
+ """
1199
+ Find the lmax value and return split sizes for torch.split based on feature dimension.
1200
+
1201
+ Args:
1202
+ feature_dim: The dimension of the feature (shape[1] of the tensor)
1203
+
1204
+ Returns:
1205
+ split_sizes: A list of split sizes for torch.split (sizes of spherical harmonic components)
1206
+ """
1207
+ lmax = 1
1208
+ while lmax_tensor_size(lmax) < feature_dim:
1209
+ lmax += 1
1210
+
1211
+ if lmax_tensor_size(lmax) != feature_dim:
1212
+ raise ValueError(f"Feature dimension {feature_dim} does not correspond to a valid lmax value")
1213
+
1214
+ # Return the sizes of each spherical harmonic component
1215
+ return [2 * l + 1 for l in range(1, lmax + 1)] # noqa: E741
1216
+
1217
+
1218
+ class TensorLayerNorm(nn.Module):
1219
+ def __init__(self, hidden_channels, trainable):
1220
+ super(TensorLayerNorm, self).__init__()
1221
+
1222
+ self.hidden_channels = hidden_channels
1223
+ self.eps = 1e-12
1224
+
1225
+ weight = torch.ones(self.hidden_channels)
1226
+ if trainable:
1227
+ self.register_parameter("weight", nn.Parameter(weight))
1228
+ else:
1229
+ self.register_buffer("weight", weight)
1230
+
1231
+ self.reset_parameters()
1232
+
1233
+ def reset_parameters(self):
1234
+ weight = torch.ones(self.hidden_channels)
1235
+ self.weight.data.copy_(weight)
1236
+
1237
+ def max_min_norm(self, tensor):
1238
+ # Based on VisNet (https://www.nature.com/articles/s41467-023-43720-2)
1239
+ dist = torch.norm(tensor, dim=1, keepdim=True)
1240
+
1241
+ if (dist == 0).all():
1242
+ return torch.zeros_like(tensor)
1243
+
1244
+ dist = dist.clamp(min=self.eps)
1245
+ direct = tensor / dist
1246
+
1247
+ max_val, _ = torch.max(dist, dim=-1)
1248
+ min_val, _ = torch.min(dist, dim=-1)
1249
+ delta = (max_val - min_val).view(-1)
1250
+ delta = torch.where(delta == 0, torch.ones_like(delta), delta)
1251
+ dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1)
1252
+
1253
+ return F.relu(dist) * direct
1254
+
1255
+ def forward(self, tensor):
1256
+ # vec: (num_atoms, feature_dim, hidden_channels)
1257
+ feature_dim = tensor.shape[1]
1258
+
1259
+ try:
1260
+ split_sizes = get_split_sizes_from_dim(feature_dim)
1261
+ except ValueError as e:
1262
+ raise ValueError(f"VecLayerNorm received unsupported feature dimension {feature_dim}: {str(e)}")
1263
+
1264
+ # Split the vector into parts
1265
+ vec_parts = torch.split(tensor, split_sizes, dim=1)
1266
+
1267
+ # Normalize each part separately
1268
+ normalized_parts = [self.max_min_norm(part) for part in vec_parts]
1269
+
1270
+ # Concatenate the normalized parts
1271
+ normalized_vec = torch.cat(normalized_parts, dim=1)
1272
+
1273
+ # Apply weight
1274
+ return normalized_vec * self.weight.unsqueeze(0).unsqueeze(0)
1275
+
1276
+
1277
+ def normalize_string(s: str) -> str:
1278
+ return s.lower().replace("-", "").replace("_", "").replace(" ", "")
1279
+
1280
+
1281
+ class Swish(nn.Module):
1282
+ def __init__(self):
1283
+ super(Swish, self).__init__()
1284
+
1285
+ def forward(self, x):
1286
+ return x * torch.sigmoid(x)
1287
+
1288
+
1289
+ act_class_mapping = {"ssp": ShiftedSoftplus, "silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "swish": Swish}
1290
+
1291
+
1292
+ # https://github.com/sunglasses-ai/classy/blob/3e74cba1fdf1b9f9f2ba1cfcfa6c2017aa59fc04/classy/optim/factories.py#L14
1293
+ def get_activations(optional=False, *args, **kwargs):
1294
+ activations = {
1295
+ normalize_string(act.__name__): act
1296
+ for act in vars(torch.nn.modules.activation).values()
1297
+ if isinstance(act, type) and issubclass(act, torch.nn.Module)
1298
+ }
1299
+ activations.update(
1300
+ {
1301
+ "relu": torch.nn.ReLU,
1302
+ "elu": torch.nn.ELU,
1303
+ "sigmoid": torch.nn.Sigmoid,
1304
+ "silu": torch.nn.SiLU,
1305
+ "mish": torch.nn.Mish,
1306
+ "swish": torch.nn.SiLU,
1307
+ "selu": torch.nn.SELU,
1308
+ "scaled_swish": scaled_silu,
1309
+ "softplus": shifted_softplus,
1310
+ "slrelu": SmoothLeakyReLU,
1311
+ }
1312
+ )
1313
+
1314
+ if optional:
1315
+ activations[""] = None
1316
+
1317
+ return activations
1318
+
1319
+
1320
+ def get_activations_none(optional=False, *args, **kwargs):
1321
+ activations = {
1322
+ normalize_string(act.__name__): act
1323
+ for act in vars(torch.nn.modules.activation).values()
1324
+ if isinstance(act, type) and issubclass(act, torch.nn.Module)
1325
+ }
1326
+ activations.update(
1327
+ {
1328
+ "relu": torch.nn.ReLU,
1329
+ "elu": torch.nn.ELU,
1330
+ "sigmoid": torch.nn.Sigmoid,
1331
+ "silu": torch.nn.SiLU,
1332
+ "selu": torch.nn.SELU,
1333
+ }
1334
+ )
1335
+
1336
+ if optional:
1337
+ activations[""] = None
1338
+ activations[None] = None
1339
+
1340
+ return activations
1341
+
1342
+
1343
+ def dictionary_to_option(options, selected):
1344
+ if selected not in options:
1345
+ raise ValueError(f'Invalid choice "{selected}", choose one from {", ".join(list(options.keys()))} ')
1346
+
1347
+ activation = options[selected]
1348
+ if inspect.isclass(activation):
1349
+ activation = activation()
1350
+ return activation
1351
+
1352
+
1353
+ def str2act(input_str, *args, **kwargs):
1354
+ if input_str == "":
1355
+ return None
1356
+
1357
+ act = get_activations(optional=True, *args, **kwargs)
1358
+ out = dictionary_to_option(act, input_str)
1359
+ return out
1360
+
1361
+
1362
+ class ExpNormalSmearing(nn.Module):
1363
+ def __init__(self, cutoff=5.0, n_rbf=50, trainable=False):
1364
+ super(ExpNormalSmearing, self).__init__()
1365
+ if isinstance(cutoff, torch.Tensor):
1366
+ cutoff = cutoff.item()
1367
+ self.cutoff = cutoff
1368
+ self.n_rbf = n_rbf
1369
+ self.trainable = trainable
1370
+
1371
+ self.cutoff_fn = CosineCutoff(cutoff)
1372
+ self.alpha = 5.0 / cutoff
1373
+
1374
+ means, betas = self._initial_params()
1375
+ if trainable:
1376
+ self.register_parameter("means", nn.Parameter(means))
1377
+ self.register_parameter("betas", nn.Parameter(betas))
1378
+ else:
1379
+ self.register_buffer("means", means)
1380
+ self.register_buffer("betas", betas)
1381
+
1382
+ def _initial_params(self):
1383
+ start_value = torch.exp(torch.scalar_tensor(-self.cutoff))
1384
+ means = torch.linspace(start_value, 1, self.n_rbf)
1385
+ betas = torch.tensor([(2 / self.n_rbf * (1 - start_value)) ** -2] * self.n_rbf)
1386
+ return means, betas
1387
+
1388
+ def reset_parameters(self):
1389
+ means, betas = self._initial_params()
1390
+ self.means.data.copy_(means)
1391
+ self.betas.data.copy_(betas)
1392
+
1393
+ def forward(self, dist):
1394
+ dist = dist.unsqueeze(-1)
1395
+ return self.cutoff_fn(dist) * torch.exp(-self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2)
1396
+
1397
+
1398
+ def str2basis(input_str):
1399
+ if type(input_str) != str: # noqa: E721
1400
+ return input_str
1401
+
1402
+ if input_str == "BesselBasis":
1403
+ radial_basis = BesselBasis
1404
+ elif input_str == "GaussianRBF":
1405
+ radial_basis = GaussianRBF
1406
+ elif input_str.lower() == "expnorm":
1407
+ radial_basis = ExpNormalSmearing
1408
+ else:
1409
+ raise ValueError("Unknown radial basis: {}".format(input_str))
1410
+
1411
+ return radial_basis
1412
+
1413
+
1414
+ class MLP(nn.Module):
1415
+ def __init__(
1416
+ self,
1417
+ hidden_dims: List[int],
1418
+ bias=True,
1419
+ activation=None,
1420
+ last_activation=None,
1421
+ weight_init=xavier_uniform_,
1422
+ bias_init=zeros_initializer,
1423
+ norm="",
1424
+ ):
1425
+ super().__init__()
1426
+
1427
+ # hidden_dims = [hidden, half, hidden]
1428
+
1429
+ dims = hidden_dims
1430
+ n_layers = len(dims)
1431
+
1432
+ DenseMLP = partial(Dense, bias=bias, weight_init=weight_init, bias_init=bias_init)
1433
+
1434
+ self.dense_layers = nn.ModuleList(
1435
+ [DenseMLP(dims[i], dims[i + 1], activation=activation, norm=norm) for i in range(n_layers - 2)]
1436
+ + [DenseMLP(dims[-2], dims[-1], activation=last_activation)]
1437
+ )
1438
+
1439
+ self.layers = nn.Sequential(*self.dense_layers)
1440
+
1441
+ self.reset_parameters()
1442
+
1443
+ def reset_parameters(self):
1444
+ for m in self.dense_layers:
1445
+ m.reset_parameters()
1446
+
1447
+ def forward(self, x):
1448
+ return self.layers(x)
1449
+
1450
+
1451
+ class NodeInit(MessagePassing):
1452
+ def __init__(
1453
+ self,
1454
+ hidden_channels,
1455
+ num_rbf,
1456
+ cutoff,
1457
+ max_z=100,
1458
+ activation=F.silu,
1459
+ proj_ln="",
1460
+ last_activation=False,
1461
+ weight_init=nn.init.xavier_uniform_,
1462
+ bias_init=nn.init.zeros_,
1463
+ concat=False,
1464
+ ):
1465
+ super(NodeInit, self).__init__(aggr="add")
1466
+
1467
+ if type(hidden_channels) == int: # noqa: E721
1468
+ hidden_channels = [hidden_channels]
1469
+
1470
+ first_channel = hidden_channels[0]
1471
+ last_channel = hidden_channels[-1]
1472
+
1473
+ DenseInit = partial(Dense, weight_init=weight_init, bias_init=bias_init) # noqa: F841
1474
+
1475
+ self.concat = concat
1476
+ self.embedding = nn.Embedding(max_z, last_channel)
1477
+ if self.concat:
1478
+ self.embedding_src = nn.Embedding(max_z, first_channel)
1479
+ self.distance_proj = MLP(
1480
+ [num_rbf + 2 * first_channel] + hidden_channels,
1481
+ activation=activation,
1482
+ norm=proj_ln,
1483
+ weight_init=weight_init,
1484
+ bias_init=bias_init,
1485
+ last_activation=activation if last_activation else None,
1486
+ )
1487
+ else:
1488
+ self.distance_proj = MLP(
1489
+ [num_rbf] + [last_channel], activation=None, norm="", weight_init=weight_init, bias_init=bias_init, last_activation=None
1490
+ )
1491
+
1492
+ if not self.concat:
1493
+ self.combine = MLP(
1494
+ [2 * last_channel] + hidden_channels,
1495
+ activation=activation,
1496
+ norm=proj_ln,
1497
+ weight_init=weight_init,
1498
+ bias_init=bias_init,
1499
+ last_activation=activation if last_activation else None,
1500
+ )
1501
+ self.cutoff = CosineCutoff(cutoff)
1502
+
1503
+ self.reset_parameters()
1504
+
1505
+ def reset_parameters(self):
1506
+ self.embedding.reset_parameters()
1507
+ if self.concat:
1508
+ self.embedding_src.reset_parameters()
1509
+ self.distance_proj.reset_parameters()
1510
+ if not self.concat:
1511
+ self.combine.reset_parameters()
1512
+
1513
+ def forward(self, z, x, edge_index, edge_weight, edge_attr):
1514
+ # remove self loops
1515
+ mask = edge_index[0] != edge_index[1]
1516
+ if not mask.all():
1517
+ edge_index = edge_index[:, mask]
1518
+ edge_weight = edge_weight[mask]
1519
+ edge_attr = edge_attr[mask]
1520
+
1521
+ x_neighbors = self.embedding(z)
1522
+ if not self.concat:
1523
+ C = self.cutoff(edge_weight)
1524
+ W = self.distance_proj(edge_attr) * C.view(-1, 1)
1525
+ x_src = x_neighbors
1526
+ else:
1527
+ x_src = self.embedding_src(z)
1528
+ W = edge_attr
1529
+ # propagate_type: (x: Tensor, s:Tensor, W: Tensor)
1530
+ x_neighbors = self.propagate(edge_index, x=x_neighbors, s=x_src, W=W, size=None)
1531
+
1532
+ if self.concat:
1533
+ x_neighbors = x + x_neighbors
1534
+ else:
1535
+ x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1))
1536
+ return x_neighbors
1537
+
1538
+ def message(self, s_i, x_j, W):
1539
+ if self.concat:
1540
+ return self.distance_proj(torch.cat([W, x_j, s_i], dim=1))
1541
+ return x_j * W
1542
+
1543
+
1544
+ class EdgeInit(MessagePassing):
1545
+ def __init__(
1546
+ self,
1547
+ num_rbf,
1548
+ hidden_channels,
1549
+ activation=F.silu,
1550
+ proj_ln="",
1551
+ last_activation=False,
1552
+ weight_init=nn.init.xavier_uniform_,
1553
+ bias_init=nn.init.zeros_,
1554
+ ):
1555
+ super(EdgeInit, self).__init__(aggr=None)
1556
+ self.activation = activation
1557
+
1558
+ if type(hidden_channels) == int: # noqa: E721
1559
+ hidden_channels = [hidden_channels]
1560
+ self.edge_up = MLP(
1561
+ [num_rbf] + hidden_channels,
1562
+ activation=activation,
1563
+ norm=proj_ln,
1564
+ weight_init=weight_init,
1565
+ bias_init=bias_init,
1566
+ last_activation=activation if last_activation else None,
1567
+ )
1568
+
1569
+ self.reset_parameters()
1570
+
1571
+ def reset_parameters(self):
1572
+ self.edge_up.reset_parameters()
1573
+
1574
+ def forward(self, edge_index, edge_attr, x):
1575
+ # propagate_type: (x: Tensor, edge_attr: Tensor)
1576
+ out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
1577
+ return out
1578
+
1579
+ def message(self, x_i, x_j, edge_attr):
1580
+ return (x_i + x_j) * self.edge_up(edge_attr)
1581
+
1582
+ def aggregate(self, features, index):
1583
+ # no aggregate
1584
+ return features
models/pos_egnn/posegnn/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from torch_nl import compute_neighborlist
6
+ from torch_nl.geometry import compute_distances
7
+ from torch_nl.neighbor_list import compute_cell_shifts
8
+
9
+
10
+ ACT_CLASS_MAPPING = {"silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "gelu": nn.GELU}
11
+
12
+ class BatchedPeriodicDistance(nn.Module):
13
+ """
14
+ Wraps the `torch_nl` package to calculate Periodic Distance using
15
+ PyTorch operations efficiently. Compute the neighbor list for a given cutoff.
16
+ Reference: https://github.com/felixmusil/torch_nl
17
+ """
18
+
19
+ def __init__(self, cutoff: float = 5.0) -> None:
20
+ super().__init__()
21
+ self.cutoff = cutoff
22
+ self.self_interactions = False
23
+
24
+ def forward(
25
+ self, pos: Tensor, box: Tensor, batch: Optional[Tensor] = None, precomputed_edge_index=None, precomputed_shifts_idx=None
26
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
27
+ # No batch, single sample
28
+ if batch is None:
29
+ n_atoms = pos.shape[0]
30
+ batch = torch.zeros(n_atoms, device=pos.device, dtype=torch.int64)
31
+
32
+ is_zero = torch.eq(box, 0)
33
+ is_not_all_zero = ~is_zero.all(dim=-1).all(dim=-1)
34
+ pbc = is_not_all_zero.unsqueeze(-1).repeat(1, 3) # We need to change this when dealing with interfaces
35
+
36
+ if (precomputed_edge_index is None) or (precomputed_shifts_idx is None):
37
+ edge_index, batch_mapping, shifts_idx = compute_neighborlist(self.cutoff, pos, box, pbc, batch, self.self_interactions)
38
+ else:
39
+ edge_index = precomputed_edge_index
40
+ shifts_idx = precomputed_shifts_idx
41
+ batch_mapping = batch[edge_index[0]] # NOTE: should be same as edge_index[1]
42
+
43
+ cell_shifts = compute_cell_shifts(box, shifts_idx, batch_mapping)
44
+ edge_weight = compute_distances(pos, edge_index, cell_shifts)
45
+
46
+ edge_vec = -(pos[edge_index[1]] - pos[edge_index[0]] + cell_shifts)
47
+
48
+ # edge_weight and edge_vec should have grad_fn
49
+ return edge_index, edge_weight, edge_vec, shifts_idx
50
+
51
+
52
+ def get_symmetric_displacement(
53
+ positions: torch.Tensor,
54
+ box: Optional[torch.Tensor],
55
+ num_graphs: int,
56
+ batch: torch.Tensor,
57
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
58
+ displacement = torch.zeros(
59
+ (num_graphs, 3, 3),
60
+ dtype=positions.dtype,
61
+ device=positions.device,
62
+ )
63
+ displacement.requires_grad_(True)
64
+ symmetric_displacement = 0.5 * (displacement + displacement.transpose(-1, -2))
65
+ positions = positions + torch.einsum("be,bec->bc", positions, symmetric_displacement[batch])
66
+ box = box.view(-1, 3, 3)
67
+ box = box + torch.matmul(box, symmetric_displacement)
68
+
69
+ return positions, box, displacement
models/pos_egnn/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ -f https://data.pyg.org/whl/torch-2.5.1%2Bcu121.html
2
+ -f https://data.pyg.org/whl/torch-2.5.1%2Bcpu.html
3
+ numpy==1.26.4
4
+ ase==3.24.0
5
+ torch==2.5.1
6
+ torch_geometric==2.5.3
7
+ torch_nl==0.3
8
+ torch_scatter
9
+ torch_sparse
10
+ tqdm>=4.66.1
requirements.txt CHANGED
@@ -27,3 +27,5 @@ torch-optimizer
27
  tqdm>=4.66.4
28
  pandas==2.2.3
29
  mordred
 
 
 
27
  tqdm>=4.66.4
28
  pandas==2.2.3
29
  mordred
30
+ ase==3.24.0
31
+ torch_nl==0.3