Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from esm.utils.constants.physics import ( | |
BB_COORDINATES, | |
) | |
from esm.utils.structure.affine3d import ( | |
Affine3D, | |
RotationMatrix, | |
) | |
class Dim6RotStructureHead(nn.Module): | |
# Normally, AF2 uses quaternions to specify rotations. There's some evidence that | |
# other representations are more well behaved - the best one according to | |
# https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhou_On_the_Continuity_of_Rotation_Representations_in_Neural_Networks_CVPR_2019_paper.pdf | |
# is using graham schmidt on 2 vectors, which is implemented here. | |
def __init__( | |
self, | |
input_dim: int, | |
trans_scale_factor: float = 10, | |
norm_type: str = "layernorm", | |
activation_fn: str = "esm_gelu", | |
predict_torsion_angles: bool = True, | |
): | |
super().__init__() | |
self.ffn1 = nn.Linear(input_dim, input_dim) | |
self.activation_fn = nn.GELU() | |
self.norm = nn.LayerNorm(input_dim) | |
self.proj = nn.Linear(input_dim, 9 + 7 * 2) | |
self.trans_scale_factor = trans_scale_factor | |
self.predict_torsion_angles = predict_torsion_angles | |
self.bb_local_coords = torch.tensor(BB_COORDINATES).float() | |
def forward(self, x, affine, affine_mask, **kwargs): | |
if affine is None: | |
rigids = Affine3D.identity( | |
x.shape[:-1], | |
dtype=x.dtype, | |
device=x.device, | |
requires_grad=self.training, | |
rotation_type=RotationMatrix, | |
) | |
else: | |
rigids = affine | |
# [*, N] | |
x = self.ffn1(x) | |
x = self.activation_fn(x) | |
x = self.norm(x) | |
trans, x, y, angles = self.proj(x).split([3, 3, 3, 7 * 2], dim=-1) | |
trans = trans * self.trans_scale_factor | |
x = x / (x.norm(dim=-1, keepdim=True) + 1e-5) | |
y = y / (y.norm(dim=-1, keepdim=True) + 1e-5) | |
update = Affine3D.from_graham_schmidt(x + trans, trans, y + trans) | |
rigids = rigids.compose(update.mask(affine_mask)) | |
affine = rigids.tensor | |
# We approximate the positions of the backbone atoms in the global frame by applying the rigid | |
# transformation to the mean of the backbone atoms in the local frame. | |
all_bb_coords_local = ( | |
self.bb_local_coords[None, None, :, :] | |
.expand(*x.shape[:-1], 3, 3) | |
.to(x.device) | |
) | |
pred_xyz = rigids[..., None].apply(all_bb_coords_local) | |
return affine, pred_xyz | |