|
|
from torch import nn |
|
|
import torch |
|
|
import numpy as np |
|
|
from transformers import AutoModel |
|
|
import torch.nn.functional as F |
|
|
import esm |
|
|
import copy |
|
|
import pdb |
|
|
|
|
|
class GaussianFourierProjection(nn.Module): |
|
|
""" |
|
|
Gaussian random features for encoding time steps. |
|
|
""" |
|
|
|
|
|
def __init__(self, embed_dim, scale=30.): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) |
|
|
|
|
|
def forward(self, x): |
|
|
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi |
|
|
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) |
|
|
|
|
|
class Dense(nn.Module): |
|
|
""" |
|
|
A fully connected layer that reshapes outputs to feature maps. |
|
|
""" |
|
|
|
|
|
def __init__(self, input_dim, output_dim): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(input_dim, output_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.dense(x)[...] |
|
|
|
|
|
class Swish(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def forward(self, x): |
|
|
return torch.sigmoid(x) * x |
|
|
|
|
|
class CNNESMModel(nn.Module): |
|
|
"""A time-dependent score-based model built upon U-Net architecture.""" |
|
|
|
|
|
def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): |
|
|
""" |
|
|
Args: |
|
|
embed_dim (int): Dimensionality of the token and time embeddings. |
|
|
""" |
|
|
super().__init__() |
|
|
self.alphabet_size = alphabet_size |
|
|
|
|
|
|
|
|
self.esm = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
self.esm.eval() |
|
|
for param in self.esm.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.time_embed = nn.Sequential( |
|
|
GaussianFourierProjection(embed_dim=embed_dim), |
|
|
nn.Linear(embed_dim, embed_dim) |
|
|
) |
|
|
|
|
|
self.swish = Swish() |
|
|
|
|
|
n = hidden_dim |
|
|
|
|
|
self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) |
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
]) |
|
|
|
|
|
self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)]) |
|
|
self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)]) |
|
|
|
|
|
self.final = nn.Sequential( |
|
|
nn.Conv1d(n, n, kernel_size=1), |
|
|
nn.GELU(), |
|
|
nn.Conv1d(n, self.alphabet_size, kernel_size=1) |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, x, t): |
|
|
""" |
|
|
Args: |
|
|
x: Tensor of shape (B, L) containing DNA token indices. |
|
|
t: Tensor of shape (B,) containing the time steps. |
|
|
Returns: |
|
|
out: Tensor of shape (B, L, 4) with output logits for each DNA base. |
|
|
""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
x = self.esm(input_ids=x).last_hidden_state |
|
|
time_embed = self.swish(self.time_embed(t)) |
|
|
|
|
|
out = x.permute(0, 2, 1) |
|
|
out = self.swish(self.linear(out)) |
|
|
|
|
|
|
|
|
for block, dense, norm in zip(self.blocks, self.denses, self.norms): |
|
|
|
|
|
h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) |
|
|
|
|
|
if h.shape == out.shape: |
|
|
out = h + out |
|
|
else: |
|
|
out = h |
|
|
|
|
|
out = self.final(out) |
|
|
out = out.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
out = out - out.mean(dim=-1, keepdim=True) |
|
|
return out |
|
|
|
|
|
|
|
|
class MLPModel(nn.Module): |
|
|
def __init__( |
|
|
self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.time_dim = time_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.time_embedding = nn.Linear(1, time_dim) |
|
|
self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim) |
|
|
|
|
|
self.swish = Swish() |
|
|
|
|
|
self.main = nn.Sequential( |
|
|
self.swish, |
|
|
nn.Linear(hidden_dim * length + time_dim, hidden_dim), |
|
|
self.swish, |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
self.swish, |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
self.swish, |
|
|
nn.Linear(hidden_dim, self.input_dim * length), |
|
|
) |
|
|
|
|
|
def forward(self, x, t): |
|
|
''' |
|
|
x shape (B,L) |
|
|
t shape (B,) |
|
|
''' |
|
|
t = self.time_embedding(t.unsqueeze(-1)) |
|
|
x = self.token_embedding(x) |
|
|
|
|
|
B, N, d = x.shape |
|
|
x = x.reshape(B, N * d) |
|
|
|
|
|
h = torch.cat([x, t], dim=1) |
|
|
h = self.main(h) |
|
|
|
|
|
h = h.reshape(B, N, self.input_dim) |
|
|
|
|
|
return h |
|
|
|
|
|
class CNNModel(nn.Module): |
|
|
"""A time-dependent score-based model built upon U-Net architecture.""" |
|
|
|
|
|
def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): |
|
|
""" |
|
|
Args: |
|
|
embed_dim (int): Dimensionality of the token and time embeddings. |
|
|
""" |
|
|
super().__init__() |
|
|
self.alphabet_size = alphabet_size |
|
|
|
|
|
self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.time_embed = nn.Sequential( |
|
|
GaussianFourierProjection(embed_dim=embed_dim), |
|
|
nn.Linear(embed_dim, embed_dim) |
|
|
) |
|
|
|
|
|
self.swish = Swish() |
|
|
|
|
|
n = hidden_dim |
|
|
|
|
|
self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) |
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
]) |
|
|
|
|
|
self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)]) |
|
|
self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)]) |
|
|
|
|
|
self.final = nn.Sequential( |
|
|
nn.Conv1d(n, n, kernel_size=1), |
|
|
nn.GELU(), |
|
|
nn.Conv1d(n, self.alphabet_size, kernel_size=1) |
|
|
) |
|
|
|
|
|
def forward(self, x, t): |
|
|
""" |
|
|
Args: |
|
|
x: Tensor of shape (B, L) containing DNA token indices. |
|
|
t: Tensor of shape (B,) containing the time steps. |
|
|
Returns: |
|
|
out: Tensor of shape (B, L, 4) with output logits for each DNA base. |
|
|
""" |
|
|
x = self.token_embedding(x) |
|
|
|
|
|
|
|
|
time_embed = self.swish(self.time_embed(t)) |
|
|
|
|
|
out = x.permute(0, 2, 1) |
|
|
out = self.swish(self.linear(out)) |
|
|
|
|
|
|
|
|
for block, dense, norm in zip(self.blocks, self.denses, self.norms): |
|
|
|
|
|
h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) |
|
|
|
|
|
if h.shape == out.shape: |
|
|
out = h + out |
|
|
else: |
|
|
out = h |
|
|
|
|
|
out = self.final(out) |
|
|
out = out.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
out = out - out.mean(dim=-1, keepdim=True) |
|
|
return out |
|
|
|
|
|
class CNNModel_Large(nn.Module): |
|
|
"""A time-dependent score-based model built upon U-Net architecture.""" |
|
|
|
|
|
def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): |
|
|
""" |
|
|
Args: |
|
|
embed_dim (int): Dimensionality of the token and time embeddings. |
|
|
""" |
|
|
super().__init__() |
|
|
self.alphabet_size = alphabet_size |
|
|
|
|
|
self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim) |
|
|
|
|
|
self.time_embed = nn.Sequential( |
|
|
GaussianFourierProjection(embed_dim=embed_dim), |
|
|
nn.Linear(embed_dim, embed_dim) |
|
|
) |
|
|
|
|
|
self.swish = Swish() |
|
|
|
|
|
n = hidden_dim |
|
|
|
|
|
self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) |
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, padding=4), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
|
|
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256) |
|
|
]) |
|
|
|
|
|
self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(20)]) |
|
|
self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(20)]) |
|
|
|
|
|
self.final = nn.Sequential( |
|
|
nn.Conv1d(n, n, kernel_size=1), |
|
|
nn.GELU(), |
|
|
nn.Conv1d(n, self.alphabet_size, kernel_size=1) |
|
|
) |
|
|
|
|
|
def forward(self, x, t): |
|
|
""" |
|
|
Args: |
|
|
x: Tensor of shape (B, L) containing DNA token indices. |
|
|
t: Tensor of shape (B,) containing the time steps. |
|
|
Returns: |
|
|
out: Tensor of shape (B, L, 4) with output logits for each DNA base. |
|
|
""" |
|
|
x = self.token_embedding(x) |
|
|
time_embed = self.swish(self.time_embed(t)) |
|
|
|
|
|
out = x.permute(0, 2, 1) |
|
|
out = self.swish(self.linear(out)) |
|
|
|
|
|
|
|
|
for block, dense, norm in zip(self.blocks, self.denses, self.norms): |
|
|
|
|
|
h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) |
|
|
|
|
|
if h.shape == out.shape: |
|
|
out = h + out |
|
|
else: |
|
|
out = h |
|
|
|
|
|
out = self.final(out) |
|
|
out = out.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
out = out - out.mean(dim=-1, keepdim=True) |
|
|
return out |