|
import json |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from functools import partial |
|
from mamba_ssm.modules.mamba_simple import Block, Mamba |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
class OrthrusConfig(PretrainedConfig): |
|
"""HuggingFace config for pre-trained Orthrus model.""" |
|
|
|
model_type = "orthrus" |
|
|
|
def __init__( |
|
self, |
|
n_tracks: int = 4, |
|
ssm_model_dim: int = 256, |
|
ssm_n_layers: int = 3, |
|
**kwargs |
|
): |
|
"""Initialize OrthrusConfig. |
|
|
|
Args: |
|
n_tracks: Number of data tracks. |
|
ssm_model_dim: Hidden dimension of Mamba backbone. |
|
ssm_n_layers: Number of layers in Mamba backbone. |
|
""" |
|
self.n_tracks = n_tracks |
|
self.ssm_model_dim = ssm_model_dim |
|
self.ssm_n_layers = ssm_n_layers |
|
super().__init__(**kwargs) |
|
|
|
@classmethod |
|
def init_from_config(cls, config_dir_path: str) -> "OrthrusConfig": |
|
"""Load config from pretraining config files. |
|
|
|
Args: |
|
config_dir_path: Path to folder with pretraining configs. |
|
""" |
|
model_config_path = os.path.join(config_dir_path, "model_config.json") |
|
data_config_path = os.path.join(config_dir_path, "data_config.json") |
|
|
|
with open(model_config_path, "r") as f: |
|
model_params = json.load(f) |
|
|
|
if "n_tracks" not in model_params: |
|
with open(data_config_path, "r") as f: |
|
data_params = json.load(f) |
|
n_tracks = data_params["n_tracks"] |
|
else: |
|
n_tracks = model_params["n_tracks"] |
|
|
|
return cls( |
|
n_tracks=n_tracks, |
|
ssm_model_dim=model_params["ssm_model_dim"], |
|
ssm_n_layers=model_params["ssm_n_layers"] |
|
) |
|
|
|
|
|
class OrthrusPretrainedModel(PreTrainedModel): |
|
"""HuggingFace wrapper for a pretrained Orthrus model.""" |
|
|
|
config_class = OrthrusConfig |
|
base_model_prefix = "orthrus" |
|
|
|
def __init__(self, config: OrthrusConfig, **kwargs): |
|
"""Initialize OrthrusPretrainedModel. |
|
|
|
Args: |
|
config: Model configs. |
|
""" |
|
super().__init__(config, **kwargs) |
|
|
|
self.config = config |
|
self.embedding = nn.Linear( |
|
config.n_tracks, |
|
config.ssm_model_dim, |
|
) |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
self.create_block( |
|
config.ssm_model_dim, |
|
layer_idx=i, |
|
) |
|
for i in range(config.ssm_n_layers) |
|
] |
|
) |
|
|
|
self.norm_f = nn.LayerNorm(config.ssm_model_dim) |
|
|
|
def create_block( |
|
self, |
|
d_model: int, |
|
layer_idx: int | None = None |
|
) -> Block: |
|
"""Create Mamba Block. |
|
|
|
Args: |
|
d_model: Hidden dimension of Mamba blocks. |
|
layer_idx: Index of current Mamba block in stack. |
|
|
|
Returns: |
|
Initialized Mamba block. |
|
""" |
|
mix_cls = partial(Mamba, layer_idx=layer_idx) |
|
norm_cls = nn.LayerNorm |
|
block = Block( |
|
d_model, |
|
mix_cls, |
|
norm_cls=norm_cls, |
|
) |
|
block.layer_idx = layer_idx |
|
return block |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
channel_last: bool = False |
|
) -> torch.Tensor: |
|
"""Perform Orthrus forward pass. |
|
|
|
Args: |
|
x: Input data. Shape (B x C x L) or (B x L x C) if channel_last. |
|
channel_last: Whether channel dimension is last dimension. |
|
|
|
Returns: |
|
Position-wise Orthrus embedding with shape (B x L x C). |
|
""" |
|
if not channel_last: |
|
x = x.transpose(1, 2) |
|
|
|
hidden_states = self.embedding(x) |
|
res = None |
|
for layer in self.layers: |
|
hidden_states, res = layer(hidden_states, res) |
|
|
|
res = (hidden_states + res) if res is not None else hidden_states |
|
hidden_states = self.norm_f(res.to(dtype=self.norm_f.weight.dtype)) |
|
|
|
return hidden_states |
|
|
|
def representation( |
|
self, |
|
x: torch.Tensor, |
|
lengths: torch.Tensor, |
|
channel_last: bool = False, |
|
) -> torch.Tensor: |
|
"""Get global representation of input data. |
|
|
|
Representation is pooled across length dimension. |
|
|
|
Args: |
|
x: Data to embed. Has shape (B x C x L) if not channel_last. |
|
lengths: Unpadded length of each data input. |
|
channel_last: Expects input of shape (B x L x C). |
|
|
|
Returns: |
|
Global representation vector of shape (B x H). |
|
""" |
|
out = self.forward(x, channel_last=channel_last) |
|
|
|
mean_tensor = self.mean_unpadded(out, lengths) |
|
return mean_tensor |
|
|
|
def seq_to_oh(self, seq: list[str]) -> torch.Tensor: |
|
"""Convert nucleotide string into one-hot-encoding. |
|
|
|
The encoding uses ordering ["A", "C", "G", "T"]. |
|
|
|
Args: |
|
seq: Sequence to encode. |
|
|
|
Returns: |
|
One hot encoded sequence, with shape (L x 4). |
|
""" |
|
oh = torch.zeros((len(seq), 4), dtype=torch.float32) |
|
for i, base in enumerate(seq): |
|
if base == "A": |
|
oh[i, 0] = 1 |
|
elif base == "C": |
|
oh[i, 1] = 1 |
|
elif base == "G": |
|
oh[i, 2] = 1 |
|
elif base == "T": |
|
oh[i, 3] = 1 |
|
return oh |
|
|
|
def mean_unpadded( |
|
self, |
|
x: torch.Tensor, |
|
lengths: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Take mean of tensor across second dimension without padding. |
|
|
|
Args: |
|
x: Tensor to take unpadded mean. Has shape (B x L x H). |
|
lengths: Tensor of unpadded lengths. Has shape (B) |
|
|
|
Returns: |
|
Mean tensor of shape (B x H). |
|
""" |
|
mask = torch.arange( |
|
x.size(1), |
|
device=x.device |
|
)[None, :] < lengths[:, None] |
|
masked_tensor = x * mask.unsqueeze(-1) |
|
sum_tensor = masked_tensor.sum(dim=1) |
|
mean_tensor = sum_tensor / lengths.unsqueeze(-1).float() |
|
|
|
return mean_tensor |
|
|