ipd's picture
add pos-egnn
dded3ca
raw
history blame
2.74 kB
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from torch_nl import compute_neighborlist
from torch_nl.geometry import compute_distances
from torch_nl.neighbor_list import compute_cell_shifts
ACT_CLASS_MAPPING = {"silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "gelu": nn.GELU}
class BatchedPeriodicDistance(nn.Module):
"""
Wraps the `torch_nl` package to calculate Periodic Distance using
PyTorch operations efficiently. Compute the neighbor list for a given cutoff.
Reference: https://github.com/felixmusil/torch_nl
"""
def __init__(self, cutoff: float = 5.0) -> None:
super().__init__()
self.cutoff = cutoff
self.self_interactions = False
def forward(
self, pos: Tensor, box: Tensor, batch: Optional[Tensor] = None, precomputed_edge_index=None, precomputed_shifts_idx=None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# No batch, single sample
if batch is None:
n_atoms = pos.shape[0]
batch = torch.zeros(n_atoms, device=pos.device, dtype=torch.int64)
is_zero = torch.eq(box, 0)
is_not_all_zero = ~is_zero.all(dim=-1).all(dim=-1)
pbc = is_not_all_zero.unsqueeze(-1).repeat(1, 3) # We need to change this when dealing with interfaces
if (precomputed_edge_index is None) or (precomputed_shifts_idx is None):
edge_index, batch_mapping, shifts_idx = compute_neighborlist(self.cutoff, pos, box, pbc, batch, self.self_interactions)
else:
edge_index = precomputed_edge_index
shifts_idx = precomputed_shifts_idx
batch_mapping = batch[edge_index[0]] # NOTE: should be same as edge_index[1]
cell_shifts = compute_cell_shifts(box, shifts_idx, batch_mapping)
edge_weight = compute_distances(pos, edge_index, cell_shifts)
edge_vec = -(pos[edge_index[1]] - pos[edge_index[0]] + cell_shifts)
# edge_weight and edge_vec should have grad_fn
return edge_index, edge_weight, edge_vec, shifts_idx
def get_symmetric_displacement(
positions: torch.Tensor,
box: Optional[torch.Tensor],
num_graphs: int,
batch: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
displacement = torch.zeros(
(num_graphs, 3, 3),
dtype=positions.dtype,
device=positions.device,
)
displacement.requires_grad_(True)
symmetric_displacement = 0.5 * (displacement + displacement.transpose(-1, -2))
positions = positions + torch.einsum("be,bec->bc", positions, symmetric_displacement[batch])
box = box.view(-1, 3, 3)
box = box + torch.matmul(box, symmetric_displacement)
return positions, box, displacement