File size: 3,698 Bytes
52da96f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch

from tqdm import tqdm
from torch.nn.functional import normalize
from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer


class ProteinEncoder(torch.nn.Module):
    def __init__(self,

                 config_path: str,

                 out_dim: int,

                 load_pretrained: bool = True,

                 gradient_checkpointing: bool = False):
        """

        Args:

            config_path: Path to the config file

            

            out_dim    : Output dimension of the protein representation

            

            load_pretrained: Whether to load pretrained weights

            

            gradient_checkpointing: Whether to use gradient checkpointing

        """
        super().__init__()
        config = EsmConfig.from_pretrained(config_path)
        if load_pretrained:
            self.model = EsmForMaskedLM.from_pretrained(config_path)
        else:
            self.model = EsmForMaskedLM(config)
        self.out = torch.nn.Linear(config.hidden_size, out_dim)
        
        # Set gradient checkpointing
        self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
        
        # Remove contact head
        self.model.esm.contact_head = None
        
        # Remove position embedding if the embedding type is ``rotary``
        if config.position_embedding_type == "rotary":
            self.model.esm.embeddings.position_embeddings = None
        
        self.tokenizer = EsmTokenizer.from_pretrained(config_path)
    
    def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
        """

        Compute protein representation for the given proteins

        Args:

            protein: A list of protein sequences

            batch_size: Batch size for inference

            verbose: Whether to print progress

        """
        device = next(self.parameters()).device
        
        protein_repr = []
        if verbose:
            iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
        else:
            iterator = range(0, len(proteins), batch_size)
            
        for i in iterator:
            protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
                                                              return_tensors="pt",
                                                              padding=True)
            protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
            output, _ = self.forward(protein_inputs)
            
            protein_repr.append(output)
        
        protein_repr = torch.cat(protein_repr, dim=0)
        return normalize(protein_repr, dim=-1)

    def forward(self, inputs: dict, get_mask_logits: bool = False):
        """

        Encode protein sequence into protein representation

        Args:

            inputs: A dictionary containing the following keys:

                - input_ids: [batch, seq_len]

                - attention_mask: [batch, seq_len]

            get_mask_logits: Whether to return the logits for masked tokens



        Returns:

            protein_repr: [batch, protein_repr_dim]

            mask_logits : [batch, seq_len, vocab_size]

        """
        last_hidden_state = self.model.esm(**inputs).last_hidden_state
        reprs = last_hidden_state[:, 0, :]
        reprs = self.out(reprs)

        # Get logits for masked tokens
        if get_mask_logits:
            mask_logits = self.model.lm_head(last_hidden_state)
        else:
            mask_logits = None

        return reprs, mask_logits