import json import os import torch import torch.nn as nn from functools import partial from mamba_ssm.modules.mamba_simple import Block, Mamba from transformers import PretrainedConfig, PreTrainedModel class OrthrusConfig(PretrainedConfig): """HuggingFace config for pre-trained Orthrus model.""" model_type = "orthrus" def __init__( self, n_tracks: int = 4, ssm_model_dim: int = 256, ssm_n_layers: int = 3, **kwargs ): """Initialize OrthrusConfig. Args: n_tracks: Number of data tracks. ssm_model_dim: Hidden dimension of Mamba backbone. ssm_n_layers: Number of layers in Mamba backbone. """ self.n_tracks = n_tracks self.ssm_model_dim = ssm_model_dim self.ssm_n_layers = ssm_n_layers super().__init__(**kwargs) @classmethod def init_from_config(cls, config_dir_path: str) -> "OrthrusConfig": """Load config from pretraining config files. Args: config_dir_path: Path to folder with pretraining configs. """ model_config_path = os.path.join(config_dir_path, "model_config.json") data_config_path = os.path.join(config_dir_path, "data_config.json") with open(model_config_path, "r") as f: model_params = json.load(f) if "n_tracks" not in model_params: with open(data_config_path, "r") as f: data_params = json.load(f) n_tracks = data_params["n_tracks"] else: n_tracks = model_params["n_tracks"] return cls( n_tracks=n_tracks, ssm_model_dim=model_params["ssm_model_dim"], ssm_n_layers=model_params["ssm_n_layers"] ) class OrthrusPretrainedModel(PreTrainedModel): """HuggingFace wrapper for a pretrained Orthrus model.""" config_class = OrthrusConfig base_model_prefix = "orthrus" def __init__(self, config: OrthrusConfig, **kwargs): """Initialize OrthrusPretrainedModel. Args: config: Model configs. """ super().__init__(config, **kwargs) self.config = config self.embedding = nn.Linear( config.n_tracks, config.ssm_model_dim, ) self.layers = nn.ModuleList( [ self.create_block( config.ssm_model_dim, layer_idx=i, ) for i in range(config.ssm_n_layers) ] ) self.norm_f = nn.LayerNorm(config.ssm_model_dim) def create_block( self, d_model: int, layer_idx: int | None = None ) -> Block: """Create Mamba Block. Args: d_model: Hidden dimension of Mamba blocks. layer_idx: Index of current Mamba block in stack. Returns: Initialized Mamba block. """ mix_cls = partial(Mamba, layer_idx=layer_idx) norm_cls = nn.LayerNorm block = Block( d_model, mix_cls, norm_cls=norm_cls, ) block.layer_idx = layer_idx return block def forward( self, x: torch.Tensor, channel_last: bool = False ) -> torch.Tensor: """Perform Orthrus forward pass. Args: x: Input data. Shape (B x C x L) or (B x L x C) if channel_last. channel_last: Whether channel dimension is last dimension. Returns: Position-wise Orthrus embedding with shape (B x L x C). """ if not channel_last: x = x.transpose(1, 2) hidden_states = self.embedding(x) res = None for layer in self.layers: hidden_states, res = layer(hidden_states, res) res = (hidden_states + res) if res is not None else hidden_states hidden_states = self.norm_f(res.to(dtype=self.norm_f.weight.dtype)) return hidden_states def representation( self, x: torch.Tensor, lengths: torch.Tensor, channel_last: bool = False, ) -> torch.Tensor: """Get global representation of input data. Representation is pooled across length dimension. Args: x: Data to embed. Has shape (B x C x L) if not channel_last. lengths: Unpadded length of each data input. channel_last: Expects input of shape (B x L x C). Returns: Global representation vector of shape (B x H). """ out = self.forward(x, channel_last=channel_last) mean_tensor = self.mean_unpadded(out, lengths) return mean_tensor def seq_to_oh(self, seq: list[str]) -> torch.Tensor: """Convert nucleotide string into one-hot-encoding. The encoding uses ordering ["A", "C", "G", "T"]. Args: seq: Sequence to encode. Returns: One hot encoded sequence, with shape (L x 4). """ oh = torch.zeros((len(seq), 4), dtype=torch.float32) for i, base in enumerate(seq): if base == "A": oh[i, 0] = 1 elif base == "C": oh[i, 1] = 1 elif base == "G": oh[i, 2] = 1 elif base == "T": oh[i, 3] = 1 return oh def mean_unpadded( self, x: torch.Tensor, lengths: torch.Tensor ) -> torch.Tensor: """Take mean of tensor across second dimension without padding. Args: x: Tensor to take unpadded mean. Has shape (B x L x H). lengths: Tensor of unpadded lengths. Has shape (B) Returns: Mean tensor of shape (B x H). """ mask = torch.arange( x.size(1), device=x.device )[None, :] < lengths[:, None] masked_tensor = x * mask.unsqueeze(-1) sum_tensor = masked_tensor.sum(dim=1) mean_tensor = sum_tensor / lengths.unsqueeze(-1).float() return mean_tensor