{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# POS-EGNN " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Uncomment to install notebook-only dependencies\n", "# !pip install nglview ipywidgets" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4bac12c6048044898065f0778d95caeb", "version_major": 2, "version_minor": 0 }, "text/plain": [] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import nglview as nv\n", "import torch\n", "from ase import units\n", "from ase.io import read\n", "from ase.md.langevin import Langevin\n", "\n", "from posegnn.calculator import PosEGNNCalculator" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "device = \"cpu\"\n", "torch.set_float32_matmul_precision(\"high\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Extraction" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Please download checkpoint from https://huggingface.co/ibm-research/materials.pos-egnn\n", "calculator = PosEGNNCalculator(\"pos-egnn.v1-6M.ckpt\", device=device, compute_stress=False)\n", "atoms = read(\"inputs/3BPA.xyz\", index=0)\n", "atoms.calc = calculator" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([27, 256])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embeddings = atoms.get_invariant_embeddings()\n", "embeddings.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "energy = atoms.get_potential_energy()\n", "forces = atoms.get_forces()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([-175.05188], dtype=float32),\n", " array([[ 0.34280202, -0.41967863, 0.7246248 ],\n", " [-0.86854756, -0.12186409, -2.305024 ],\n", " [ 0.26306945, 0.06607065, 0.85476065],\n", " [-0.230737 , 0.02304646, -0.5161394 ],\n", " [-0.43901953, 2.7678285 , -0.70297724],\n", " [ 0.03933215, -0.50390136, 1.0451801 ],\n", " [ 0.37628424, -2.2708364 , -0.7662437 ],\n", " [ 0.25884533, -1.6086004 , -0.08700082],\n", " [-0.09319548, -0.24666801, -0.48069426],\n", " [ 0.01849201, 1.001767 , 2.151208 ],\n", " [-0.46055827, 1.3630681 , -0.38470453],\n", " [ 0.38605827, -0.32170498, 0.6269282 ],\n", " [-0.29103595, 0.22509174, -0.26729944],\n", " [ 1.3340423 , -1.727819 , -0.08812339],\n", " [-0.96442086, 1.1447092 , 1.0665402 ],\n", " [-0.74679977, 0.56782806, 0.03098067],\n", " [ 0.42040402, 0.7405614 , -0.6953748 ],\n", " [-0.25654212, 0.25282693, 0.25414664],\n", " [ 2.0051584 , -0.38257334, -0.26911467],\n", " [-0.00743119, 0.43786597, -0.27683535],\n", " [ 0.64563835, -0.5602143 , -0.11240276],\n", " [-0.00601408, -1.03808 , 0.23635206],\n", " [-0.04149596, 0.02955294, -0.06748012],\n", " [-0.86066115, 0.00299245, 0.06783121],\n", " [-0.05461264, 0.05352221, -0.06844339],\n", " [-0.26291835, 0.58347785, 0.19614606],\n", " [-0.50613666, -0.05826864, -0.16684091]], dtype=float32))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "energy, forces" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Molecular Dynamics Simulation" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dyn = Langevin(atoms=atoms, friction=0.005, temperature_K=310, timestep=0.5 * units.fs)\n", "\n", "def write_frame():\n", " dyn.atoms.write(\"output.xyz\", append=True)\n", "\n", "dyn.attach(write_frame, interval=5)\n", "dyn.run(500)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "139d7605baca43d79ea515d3454d9941", "version_major": 2, "version_minor": 0 }, "text/plain": [ "NGLWidget(max_frame=234)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "traj = read('output.xyz', index=slice(None))\n", "view = nv.show_asetraj(traj)\n", "display(view)" ] } ], "metadata": { "kernelspec": { "display_name": "py311", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }