|
--- |
|
license: mit |
|
--- |
|
[Github repo](https://github.com/klemens-floege/oneprot/)| |
|
[Paper link](https://arxiv.org/abs/2411.04863) |
|
|
|
|
|
## Overview |
|
|
|
OneProt is a multimodal model that integrates protein sequence, protein structure (both in form of an augmented sequence and in a form of a graph), protein binding sites and protein text annotations. Contrastive learning is used to align each of the modality to the central one, which is protein sequence. In the pre-training phase InfoNCE loss is computed between pairs (protein sequence, other modality). |
|
|
|
|
|
## Model architecture |
|
|
|
Protein sequence encoder: [esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D) |
|
|
|
Protein structure encoder: [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) |
|
|
|
Protein structure encoder GNN: [ProNet](https://github.com/divelab/DIG) |
|
|
|
Pocket (binding sites encoder) GNN: [ProNet](https://github.com/divelab/DIG) |
|
|
|
Text encoder: [BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext](https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext) |
|
|
|
Below is an example code on how to obtain the embeddings (requires cloning our repo first). Note that example data for transformer models are read-off from `.txt` files and in principle can be passed as strings, whlist the data for GNN models are contained in the example `.h5` file and need to subsequently be converted to graphs. |
|
|
|
``` |
|
import torch |
|
import hydra |
|
from omegaconf import OmegaConf |
|
from huggingface_hub import HfApi, hf_hub_download |
|
import sys |
|
import os |
|
import h5py |
|
from torch_geometric.data import Batch |
|
from transformers import AutoTokenizer |
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # assuming that you are running this script from the oneprot repo, can be any other path |
|
|
|
from src.models.oneprot_module import OneProtLitModule |
|
from src.data.utils.struct_graph_utils import protein_to_graph |
|
|
|
###if you are not running on the supercomputer, you may need to uncomment the two following lines |
|
#os.environ['RANK']='0' |
|
#os.environ['WORLD_SIZE']='1' |
|
|
|
|
|
#Load the config file and read it off |
|
|
|
config_path = hf_hub_download( |
|
repo_id="HelmholtzAI-FZJ/oneprot", |
|
filename="config.yaml", |
|
) |
|
|
|
with open(config_path, 'r') as f: |
|
cfg = OmegaConf.load(f) |
|
|
|
# Prepare components dictionary from config |
|
components = { |
|
'sequence': hydra.utils.instantiate(cfg.model.components.sequence), |
|
'struct_token': hydra.utils.instantiate(cfg.model.components.struct_token), |
|
'struct_graph': hydra.utils.instantiate(cfg.model.components.struct_graph), |
|
'pocket': hydra.utils.instantiate(cfg.model.components.pocket), |
|
'text': hydra.utils.instantiate(cfg.model.components.text) |
|
} |
|
|
|
# Load the model checkpoint |
|
|
|
checkpoint_path = hf_hub_download( |
|
repo_id="HelmholtzAI-FZJ/oneprot", |
|
filename="pytorch_model.bin", |
|
repo_type="model" |
|
) |
|
|
|
# Create model instance and load the checkpoint |
|
|
|
model = OneProtLitModule( |
|
components=components, |
|
optimizer=None, |
|
loss_fn=cfg.model.loss_fn, |
|
local_loss=cfg.model.local_loss, |
|
gather_with_grad=cfg.model.gather_with_grad, |
|
use_l1_regularization=cfg.model.use_l1_regularization, |
|
train_on_all_modalities_after_step=cfg.model.train_on_all_modalities_after_step, |
|
use_seqsim=cfg.model.use_seqsim |
|
) |
|
|
|
state_dict = torch.load(checkpoint_path) |
|
model_state_dict = model.state_dict() |
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
# Define the tokenisers |
|
|
|
tokenizers = { |
|
'sequence': "facebook/esm2_t33_650M_UR50D", |
|
'struct_token': "facebook/esm2_t33_650M_UR50D", |
|
'text': "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" |
|
} |
|
|
|
loaded_tokenizers = {} |
|
for modality, tokenizer_name in tokenizers.items(): |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizers[modality]) |
|
if modality=='struct_token': |
|
new_tokens = ['p', 'y', 'n', 'w', 'r', 'q', 'h', 'g', 'd', 'l', 'v', 't', 'm', 'f', 's', 'a', 'e', 'i', 'k', 'c','#'] |
|
tokenizer.add_tokens(new_tokens) |
|
loaded_tokenizers[modality] = tokenizer |
|
|
|
# Get example embeddings for each modality |
|
|
|
|
|
##########################sequence############################## |
|
|
|
modality = "sequence" |
|
|
|
file_path = hf_hub_download( |
|
repo_id="HelmholtzAI-FZJ/oneprot", |
|
filename="data_examples/sequence_example.txt", |
|
repo_type="model" # or "dataset" |
|
) |
|
|
|
with open(file_path, 'r') as file: |
|
input_sequence = file.read().strip() |
|
|
|
input_tensor = loaded_tokenizers[modality](input_sequence, return_tensors="pt")["input_ids"] |
|
output = model.network[modality](input_tensor) |
|
print(f"Output for modality '{modality}': {output}") |
|
|
|
|
|
###########################text################################# |
|
|
|
modality = "text" |
|
|
|
file_path = hf_hub_download( |
|
repo_id="HelmholtzAI-FZJ/oneprot", |
|
filename="data_examples/text_example.txt", |
|
repo_type="model" # or "dataset" |
|
) |
|
|
|
with open(file_path, 'r') as file: |
|
input_text = file.read().strip() |
|
|
|
input_tensor = loaded_tokenizers[modality](input_text, return_tensors="pt")["input_ids"] |
|
output = model.network[modality](input_tensor) |
|
print(f"Output for modality '{modality}': {output}") |
|
|
|
|
|
#####################tokenized structure######################## |
|
|
|
modality = "struct_token" |
|
|
|
file_path = hf_hub_download( |
|
repo_id="HelmholtzAI-FZJ/oneprot", |
|
filename="data_examples/struct_token_example.txt", |
|
repo_type="model" # or "dataset" |
|
) |
|
|
|
with open(file_path, 'r') as file: |
|
input_struct_token = file.read().strip() |
|
|
|
input_struct_token = "".join([s.replace("#", "") for s in input_struct_token]) |
|
input_tensor = loaded_tokenizers[modality](input_struct_token, return_tensors="pt")["input_ids"] |
|
output = model.network[modality](input_tensor) |
|
print(f"Output for modality '{modality}': {output}") |
|
|
|
|
|
#####################graph structure############################ |
|
|
|
modality = "struct_graph" |
|
file_path = hf_hub_download( |
|
repo_id="HelmholtzAI-FZJ/oneprot", |
|
filename="data_examples/seqstruc_example.h5", |
|
repo_type="model" # or "dataset" |
|
) |
|
|
|
with h5py.File(file_path, 'r') as file: |
|
input_struct_graph=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=False)] |
|
input_struct_graph = Batch.from_data_list(input_struct_graph) |
|
output=model.network[modality](input_struct_graph) |
|
print(f"Output for modality '{modality}': {output}") |
|
|
|
|
|
##########################pocket################################ |
|
|
|
|
|
modality = "pocket" # Replace with the desired modality |
|
|
|
file_path = hf_hub_download( |
|
repo_id="HelmholtzAI-FZJ/oneprot", |
|
filename="data_examples/pocket_example.h5", |
|
repo_type="model" # or "dataset" |
|
) |
|
|
|
with h5py.File(file_path, 'r') as file: |
|
input_pocket=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=True)] |
|
input_pocket = Batch.from_data_list(input_pocket) |
|
output=model.network[modality](input_pocket) |
|
|
|
print(f"Output for modality '{modality}': {output}") |
|
|
|
``` |
|
|
|
Citation |
|
|
|
``` |
|
@misc{flöge2024oneprotmultimodalproteinfoundation, |
|
title={OneProt: Towards Multi-Modal Protein Foundation Models}, |
|
author={Klemens Flöge and Srisruthi Udayakumar and Johanna Sommer and Marie Piraud and Stefan Kesselheim and Vincent Fortuin and Stephan Günneman and Karel J van der Weg and Holger Gohlke and Alina Bazarova and Erinc Merdivan}, |
|
year={2024}, |
|
eprint={2411.04863}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.LG}, |
|
url={https://arxiv.org/abs/2411.04863}, |
|
} |
|
|
|
``` |
|
|
|
|