File size: 3,045 Bytes
5e01175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pkg_resources
import logging

from pytorch_models import PROTAC_Model, load_model
from data_utils import (
    load_protein2embedding,
    load_cell2embedding,
    get_fingerprint,
)
from config import config

import numpy as np
import torch
from torch import sigmoid

package_name = 'protac_degradation_predictor'

def get_protac_active_proba(
        protac_smiles: str,
        e3_ligase: str,
        target_uniprot: str,
        cell_line: str,
        device: str = 'cpu',
) -> bool:
    ckpt_path = pkg_resources.resource_stream(__name__, 'data/model.ckpt')
    model = load_model(ckpt_path).to(device)
    protein2embedding = load_protein2embedding()
    cell2embedding = load_cell2embedding()

    # Setup default embeddings
    if e3_ligase not in config.e3_ligase2uniprot:
        available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
        logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
    if target_uniprot not in protein2embedding:
        logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
    if cell_line not in load_cell2embedding():
        logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")

    default_protein_emb = np.zeros(config.protein_embedding_size)
    default_cell_emb = np.zeros(config.cell_embedding_size)
    
    # Convert the E3 ligase to Uniprot ID
    e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')

    # Get the embeddings
    poi_emb = protein2embedding.get(target_uniprot, default_protein_emb)
    e3_emb = protein2embedding.get(e3_ligase_uniprot, default_protein_emb)
    cell_emb = cell2embedding.get(cell_line, default_cell_emb)
    smiles_emb = get_fingerprint(protac_smiles)

    # Convert to torch tensors
    poi_emb = torch.tensor(poi_emb).to(device)
    e3_emb = torch.tensor(e3_emb).to(device)
    cell_emb = torch.tensor(cell_emb).to(device)
    smiles_emb = torch.tensor(smiles_emb).to(device)

    return model(poi_emb, e3_emb, cell_emb, smiles_emb).item()


def is_protac_active(
        protac_smiles: str,
        e3_ligase: str,
        target_uniprot: str,
        cell_line: str,
        device: str = 'cpu',
        proba_threshold: float = 0.5,
) -> bool:
    """ Predict whether a PROTAC is active or not.
    
    Args:
        protac_smiles (str): The SMILES of the PROTAC.
        e3_ligase (str): The Uniprot ID of the E3 ligase.
        target_uniprot (str): The Uniprot ID of the target protein.
        cell_line (str): The cell line identifier.
        device (str): The device to run the model on.
        proba_threshold (float): The probability threshold.

    Returns:
        bool: Whether the PROTAC is active or not.
    """
    pred = get_protac_active_proba(
        protac_smiles,
        e3_ligase,
        target_uniprot,
        cell_line,
        device,
    )
    return sigmoid(pred) > proba_threshold