Spaces:
Running
Running
from math import sqrt | |
import torch | |
from einops import rearrange | |
from torch import nn | |
from torch.nn import functional as F | |
class GeometricReasoningOriginalImpl(nn.Module): | |
def __init__( | |
self, | |
c_s: int, | |
v_heads: int, | |
num_vector_messages: int = 1, | |
mask_and_zero_frameless: bool = True, | |
divide_residual_by_depth: bool = False, | |
bias: bool = False, | |
): | |
"""Approximate implementation: | |
ATTN(A, v) := (softmax_j A_ij) v_j | |
make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3) | |
make_vectors(x) := T(i->g) Linear(x).reshape(..., 3) | |
v <- make_rot_vectors(x) | |
q_dir, k_dir <- make_rot_vectors(x) | |
q_dist, k_dist <- make_vectors(x) | |
A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2 | |
x <- x + Linear(T(g->i) ATTN(A, v)) | |
""" | |
super().__init__() | |
self.c_s = c_s | |
self.v_heads = v_heads | |
self.num_vector_messages = num_vector_messages | |
self.mask_and_zero_frameless = mask_and_zero_frameless | |
self.s_norm = nn.LayerNorm(c_s, bias=bias) | |
dim_proj = ( | |
4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages | |
) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z) | |
self.proj = nn.Linear(c_s, dim_proj, bias=bias) | |
channels_out = self.v_heads * 3 * self.num_vector_messages | |
self.out_proj = nn.Linear(channels_out, c_s, bias=bias) | |
# The basic idea is for some attention heads to pay more or less attention to rotation versus distance, | |
# as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues | |
# very nearby or should there be shallower dropoff in attention weight?) | |
self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads))) | |
self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads))) | |
def forward(self, s, affine, affine_mask, sequence_id, chain_id): | |
attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2) | |
attn_bias = attn_bias.unsqueeze(1).float() | |
attn_bias = attn_bias.masked_fill( | |
~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min | |
) | |
chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2) | |
attn_bias = attn_bias.masked_fill( | |
chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min | |
) | |
ns = self.s_norm(s) | |
vec_rot, vec_dist = self.proj(ns).split( | |
[ | |
self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages, | |
self.v_heads * 2 * 3, | |
], | |
dim=-1, | |
) | |
# Rotate the queries and keys for the rotation term. We also rotate the values. | |
# NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change | |
# this in the future. | |
query_rot, key_rot, value = ( | |
affine.rot[..., None] | |
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3)) | |
.split( | |
[ | |
self.v_heads, | |
self.v_heads, | |
self.v_heads * self.num_vector_messages, | |
], | |
dim=-2, | |
) | |
) | |
# Rotate and translate the queries and keys for the distance term | |
# NOTE(thayes): a simple speedup would be to apply all rotations together, then | |
# separately apply the translations. | |
query_dist, key_dist = ( | |
affine[..., None] | |
.apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3)) | |
.chunk(2, dim=-2) | |
) | |
query_dist = rearrange(query_dist, "b s h d -> b h s 1 d") | |
key_dist = rearrange(key_dist, "b s h d -> b h 1 s d") | |
query_rot = rearrange(query_rot, "b s h d -> b h s d") | |
key_rot = rearrange(key_rot, "b s h d -> b h d s") | |
value = rearrange( | |
value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages | |
) | |
distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3) | |
rotation_term = query_rot.matmul(key_rot) / sqrt(3) | |
distance_term_weight = rearrange( | |
F.softplus(self.distance_scale_per_head), "h -> h 1 1" | |
) | |
rotation_term_weight = rearrange( | |
F.softplus(self.rotation_scale_per_head), "h -> h 1 1" | |
) | |
attn_weight = ( | |
rotation_term * rotation_term_weight - distance_term * distance_term_weight | |
) | |
if attn_bias is not None: | |
# we can re-use the attention bias from the transformer layers | |
# NOTE(thayes): This attention bias is expected to handle two things: | |
# 1. Masking attention on padding tokens | |
# 2. Masking cross sequence attention in the case of bin packing | |
s_q = attn_weight.size(2) | |
s_k = attn_weight.size(3) | |
_s_q = max(0, attn_bias.size(2) - s_q) | |
_s_k = max(0, attn_bias.size(3) - s_k) | |
attn_bias = attn_bias[:, :, _s_q:, _s_k:] | |
attn_weight = attn_weight + attn_bias | |
attn_weight = torch.softmax(attn_weight, dim=-1) | |
attn_out = attn_weight.matmul(value) | |
attn_out = ( | |
affine.rot[..., None] | |
.invert() | |
.apply( | |
rearrange( | |
attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages | |
) | |
) | |
) | |
attn_out = rearrange( | |
attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages | |
) | |
if self.mask_and_zero_frameless: | |
attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0) | |
s = self.out_proj(attn_out) | |
return s | |