# POS-EGNN 

## Setup

In [1]:
# Uncomment to install notebook-only dependencies
# !pip install nglview ipywidgets

In [2]:
import nglview as nv
import torch
from ase import units
from ase.io import read
from ase.md.langevin import Langevin

from posegnn.calculator import PosEGNNCalculator



In [3]:
device = "cpu"
torch.set_float32_matmul_precision("high")

## Feature Extraction

In [4]:
# Please download checkpoint from https://huggingface.co/ibm-research/materials.pos-egnn
calculator = PosEGNNCalculator("pos-egnn.v1-6M.ckpt", device=device, compute_stress=False)
atoms = read("inputs/3BPA.xyz", index=0)
atoms.calc = calculator

In [5]:
embeddings = atoms.get_invariant_embeddings()
embeddings.shape

torch.Size([27, 256])

## Inference

In [6]:
energy = atoms.get_potential_energy()
forces = atoms.get_forces()

In [7]:
energy, forces

(array([-175.05188], dtype=float32),
 array([[ 0.34280202, -0.41967863, 0.7246248 ],
 [-0.86854756, -0.12186409, -2.305024 ],
 [ 0.26306945, 0.06607065, 0.85476065],
 [-0.230737 , 0.02304646, -0.5161394 ],
 [-0.43901953, 2.7678285 , -0.70297724],
 [ 0.03933215, -0.50390136, 1.0451801 ],
 [ 0.37628424, -2.2708364 , -0.7662437 ],
 [ 0.25884533, -1.6086004 , -0.08700082],
 [-0.09319548, -0.24666801, -0.48069426],
 [ 0.01849201, 1.001767 , 2.151208 ],
 [-0.46055827, 1.3630681 , -0.38470453],
 [ 0.38605827, -0.32170498, 0.6269282 ],
 [-0.29103595, 0.22509174, -0.26729944],
 [ 1.3340423 , -1.727819 , -0.08812339],
 [-0.96442086, 1.1447092 , 1.0665402 ],
 [-0.74679977, 0.56782806, 0.03098067],
 [ 0.42040402, 0.7405614 , -0.6953748 ],
 [-0.25654212, 0.25282693, 0.25414664],
 [ 2.0051584 , -0.38257334, -0.26911467],
 [-0.00743119, 0.43786597, -0.27683535],
 [ 0.64563835, -0.5602143 , -0.11240276],
 [-0.00601408, -1.03808 , 0.23635206],
 [-0.04149596, 0.02955294, -0.06748012],
 [-0.86066115, 0.0

## Molecular Dynamics Simulation

In [11]:
dyn = Langevin(atoms=atoms, friction=0.005, temperature_K=310, timestep=0.5 * units.fs)

def write_frame():
 dyn.atoms.write("output.xyz", append=True)

dyn.attach(write_frame, interval=5)
dyn.run(500)

True

In [12]:
traj = read('output.xyz', index=slice(None))
view = nv.show_asetraj(traj)
display(view)

NGLWidget(max_frame=234)