ESM-2 for Post Translational Modification

Metrics

Train metrics:
{'eval_loss': 0.024510689079761505,
'eval_accuracy': 0.9908227849618837,
'eval_precision': 0.22390420883031378,
'eval_recall': 0.9793229461354229,
'eval_f1': 0.3644773616334614,
'eval_auc': 0.9850883581685357,
'eval_mcc': 0.4660172779827273}

Test metrics:
{'eval_loss': 0.1606895923614502,
'eval_accuracy': 0.9363938912290479,
'eval_precision': 0.04428881619840198,
'eval_recall': 0.7708102070506146,
'eval_f1': 0.08376472210171558,
'eval_auc': 0.8539155251667717,
'eval_mcc': 0.17519724897930178}

Using the Model

To use this model, firts run:

!pip install transformers -q
!pip install peft -q

Then run the following on your protein sequence to predict post translational modification sites:

from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t6_8M_ptm_lora_500K"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No ptm site",
    1: "ptm site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Collection including AmelieSchreiber/esm2_t6_8M_ptm_lora_500K