Spaces:
Running
Running
File size: 3,833 Bytes
dded3ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from torch import nn
import torch
from .encoder import GotenNet
from .utils import get_symmetric_displacement, BatchedPeriodicDistance, ACT_CLASS_MAPPING
#from torch_scatter import scatter
class NodeInvariantReadout(nn.Module):
def __init__(self, in_channels, num_residues, hidden_channels, out_channels, activation):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(in_channels, out_channels) for _ in range(num_residues - 1)])
# Define the nonlinear layer for the last layer's output
self.non_linear = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
ACT_CLASS_MAPPING[activation](),
nn.Linear(hidden_channels, out_channels),
)
def forward(self, embedding_0):
layer_outputs = embedding_0.squeeze(2) # [n_nodes, in_channels, num_residues]
processed_outputs = []
for i, linear in enumerate(self.linears):
processed_outputs.append(linear(layer_outputs[:, :, i]))
processed_outputs.append(self.non_linear(layer_outputs[:, :, -1]))
output = torch.stack(processed_outputs, dim=0).sum(dim=0).squeeze(-1)
return output
class PosEGNN(nn.Module):
def __init__(self, config):
super().__init__()
self.distance = BatchedPeriodicDistance(config["encoder"]["cutoff"])
self.encoder = GotenNet(**config["encoder"])
self.readout = NodeInvariantReadout(**config["decoder"])
self.register_buffer("e0_mean", torch.tensor(config["e0_mean"]))
self.register_buffer("atomic_res_total_mean", torch.tensor(config["atomic_res_total_mean"]))
self.register_buffer("atomic_res_total_std", torch.tensor(config["atomic_res_total_std"]))
def forward(self, data):
data.pos.requires_grad_(True)
data.pos, data.box, data.displacements = get_symmetric_displacement(data.pos, data.box, data.num_graphs, data.batch)
data.cutoff_edge_index, data.cutoff_edge_distance, data.cutoff_edge_vec, data.cutoff_shifts_idx = self.distance(
data.pos, data.box, data.batch
)
embedding_dict = self.encoder(data.z, data.pos, data.cutoff_edge_index, data.cutoff_edge_distance, data.cutoff_edge_vec)
return embedding_dict
def compute_properties(self, data, compute_stress = True):
output = {}
embedding_dict = self.forward(data)
embedding_0 = embedding_dict["embedding_0"]
# Compute energy
node_e_res = self.readout(embedding_0)
node_e_res = node_e_res * self.atomic_res_total_std + self.atomic_res_total_mean
total_e_res = scatter(src=node_e_res, index=data["batch"], dim=0, reduce="sum")
node_e0 = self.e0_mean[data.z]
total_e0 = scatter(src=node_e0, index=data["batch"], dim=0, reduce="sum")
total_energy = total_e0 + total_e_res
output["total_energy"] = total_energy
# Compute gradients
if compute_stress:
inputs = [data.pos, data.displacements]
compute_stress = True
else:
inputs = [data.pos]
grad_outputs = torch.autograd.grad(
outputs=[total_energy],
inputs=inputs,
grad_outputs=[torch.ones_like(total_energy)],
retain_graph=self.training,
create_graph=self.training,
)
# Get forces and stresses
if compute_stress:
force, virial = grad_outputs
stress = virial / torch.det(data.box).abs().view(-1, 1, 1)
stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress))
output["force"] = -force
output["stress"] = -stress
else:
force = grad_outputs[0]
output["force"] = -force
return output |