PyTorch
oneprot / README.md
sealinka's picture
Update README.md
78ef4e4 verified
---
license: mit
---
[Github repo](https://github.com/klemens-floege/oneprot/)|
[Paper link](https://arxiv.org/abs/2411.04863)
## Overview
OneProt is a multimodal model that integrates protein sequence, protein structure (both in form of an augmented sequence and in a form of a graph), protein binding sites and protein text annotations. Contrastive learning is used to align each of the modality to the central one, which is protein sequence. In the pre-training phase InfoNCE loss is computed between pairs (protein sequence, other modality).
## Model architecture
Protein sequence encoder: [esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D)
Protein structure encoder: [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D)
Protein structure encoder GNN: [ProNet](https://github.com/divelab/DIG)
Pocket (binding sites encoder) GNN: [ProNet](https://github.com/divelab/DIG)
Text encoder: [BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext](https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext)
Below is an example code on how to obtain the embeddings (requires cloning our repo first). Note that example data for transformer models are read-off from `.txt` files and in principle can be passed as strings, whlist the data for GNN models are contained in the example `.h5` file and need to subsequently be converted to graphs.
```
import torch
import hydra
from omegaconf import OmegaConf
from huggingface_hub import HfApi, hf_hub_download
import sys
import os
import h5py
from torch_geometric.data import Batch
from transformers import AutoTokenizer
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # assuming that you are running this script from the oneprot repo, can be any other path
from src.models.oneprot_module import OneProtLitModule
from src.data.utils.struct_graph_utils import protein_to_graph
###if you are not running on the supercomputer, you may need to uncomment the two following lines
#os.environ['RANK']='0'
#os.environ['WORLD_SIZE']='1'
#Load the config file and read it off
config_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="config.yaml",
)
with open(config_path, 'r') as f:
cfg = OmegaConf.load(f)
# Prepare components dictionary from config
components = {
'sequence': hydra.utils.instantiate(cfg.model.components.sequence),
'struct_token': hydra.utils.instantiate(cfg.model.components.struct_token),
'struct_graph': hydra.utils.instantiate(cfg.model.components.struct_graph),
'pocket': hydra.utils.instantiate(cfg.model.components.pocket),
'text': hydra.utils.instantiate(cfg.model.components.text)
}
# Load the model checkpoint
checkpoint_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="pytorch_model.bin",
repo_type="model"
)
# Create model instance and load the checkpoint
model = OneProtLitModule(
components=components,
optimizer=None,
loss_fn=cfg.model.loss_fn,
local_loss=cfg.model.local_loss,
gather_with_grad=cfg.model.gather_with_grad,
use_l1_regularization=cfg.model.use_l1_regularization,
train_on_all_modalities_after_step=cfg.model.train_on_all_modalities_after_step,
use_seqsim=cfg.model.use_seqsim
)
state_dict = torch.load(checkpoint_path)
model_state_dict = model.state_dict()
model.load_state_dict(state_dict, strict=True)
# Define the tokenisers
tokenizers = {
'sequence': "facebook/esm2_t33_650M_UR50D",
'struct_token': "facebook/esm2_t33_650M_UR50D",
'text': "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
}
loaded_tokenizers = {}
for modality, tokenizer_name in tokenizers.items():
tokenizer = AutoTokenizer.from_pretrained(tokenizers[modality])
if modality=='struct_token':
new_tokens = ['p', 'y', 'n', 'w', 'r', 'q', 'h', 'g', 'd', 'l', 'v', 't', 'm', 'f', 's', 'a', 'e', 'i', 'k', 'c','#']
tokenizer.add_tokens(new_tokens)
loaded_tokenizers[modality] = tokenizer
# Get example embeddings for each modality
##########################sequence##############################
modality = "sequence"
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/sequence_example.txt",
repo_type="model" # or "dataset"
)
with open(file_path, 'r') as file:
input_sequence = file.read().strip()
input_tensor = loaded_tokenizers[modality](input_sequence, return_tensors="pt")["input_ids"]
output = model.network[modality](input_tensor)
print(f"Output for modality '{modality}': {output}")
###########################text#################################
modality = "text"
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/text_example.txt",
repo_type="model" # or "dataset"
)
with open(file_path, 'r') as file:
input_text = file.read().strip()
input_tensor = loaded_tokenizers[modality](input_text, return_tensors="pt")["input_ids"]
output = model.network[modality](input_tensor)
print(f"Output for modality '{modality}': {output}")
#####################tokenized structure########################
modality = "struct_token"
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/struct_token_example.txt",
repo_type="model" # or "dataset"
)
with open(file_path, 'r') as file:
input_struct_token = file.read().strip()
input_struct_token = "".join([s.replace("#", "") for s in input_struct_token])
input_tensor = loaded_tokenizers[modality](input_struct_token, return_tensors="pt")["input_ids"]
output = model.network[modality](input_tensor)
print(f"Output for modality '{modality}': {output}")
#####################graph structure############################
modality = "struct_graph"
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/seqstruc_example.h5",
repo_type="model" # or "dataset"
)
with h5py.File(file_path, 'r') as file:
input_struct_graph=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=False)]
input_struct_graph = Batch.from_data_list(input_struct_graph)
output=model.network[modality](input_struct_graph)
print(f"Output for modality '{modality}': {output}")
##########################pocket################################
modality = "pocket" # Replace with the desired modality
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/pocket_example.h5",
repo_type="model" # or "dataset"
)
with h5py.File(file_path, 'r') as file:
input_pocket=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=True)]
input_pocket = Batch.from_data_list(input_pocket)
output=model.network[modality](input_pocket)
print(f"Output for modality '{modality}': {output}")
```
Citation
```
@misc{flöge2024oneprotmultimodalproteinfoundation,
title={OneProt: Towards Multi-Modal Protein Foundation Models},
author={Klemens Flöge and Srisruthi Udayakumar and Johanna Sommer and Marie Piraud and Stefan Kesselheim and Vincent Fortuin and Stephan Günneman and Karel J van der Weg and Holger Gohlke and Alina Bazarova and Erinc Merdivan},
year={2024},
eprint={2411.04863},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2411.04863},
}
```