orthrus-large-6-track / orthrus_hf.py
quietflamingo's picture
Fixed seq ohe type.
d299dce
raw
history blame
6.1 kB
import json
import os
import torch
import torch.nn as nn
from functools import partial
from mamba_ssm.modules.mamba_simple import Block, Mamba
from transformers import PretrainedConfig, PreTrainedModel
class OrthrusConfig(PretrainedConfig):
"""HuggingFace config for pre-trained Orthrus model."""
model_type = "orthrus"
def __init__(
self,
n_tracks: int = 4,
ssm_model_dim: int = 256,
ssm_n_layers: int = 3,
**kwargs
):
"""Initialize OrthrusConfig.
Args:
n_tracks: Number of data tracks.
ssm_model_dim: Hidden dimension of Mamba backbone.
ssm_n_layers: Number of layers in Mamba backbone.
"""
self.n_tracks = n_tracks
self.ssm_model_dim = ssm_model_dim
self.ssm_n_layers = ssm_n_layers
super().__init__(**kwargs)
@classmethod
def init_from_config(cls, config_dir_path: str) -> "OrthrusConfig":
"""Load config from pretraining config files.
Args:
config_dir_path: Path to folder with pretraining configs.
"""
model_config_path = os.path.join(config_dir_path, "model_config.json")
data_config_path = os.path.join(config_dir_path, "data_config.json")
with open(model_config_path, "r") as f:
model_params = json.load(f)
if "n_tracks" not in model_params:
with open(data_config_path, "r") as f:
data_params = json.load(f)
n_tracks = data_params["n_tracks"]
else:
n_tracks = model_params["n_tracks"]
return cls(
n_tracks=n_tracks,
ssm_model_dim=model_params["ssm_model_dim"],
ssm_n_layers=model_params["ssm_n_layers"]
)
class OrthrusPretrainedModel(PreTrainedModel):
"""HuggingFace wrapper for a pretrained Orthrus model."""
config_class = OrthrusConfig
base_model_prefix = "orthrus"
def __init__(self, config: OrthrusConfig, **kwargs):
"""Initialize OrthrusPretrainedModel.
Args:
config: Model configs.
"""
super().__init__(config, **kwargs)
self.config = config
self.embedding = nn.Linear(
config.n_tracks,
config.ssm_model_dim,
)
self.layers = nn.ModuleList(
[
self.create_block(
config.ssm_model_dim,
layer_idx=i,
)
for i in range(config.ssm_n_layers)
]
)
self.norm_f = nn.LayerNorm(config.ssm_model_dim)
def create_block(
self,
d_model: int,
layer_idx: int | None = None
) -> Block:
"""Create Mamba Block.
Args:
d_model: Hidden dimension of Mamba blocks.
layer_idx: Index of current Mamba block in stack.
Returns:
Initialized Mamba block.
"""
mix_cls = partial(Mamba, layer_idx=layer_idx)
norm_cls = nn.LayerNorm
block = Block(
d_model,
mix_cls,
norm_cls=norm_cls,
)
block.layer_idx = layer_idx
return block
def forward(
self,
x: torch.Tensor,
channel_last: bool = False
) -> torch.Tensor:
"""Perform Orthrus forward pass.
Args:
x: Input data. Shape (B x C x L) or (B x L x C) if channel_last.
channel_last: Whether channel dimension is last dimension.
Returns:
Position-wise Orthrus embedding with shape (B x L x C).
"""
if not channel_last:
x = x.transpose(1, 2)
hidden_states = self.embedding(x)
res = None
for layer in self.layers:
hidden_states, res = layer(hidden_states, res)
res = (hidden_states + res) if res is not None else hidden_states
hidden_states = self.norm_f(res.to(dtype=self.norm_f.weight.dtype))
return hidden_states
def representation(
self,
x: torch.Tensor,
lengths: torch.Tensor,
channel_last: bool = False,
) -> torch.Tensor:
"""Get global representation of input data.
Representation is pooled across length dimension.
Args:
x: Data to embed. Has shape (B x C x L) if not channel_last.
lengths: Unpadded length of each data input.
channel_last: Expects input of shape (B x L x C).
Returns:
Global representation vector of shape (B x H).
"""
out = self.forward(x, channel_last=channel_last)
mean_tensor = self.mean_unpadded(out, lengths)
return mean_tensor
def seq_to_oh(self, seq: list[str]) -> torch.Tensor:
"""Convert nucleotide string into one-hot-encoding.
The encoding uses ordering ["A", "C", "G", "T"].
Args:
seq: Sequence to encode.
Returns:
One hot encoded sequence, with shape (L x 4).
"""
oh = torch.zeros((len(seq), 4), dtype=torch.float32)
for i, base in enumerate(seq):
if base == "A":
oh[i, 0] = 1
elif base == "C":
oh[i, 1] = 1
elif base == "G":
oh[i, 2] = 1
elif base == "T":
oh[i, 3] = 1
return oh
def mean_unpadded(
self,
x: torch.Tensor,
lengths: torch.Tensor
) -> torch.Tensor:
"""Take mean of tensor across second dimension without padding.
Args:
x: Tensor to take unpadded mean. Has shape (B x L x H).
lengths: Tensor of unpadded lengths. Has shape (B)
Returns:
Mean tensor of shape (B x H).
"""
mask = torch.arange(
x.size(1),
device=x.device
)[None, :] < lengths[:, None]
masked_tensor = x * mask.unsqueeze(-1)
sum_tensor = masked_tensor.sum(dim=1)
mean_tensor = sum_tensor / lengths.unsqueeze(-1).float()
return mean_tensor