M3Site / esm /layers /geom_attention.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
from math import sqrt
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
class GeometricReasoningOriginalImpl(nn.Module):
def __init__(
self,
c_s: int,
v_heads: int,
num_vector_messages: int = 1,
mask_and_zero_frameless: bool = True,
divide_residual_by_depth: bool = False,
bias: bool = False,
):
"""Approximate implementation:
ATTN(A, v) := (softmax_j A_ij) v_j
make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3)
make_vectors(x) := T(i->g) Linear(x).reshape(..., 3)
v <- make_rot_vectors(x)
q_dir, k_dir <- make_rot_vectors(x)
q_dist, k_dist <- make_vectors(x)
A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2
x <- x + Linear(T(g->i) ATTN(A, v))
"""
super().__init__()
self.c_s = c_s
self.v_heads = v_heads
self.num_vector_messages = num_vector_messages
self.mask_and_zero_frameless = mask_and_zero_frameless
self.s_norm = nn.LayerNorm(c_s, bias=bias)
dim_proj = (
4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages
) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z)
self.proj = nn.Linear(c_s, dim_proj, bias=bias)
channels_out = self.v_heads * 3 * self.num_vector_messages
self.out_proj = nn.Linear(channels_out, c_s, bias=bias)
# The basic idea is for some attention heads to pay more or less attention to rotation versus distance,
# as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues
# very nearby or should there be shallower dropoff in attention weight?)
self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
def forward(self, s, affine, affine_mask, sequence_id, chain_id):
attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
attn_bias = attn_bias.unsqueeze(1).float()
attn_bias = attn_bias.masked_fill(
~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min
)
chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2)
attn_bias = attn_bias.masked_fill(
chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min
)
ns = self.s_norm(s)
vec_rot, vec_dist = self.proj(ns).split(
[
self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
self.v_heads * 2 * 3,
],
dim=-1,
)
# Rotate the queries and keys for the rotation term. We also rotate the values.
# NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change
# this in the future.
query_rot, key_rot, value = (
affine.rot[..., None]
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
.split(
[
self.v_heads,
self.v_heads,
self.v_heads * self.num_vector_messages,
],
dim=-2,
)
)
# Rotate and translate the queries and keys for the distance term
# NOTE(thayes): a simple speedup would be to apply all rotations together, then
# separately apply the translations.
query_dist, key_dist = (
affine[..., None]
.apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3))
.chunk(2, dim=-2)
)
query_dist = rearrange(query_dist, "b s h d -> b h s 1 d")
key_dist = rearrange(key_dist, "b s h d -> b h 1 s d")
query_rot = rearrange(query_rot, "b s h d -> b h s d")
key_rot = rearrange(key_rot, "b s h d -> b h d s")
value = rearrange(
value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages
)
distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3)
rotation_term = query_rot.matmul(key_rot) / sqrt(3)
distance_term_weight = rearrange(
F.softplus(self.distance_scale_per_head), "h -> h 1 1"
)
rotation_term_weight = rearrange(
F.softplus(self.rotation_scale_per_head), "h -> h 1 1"
)
attn_weight = (
rotation_term * rotation_term_weight - distance_term * distance_term_weight
)
if attn_bias is not None:
# we can re-use the attention bias from the transformer layers
# NOTE(thayes): This attention bias is expected to handle two things:
# 1. Masking attention on padding tokens
# 2. Masking cross sequence attention in the case of bin packing
s_q = attn_weight.size(2)
s_k = attn_weight.size(3)
_s_q = max(0, attn_bias.size(2) - s_q)
_s_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
attn_weight = attn_weight + attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_out = attn_weight.matmul(value)
attn_out = (
affine.rot[..., None]
.invert()
.apply(
rearrange(
attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages
)
)
)
attn_out = rearrange(
attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages
)
if self.mask_and_zero_frameless:
attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0)
s = self.out_proj(attn_out)
return s