File size: 2,735 Bytes
dded3ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Optional, Tuple

import torch
from torch import Tensor, nn
from torch_nl import compute_neighborlist
from torch_nl.geometry import compute_distances
from torch_nl.neighbor_list import compute_cell_shifts


ACT_CLASS_MAPPING = {"silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "gelu": nn.GELU}

class BatchedPeriodicDistance(nn.Module):
    """
    Wraps the `torch_nl` package to calculate Periodic Distance using
    PyTorch operations efficiently. Compute the neighbor list for a given cutoff.
    Reference: https://github.com/felixmusil/torch_nl
    """

    def __init__(self, cutoff: float = 5.0) -> None:
        super().__init__()
        self.cutoff = cutoff
        self.self_interactions = False

    def forward(
        self, pos: Tensor, box: Tensor, batch: Optional[Tensor] = None, precomputed_edge_index=None, precomputed_shifts_idx=None
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        # No batch, single sample
        if batch is None:
            n_atoms = pos.shape[0]
            batch = torch.zeros(n_atoms, device=pos.device, dtype=torch.int64)

        is_zero = torch.eq(box, 0)
        is_not_all_zero = ~is_zero.all(dim=-1).all(dim=-1)
        pbc = is_not_all_zero.unsqueeze(-1).repeat(1, 3)  # We need to change this when dealing with interfaces

        if (precomputed_edge_index is None) or (precomputed_shifts_idx is None):
            edge_index, batch_mapping, shifts_idx = compute_neighborlist(self.cutoff, pos, box, pbc, batch, self.self_interactions)
        else:
            edge_index = precomputed_edge_index
            shifts_idx = precomputed_shifts_idx
            batch_mapping = batch[edge_index[0]]  # NOTE: should be same as edge_index[1]

        cell_shifts = compute_cell_shifts(box, shifts_idx, batch_mapping)
        edge_weight = compute_distances(pos, edge_index, cell_shifts)

        edge_vec = -(pos[edge_index[1]] - pos[edge_index[0]] + cell_shifts)

        # edge_weight and edge_vec should have grad_fn
        return edge_index, edge_weight, edge_vec, shifts_idx


def get_symmetric_displacement(
    positions: torch.Tensor,
    box: Optional[torch.Tensor],
    num_graphs: int,
    batch: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    displacement = torch.zeros(
        (num_graphs, 3, 3),
        dtype=positions.dtype,
        device=positions.device,
    )
    displacement.requires_grad_(True)
    symmetric_displacement = 0.5 * (displacement + displacement.transpose(-1, -2))
    positions = positions + torch.einsum("be,bec->bc", positions, symmetric_displacement[batch])
    box = box.view(-1, 3, 3)
    box = box + torch.matmul(box, symmetric_displacement)

    return positions, box, displacement