File size: 2,497 Bytes
dded3ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
import torch
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from ase.data import atomic_numbers
from ase.stress import full_3x3_to_voigt_6_stress
from torch_geometric.data.data import Data

from .model import PosEGNN


class PosEGNNCalculator(Calculator):
    def __init__(self, checkpoint: str, device: str, compute_stress: bool = True, **kwargs):
        Calculator.__init__(self, **kwargs)

        checkpoint_dict = torch.load(checkpoint, weights_only=True, map_location=device)

        self.model = PosEGNN(checkpoint_dict["config"])
        self.model.load_state_dict(checkpoint_dict["state_dict"], strict=True)
        self.model.eval()

        self.model.to(device)
        self.model.eval()

        self.implemented_properties = ["energy", "forces"]
        self.implemented_properties += ["stress"] if compute_stress else []
        self.device = device
        self.compute_stress = compute_stress

    def calculate(self, atoms=None, properties=None, system_changes=all_changes):
        Calculator.calculate(self, atoms)
        self.results = {}
        data = self._build_data(atoms)
        out = self.model.compute_properties(data, compute_stress=self.compute_stress)

        # Decoder Forward
        self.results = {
            "energy": out["total_energy"].cpu().detach().numpy(),
            "forces": out["force"].cpu().detach().numpy()
        }
        if self.compute_stress:
            self.results.update({
                "stress": full_3x3_to_voigt_6_stress(out["stress"].cpu().detach().numpy())
                })

    def _build_data(self, atoms):
        z = torch.tensor(np.array([atomic_numbers[symbol] for symbol in atoms.symbols]), device=self.device)
        box = torch.tensor(atoms.get_cell().tolist(), device=self.device).unsqueeze(0).float()
        pos = torch.tensor(atoms.get_positions().tolist(), device=self.device).float()
        batch = torch.zeros(len(z), device=self.device).long()
        ptr = torch.zeros(1, device=self.device).long()
        return Data(z=z, pos=pos, box=box, batch=batch, num_graphs=1, ptr=ptr)


def get_invariant_embeddings(self):
    if self.calc is None:
        raise RuntimeError("No calculator is set.")
    else:
        data = self.calc._build_data(self)
        with torch.no_grad():
            embeddings = self.calc.model(data)["embedding_0"][..., 1].squeeze(2)
        return embeddings


Atoms.get_invariant_embeddings = get_invariant_embeddings