from typing import Optional, Tuple import torch from torch import Tensor, nn from torch_nl import compute_neighborlist from torch_nl.geometry import compute_distances from torch_nl.neighbor_list import compute_cell_shifts ACT_CLASS_MAPPING = {"silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "gelu": nn.GELU} class BatchedPeriodicDistance(nn.Module): """ Wraps the `torch_nl` package to calculate Periodic Distance using PyTorch operations efficiently. Compute the neighbor list for a given cutoff. Reference: https://github.com/felixmusil/torch_nl """ def __init__(self, cutoff: float = 5.0) -> None: super().__init__() self.cutoff = cutoff self.self_interactions = False def forward( self, pos: Tensor, box: Tensor, batch: Optional[Tensor] = None, precomputed_edge_index=None, precomputed_shifts_idx=None ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # No batch, single sample if batch is None: n_atoms = pos.shape[0] batch = torch.zeros(n_atoms, device=pos.device, dtype=torch.int64) is_zero = torch.eq(box, 0) is_not_all_zero = ~is_zero.all(dim=-1).all(dim=-1) pbc = is_not_all_zero.unsqueeze(-1).repeat(1, 3) # We need to change this when dealing with interfaces if (precomputed_edge_index is None) or (precomputed_shifts_idx is None): edge_index, batch_mapping, shifts_idx = compute_neighborlist(self.cutoff, pos, box, pbc, batch, self.self_interactions) else: edge_index = precomputed_edge_index shifts_idx = precomputed_shifts_idx batch_mapping = batch[edge_index[0]] # NOTE: should be same as edge_index[1] cell_shifts = compute_cell_shifts(box, shifts_idx, batch_mapping) edge_weight = compute_distances(pos, edge_index, cell_shifts) edge_vec = -(pos[edge_index[1]] - pos[edge_index[0]] + cell_shifts) # edge_weight and edge_vec should have grad_fn return edge_index, edge_weight, edge_vec, shifts_idx def get_symmetric_displacement( positions: torch.Tensor, box: Optional[torch.Tensor], num_graphs: int, batch: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: displacement = torch.zeros( (num_graphs, 3, 3), dtype=positions.dtype, device=positions.device, ) displacement.requires_grad_(True) symmetric_displacement = 0.5 * (displacement + displacement.transpose(-1, -2)) positions = positions + torch.einsum("be,bec->bc", positions, symmetric_displacement[batch]) box = box.view(-1, 3, 3) box = box + torch.matmul(box, symmetric_displacement) return positions, box, displacement